"""Operators represent calculations that will occur in the simulation.
This code adapted from sigops/operator.py and sigops/operators.py
(https://github.com/jaberg/sigops).
This modified code is included under the terms of their license:
Copyright (c) 2014, James Bergstra
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the
distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
import numpy as np
import nengo.utils.numpy as npext
from nengo.exceptions import BuildError, SimulationError
[docs]class Operator(object):
"""Base class for operator instances understood by Nengo.
During one simulator timestep, a `.Signal` can experience
1. at most one set operator (optional)
2. any number of increments
3. any number of reads
4. at most one update
in this specific order.
A ``set`` defines the state of the signal at time :math:`t`, the start
of the simulation timestep. That state can then be modified by
``increment`` operations. A signal's state will only be ``read`` after
all increments are complete. The state is then finalized by an ``update``,
which denotes the state that the signal should be at time :math:`t + dt`.
Each operator must keep track of the signals that it manipulates,
and which of these four types of manipulations is done to each signal
so that the simulator can order all of the operators properly.
.. note:: There are intentionally no default values for the
`~.Operator.reads`, `~.Operator.sets`, `~.Operator.incs`,
and `~.Operator.updates` properties to ensure that subclasses
explicitly set these values.
Parameters
----------
tag : str, optional (Default: None)
A label associated with the operator, for debugging purposes.
Attributes
----------
tag : str or None
A label associated with the operator, for debugging purposes.
"""
def __init__(self, tag=None):
self.tag = tag
def __repr__(self):
return "<%s%s at 0x%x>" % (
self.__class__.__name__, self._tagstr(), id(self))
def __str__(self):
strs = (s for s in (self._descstr(), self._tagstr()) if s)
return "%s{%s}" % (self.__class__.__name__, ' '.join(strs))
def _descstr(self):
return ''
def _tagstr(self):
return ('"%s"' % self.tag) if self.tag is not None else ''
@property
def all_signals(self):
return self.reads + self.sets + self.incs + self.updates
@property
def incs(self):
"""Signals incremented by this operator.
Increments will be applied after sets (if it is set), and before reads.
"""
return self._incs
@incs.setter
def incs(self, val):
self._incs = val
@property
def reads(self):
"""Signals that are read and not modified by this operator.
Reads occur after increments, and before updates.
"""
return self._reads
@reads.setter
def reads(self, val):
self._reads = val
@property
def sets(self):
"""Signals set by this operator.
Sets occur first, before increments. A signal that is set here cannot
be set or updated by any other operator.
"""
return self._sets
@sets.setter
def sets(self, val):
self._sets = val
@property
def updates(self):
"""Signals updated by this operator.
Updates are the last operation to occur to a signal.
"""
return self._updates
@updates.setter
def updates(self, val):
self._updates = val
[docs] def init_signals(self, signals):
"""Initialize the signals associated with this operator.
The signals will be initialized into ``signals``.
Operator subclasses that use extra buffers should create them here.
Parameters
----------
signals : SignalDict
A mapping from signals to their associated live ndarrays.
"""
for sig in self.all_signals:
if sig not in signals:
signals.init(sig)
[docs] def make_step(self, signals, dt, rng):
"""Returns a callable that performs the desired computation.
This method must be implemented by subclasses. To fully understand what
an operator does, look at its implementation of ``make_step``.
Parameters
----------
signals : SignalDict
A mapping from signals to their associated live ndarrays.
dt : float
Length of each simulation timestep, in seconds.
rng : `numpy.random.RandomState`
Random number generator for stochastic operators.
"""
raise NotImplementedError("subclasses must implement this method.")
[docs]class TimeUpdate(Operator):
"""Updates the simulation step and time.
Implements ``step[...] += 1`` and ``time[...] = step * dt``.
A separate operator is used (rather than a combination of `.Copy` and
`.DotInc`) so that other backends can manage these important parts of the
simulation state separately from other signals.
Parameters
----------
step : Signal
The signal associated with the integer step counter.
time : Signal
The signal associated with the time (a float, in seconds).
tag : str, optional (Default: None)
A label associated with the operator, for debugging purposes.
Attributes
----------
step : Signal
The signal associated with the integer step counter.
tag : str or None
A label associated with the operator, for debugging purposes.
time : Signal
The signal associated with the time (a float, in seconds).
Notes
-----
1. sets ``[step, time]``
2. incs ``[]``
3. reads ``[]``
4. updates ``[]``
"""
def __init__(self, step, time, tag=None):
self.step = step
self.time = time
self.tag = tag
self.sets = [step, time]
self.incs = []
self.reads = []
self.updates = []
def make_step(self, signals, dt, rng):
step = signals[self.step]
time = signals[self.time]
def step_timeupdate():
step[...] += 1
time[...] = step * dt
return step_timeupdate
[docs]class PreserveValue(Operator):
"""Marks a signal as ``set`` for the graph checker.
This operator does no computation. It simply marks a signal as ``set``,
allowing us to apply other ops to signals that we want to preserve their
value across multiple time steps. It is used primarily for learning rules.
Parameters
----------
dst : Signal
The signal whose value we want to preserve.
tag : str, optional (Default: None)
A label associated with the operator, for debugging purposes.
Attributes
----------
dst : Signal
The signal whose value we want to preserve.
tag : str or None
A label associated with the operator, for debugging purposes.
Notes
-----
1. sets ``[dst]``
2. incs ``[]``
3. reads ``[]``
4. updates ``[]``
"""
def __init__(self, dst, tag=None):
self.dst = dst
self.tag = tag
self.sets = [dst]
self.incs = []
self.reads = []
self.updates = []
def _descstr(self):
return str(self.dst)
def make_step(self, signals, dt, rng):
def step_preservevalue():
pass
return step_preservevalue
[docs]class Reset(Operator):
"""Assign a constant value to a Signal.
Implements ``dst[...] = value``.
Parameters
----------
dst : Signal
The Signal to reset.
value : float, optional (Default: 0)
The constant value to which ``dst`` is set.
tag : str, optional (Default: None)
A label associated with the operator, for debugging purposes.
Attributes
----------
dst : Signal
The Signal to reset.
tag : str or None
A label associated with the operator, for debugging purposes.
value : float
The constant value to which ``dst`` is set.
Notes
-----
1. sets ``[dst]``
2. incs ``[]``
3. reads ``[]``
4. updates ``[]``
"""
def __init__(self, dst, value=0, tag=None):
self.dst = dst
self.value = float(value)
self.tag = tag
self.sets = [dst]
self.incs = []
self.reads = []
self.updates = []
def _descstr(self):
return str(self.dst)
def make_step(self, signals, dt, rng):
target = signals[self.dst]
value = self.value
def step_reset():
target[...] = value
return step_reset
[docs]class Copy(Operator):
"""Assign the value of one signal to another.
Implements ``dst[...] = src``.
Parameters
----------
dst : Signal
The signal that will be assigned to (set).
src : Signal
The signal that will be copied (read).
tag : str, optional (Default: None)
A label associated with the operator, for debugging purposes.
Attributes
----------
dst : Signal
The signal that will be assigned to (set).
src : Signal
The signal that will be copied (read).
tag : str or None
A label associated with the operator, for debugging purposes.
Notes
-----
1. sets ``[dst]``
2. incs ``[]``
3. reads ``[src]``
4. updates ``[]``
"""
def __init__(self, dst, src, tag=None):
self.dst = dst
self.src = src
self.tag = tag
self.sets = [dst]
self.incs = []
self.reads = [src]
self.updates = []
def _descstr(self):
return '%s -> %s' % (self.src, self.dst)
def make_step(self, signals, dt, rng):
dst = signals[self.dst]
src = signals[self.src]
def step_copy():
dst[...] = src
return step_copy
[docs]class SlicedCopy(Operator):
"""Assign the value of a slice of one signal to another slice.
Implements ``dst[dst_slice] = src[src_slice]``.
This operator can also implement ``dst[dst_slice] += src[src_slice]``
using the parameter ``inc``.
Parameters
----------
dst : Signal
The signal that will be assigned to (set).
src : Signal
The signal that will be copied (read).
dst_slice : slice or Ellipsis, optional (Default: Ellipsis)
Slice associated with ``dst``.
src_slice : slice or Ellipsis, optional (Default: Ellipsis)
Slice associated with ``src``
inc : bool, optional (Default: False)
Whether this should be an increment rather than a copy.
tag : str, optional (Default: None)
A label associated with the operator, for debugging purposes.
Attributes
----------
dst : Signal
The signal that will be assigned to (set).
dst_slice : list or Ellipsis
Indices associated with ``dst``.
src : Signal
The signal that will be copied (read).
src_slice : list or Ellipsis
Indices associated with ``src``
tag : str or None
A label associated with the operator, for debugging purposes.
Notes
-----
1. sets ``[] if inc else [dst]``
2. incs ``[dst] if inc else []``
3. reads ``[src]``
4. updates ``[]``
"""
def __init__(self, dst, src, dst_slice=Ellipsis, src_slice=Ellipsis,
inc=False, tag=None):
if isinstance(src_slice, slice):
src = src[src_slice]
src_slice = Ellipsis
if isinstance(dst_slice, slice):
dst = dst[dst_slice]
dst_slice = Ellipsis
# ^ src_slice and dst_slice are now either lists of indices or Ellipsis
self.src = src
self.dst = dst
self.src_slice = src_slice
self.dst_slice = dst_slice
self.inc = inc
self.tag = tag
self.sets = [] if inc else [dst]
self.incs = [dst] if inc else []
self.reads = [src]
self.updates = []
def _descstr(self):
return '%s[%s] -> %s[%s], inc=%s' % (
self.src, self.src_slice, self.dst, self.dst_slice, self.inc)
def make_step(self, signals, dt, rng):
src = signals[self.src]
dst = signals[self.dst]
src_slice = self.src_slice
dst_slice = self.dst_slice
inc = self.inc
def step_slicedcopy():
if inc:
dst[dst_slice] += src[src_slice]
else:
dst[dst_slice] = src[src_slice]
return step_slicedcopy
[docs]class ElementwiseInc(Operator):
"""Increment signal ``Y`` by ``A * X`` (with broadcasting).
Implements ``Y[...] += A * X``.
Parameters
----------
A : Signal
The first signal to be multiplied.
X : Signal
The second signal to be multiplied.
Y : Signal
The signal to be incremented.
tag : str, optional (Default: None)
A label associated with the operator, for debugging purposes.
Attributes
----------
A : Signal
The first signal to be multiplied.
tag : str or None
A label associated with the operator, for debugging purposes.
X : Signal
The second signal to be multiplied.
Y : Signal
The signal to be incremented.
Notes
-----
1. sets ``[]``
2. incs ``[Y]``
3. reads ``[A, X]``
4. updates ``[]``
"""
def __init__(self, A, X, Y, tag=None):
self.A = A
self.X = X
self.Y = Y
self.tag = tag
self.sets = []
self.incs = [Y]
self.reads = [A, X]
self.updates = []
def _descstr(self):
return '%s, %s -> %s' % (self.A, self.X, self.Y)
def make_step(self, signals, dt, rng):
A = signals[self.A]
X = signals[self.X]
Y = signals[self.Y]
# check broadcasting shapes
Ashape = npext.broadcast_shape(A.shape, 2)
Xshape = npext.broadcast_shape(X.shape, 2)
Yshape = npext.broadcast_shape(Y.shape, 2)
assert all(len(s) == 2 for s in [Ashape, Xshape, Yshape])
for da, dx, dy in zip(Ashape, Xshape, Yshape):
if not (da in [1, dy] and dx in [1, dy] and max(da, dx) == dy):
raise BuildError("Incompatible shapes in ElementwiseInc: "
"Trying to do %s += %s * %s" %
(Yshape, Ashape, Xshape))
def step_elementwiseinc():
Y[...] += A * X
return step_elementwiseinc
def reshape_dot(A, X, Y, tag=None):
"""Checks if the dot product needs to be reshaped.
Also does a bunch of error checking based on the shapes of A and X.
"""
badshape = False
ashape = (1,) if A.shape == () else A.shape
xshape = (1,) if X.shape == () else X.shape
if A.shape == ():
incshape = X.shape
elif X.shape == ():
incshape = A.shape
elif X.ndim == 1:
badshape = ashape[-1] != xshape[0]
incshape = ashape[:-1]
else:
badshape = ashape[-1] != xshape[-2]
incshape = ashape[:-1] + xshape[:-2] + xshape[-1:]
if (badshape or incshape != Y.shape) and incshape != ():
raise BuildError("shape mismatch in %s: %s x %s -> %s"
% (tag, A.shape, X.shape, Y.shape))
# Reshape to handle case when np.dot(A, X) and Y are both scalars
return (np.dot(A, X)).size == Y.size == 1
[docs]class DotInc(Operator):
"""Increment signal ``Y`` by ``dot(A, X)``.
Implements ``Y[...] += np.dot(A, X)``.
.. note:: Currently, this only supports matrix-vector multiplies
for compatibility with Nengo OCL.
Parameters
----------
A : Signal
The first signal to be multiplied.
X : Signal
The second signal to be multiplied.
Y : Signal
The signal to be incremented.
tag : str, optional (Default: None)
A label associated with the operator, for debugging purposes.
Attributes
----------
A : Signal
The first signal to be multiplied.
tag : str or None
A label associated with the operator, for debugging purposes.
X : Signal
The second signal to be multiplied.
Y : Signal
The signal to be incremented.
Notes
-----
1. sets ``[]``
2. incs ``[Y]``
3. reads ``[A, X]``
4. updates ``[]``
"""
def __init__(self, A, X, Y, tag=None):
if X.ndim >= 2 and any(d > 1 for d in X.shape[1:]):
raise BuildError("X must be a column vector")
if Y.ndim >= 2 and any(d > 1 for d in Y.shape[1:]):
raise BuildError("Y must be a column vector")
self.A = A
self.X = X
self.Y = Y
self.tag = tag
self.sets = []
self.incs = [Y]
self.reads = [A, X]
self.updates = []
def _descstr(self):
return '%s, %s -> %s' % (self.A, self.X, self.Y)
def make_step(self, signals, dt, rng):
X = signals[self.X]
A = signals[self.A]
Y = signals[self.Y]
reshape = reshape_dot(A, X, Y, self.tag)
def step_dotinc():
inc = np.dot(A, X)
if reshape:
inc = np.asarray(inc).reshape(Y.shape)
Y[...] += inc
return step_dotinc
[docs]class SimPyFunc(Operator):
"""Set a signal to a Python function with optional arguments.
Implements ``output[...] = fn(*args)`` where ``args`` can
include the current simulation time ``t`` and an input signal ``x``.
Note that ``output`` may also be None, in which case the function is
called but no output is captured.
Parameters
----------
output : Signal or None
The signal to be set. If None, the function is still called.
fn : callable
The function to call.
t : Signal or None
The signal associated with the time (a float, in seconds).
If None, the time will not be passed to ``fn``.
x : Signal or None
An input signal to pass to ``fn``.
If None, an input signal will not be passed to ``fn``.
tag : str, optional (Default: None)
A label associated with the operator, for debugging purposes.
Attributes
----------
fn : callable
The function to call.
output : Signal or None
The signal to be set. If None, the function is still called.
t : Signal or None
The signal associated with the time (a float, in seconds).
If None, the time will not be passed to ``fn``.
tag : str or None
A label associated with the operator, for debugging purposes.
x : Signal or None
An input signal to pass to ``fn``.
If None, an input signal will not be passed to ``fn``.
Notes
-----
1. sets ``[] if output is None else [output]``
2. incs ``[]``
3. reads ``([] if t is None else [t]) + ([] if x is None else [x])``
4. updates ``[]``
"""
def __init__(self, output, fn, t, x, tag=None):
self.output = output
self.fn = fn
self.t = t
self.x = x
self.tag = tag
self.sets = [] if output is None else [output]
self.incs = []
self.reads = ([] if t is None else [t]) + ([] if x is None else [x])
self.updates = []
def _descstr(self):
return '%s -> %s, fn=%r' % (self.x, self.output, self.fn.__name__)
def make_step(self, signals, dt, rng):
fn = self.fn
output = signals[self.output] if self.output is not None else None
t = signals[self.t] if self.t is not None else None
x = signals[self.x] if self.x is not None else None
def step_simpyfunc():
args = (np.copy(x),) if x is not None else ()
y = fn(t.item(), *args) if t is not None else fn(*args)
if output is not None:
if y is None: # required since Numpy turns None into NaN
raise SimulationError(
"Function %r returned None" % fn.__name__)
try:
output[...] = y
except ValueError:
raise SimulationError("Function %r returned invalid value "
"%r" % (fn.__name__, y))
return step_simpyfunc