Source code for nengo.builder.operator

"""Operators represent calculations that will occur in the simulation.

This code adapted from sigops/operator.py and sigops/operators.py
(https://github.com/jaberg/sigops).
This modified code is included under the terms of their license:

Copyright (c) 2014, James Bergstra
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:

1. Redistributions of source code must retain the above copyright
   notice, this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright
   notice, this list of conditions and the following disclaimer in the
   documentation and/or other materials provided with the
   distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""

import numpy as np

import nengo.utils.numpy as npext
from nengo.exceptions import BuildError, SimulationError


[docs]class Operator(object): """Base class for operator instances understood by Nengo. During one simulator timestep, a `.Signal` can experience 1. at most one set operator (optional) 2. any number of increments 3. any number of reads 4. at most one update in this specific order. A ``set`` defines the state of the signal at time :math:`t`, the start of the simulation timestep. That state can then be modified by ``increment`` operations. A signal's state will only be ``read`` after all increments are complete. The state is then finalized by an ``update``, which denotes the state that the signal should be at time :math:`t + dt`. Each operator must keep track of the signals that it manipulates, and which of these four types of manipulations is done to each signal so that the simulator can order all of the operators properly. .. note:: There are intentionally no default values for the `~.Operator.reads`, `~.Operator.sets`, `~.Operator.incs`, and `~.Operator.updates` properties to ensure that subclasses explicitly set these values. Parameters ---------- tag : str, optional (Default: None) A label associated with the operator, for debugging purposes. Attributes ---------- tag : str or None A label associated with the operator, for debugging purposes. """ def __init__(self, tag=None): self.tag = tag def __repr__(self): return "<%s%s at 0x%x>" % ( self.__class__.__name__, self._tagstr(), id(self)) def __str__(self): strs = (s for s in (self._descstr(), self._tagstr()) if s) return "%s{%s}" % (self.__class__.__name__, ' '.join(strs)) def _descstr(self): return '' def _tagstr(self): return ('"%s"' % self.tag) if self.tag is not None else '' @property def all_signals(self): return self.reads + self.sets + self.incs + self.updates @property def incs(self): """Signals incremented by this operator. Increments will be applied after sets (if it is set), and before reads. """ return self._incs @incs.setter def incs(self, val): self._incs = val @property def reads(self): """Signals that are read and not modified by this operator. Reads occur after increments, and before updates. """ return self._reads @reads.setter def reads(self, val): self._reads = val @property def sets(self): """Signals set by this operator. Sets occur first, before increments. A signal that is set here cannot be set or updated by any other operator. """ return self._sets @sets.setter def sets(self, val): self._sets = val @property def updates(self): """Signals updated by this operator. Updates are the last operation to occur to a signal. """ return self._updates @updates.setter def updates(self, val): self._updates = val
[docs] def init_signals(self, signals): """Initialize the signals associated with this operator. The signals will be initialized into ``signals``. Operator subclasses that use extra buffers should create them here. Parameters ---------- signals : SignalDict A mapping from signals to their associated live ndarrays. """ for sig in self.all_signals: if sig not in signals: signals.init(sig)
[docs] def make_step(self, signals, dt, rng): """Returns a callable that performs the desired computation. This method must be implemented by subclasses. To fully understand what an operator does, look at its implementation of ``make_step``. Parameters ---------- signals : SignalDict A mapping from signals to their associated live ndarrays. dt : float Length of each simulation timestep, in seconds. rng : `numpy.random.RandomState` Random number generator for stochastic operators. """ raise NotImplementedError("subclasses must implement this method.")
[docs]class TimeUpdate(Operator): """Updates the simulation step and time. Implements ``step[...] += 1`` and ``time[...] = step * dt``. A separate operator is used (rather than a combination of `.Copy` and `.DotInc`) so that other backends can manage these important parts of the simulation state separately from other signals. Parameters ---------- step : Signal The signal associated with the integer step counter. time : Signal The signal associated with the time (a float, in seconds). tag : str, optional (Default: None) A label associated with the operator, for debugging purposes. Attributes ---------- step : Signal The signal associated with the integer step counter. tag : str or None A label associated with the operator, for debugging purposes. time : Signal The signal associated with the time (a float, in seconds). Notes ----- 1. sets ``[step, time]`` 2. incs ``[]`` 3. reads ``[]`` 4. updates ``[]`` """ def __init__(self, step, time, tag=None): self.step = step self.time = time self.tag = tag self.sets = [step, time] self.incs = [] self.reads = [] self.updates = [] def make_step(self, signals, dt, rng): step = signals[self.step] time = signals[self.time] def step_timeupdate(): step[...] += 1 time[...] = step * dt return step_timeupdate
[docs]class PreserveValue(Operator): """Marks a signal as ``set`` for the graph checker. This operator does no computation. It simply marks a signal as ``set``, allowing us to apply other ops to signals that we want to preserve their value across multiple time steps. It is used primarily for learning rules. Parameters ---------- dst : Signal The signal whose value we want to preserve. tag : str, optional (Default: None) A label associated with the operator, for debugging purposes. Attributes ---------- dst : Signal The signal whose value we want to preserve. tag : str or None A label associated with the operator, for debugging purposes. Notes ----- 1. sets ``[dst]`` 2. incs ``[]`` 3. reads ``[]`` 4. updates ``[]`` """ def __init__(self, dst, tag=None): self.dst = dst self.tag = tag self.sets = [dst] self.incs = [] self.reads = [] self.updates = [] def _descstr(self): return str(self.dst) def make_step(self, signals, dt, rng): def step_preservevalue(): pass return step_preservevalue
[docs]class Reset(Operator): """Assign a constant value to a Signal. Implements ``dst[...] = value``. Parameters ---------- dst : Signal The Signal to reset. value : float, optional (Default: 0) The constant value to which ``dst`` is set. tag : str, optional (Default: None) A label associated with the operator, for debugging purposes. Attributes ---------- dst : Signal The Signal to reset. tag : str or None A label associated with the operator, for debugging purposes. value : float The constant value to which ``dst`` is set. Notes ----- 1. sets ``[dst]`` 2. incs ``[]`` 3. reads ``[]`` 4. updates ``[]`` """ def __init__(self, dst, value=0, tag=None): self.dst = dst self.value = float(value) self.tag = tag self.sets = [dst] self.incs = [] self.reads = [] self.updates = [] def _descstr(self): return str(self.dst) def make_step(self, signals, dt, rng): target = signals[self.dst] value = self.value def step_reset(): target[...] = value return step_reset
[docs]class Copy(Operator): """Assign the value of one signal to another. Implements ``dst[...] = src``. Parameters ---------- dst : Signal The signal that will be assigned to (set). src : Signal The signal that will be copied (read). tag : str, optional (Default: None) A label associated with the operator, for debugging purposes. Attributes ---------- dst : Signal The signal that will be assigned to (set). src : Signal The signal that will be copied (read). tag : str or None A label associated with the operator, for debugging purposes. Notes ----- 1. sets ``[dst]`` 2. incs ``[]`` 3. reads ``[src]`` 4. updates ``[]`` """ def __init__(self, dst, src, tag=None): self.dst = dst self.src = src self.tag = tag self.sets = [dst] self.incs = [] self.reads = [src] self.updates = [] def _descstr(self): return '%s -> %s' % (self.src, self.dst) def make_step(self, signals, dt, rng): dst = signals[self.dst] src = signals[self.src] def step_copy(): dst[...] = src return step_copy
[docs]class SlicedCopy(Operator): """Assign the value of a slice of one signal to another slice. Implements ``dst[dst_slice] = src[src_slice]``. This operator can also implement ``dst[dst_slice] += src[src_slice]`` using the parameter ``inc``. Parameters ---------- dst : Signal The signal that will be assigned to (set). src : Signal The signal that will be copied (read). dst_slice : slice or Ellipsis, optional (Default: Ellipsis) Slice associated with ``dst``. src_slice : slice or Ellipsis, optional (Default: Ellipsis) Slice associated with ``src`` inc : bool, optional (Default: False) Whether this should be an increment rather than a copy. tag : str, optional (Default: None) A label associated with the operator, for debugging purposes. Attributes ---------- dst : Signal The signal that will be assigned to (set). dst_slice : list or Ellipsis Indices associated with ``dst``. src : Signal The signal that will be copied (read). src_slice : list or Ellipsis Indices associated with ``src`` tag : str or None A label associated with the operator, for debugging purposes. Notes ----- 1. sets ``[] if inc else [dst]`` 2. incs ``[dst] if inc else []`` 3. reads ``[src]`` 4. updates ``[]`` """ def __init__(self, dst, src, dst_slice=Ellipsis, src_slice=Ellipsis, inc=False, tag=None): if isinstance(src_slice, slice): src = src[src_slice] src_slice = Ellipsis if isinstance(dst_slice, slice): dst = dst[dst_slice] dst_slice = Ellipsis # ^ src_slice and dst_slice are now either lists of indices or Ellipsis self.src = src self.dst = dst self.src_slice = src_slice self.dst_slice = dst_slice self.inc = inc self.tag = tag self.sets = [] if inc else [dst] self.incs = [dst] if inc else [] self.reads = [src] self.updates = [] def _descstr(self): return '%s[%s] -> %s[%s], inc=%s' % ( self.src, self.src_slice, self.dst, self.dst_slice, self.inc) def make_step(self, signals, dt, rng): src = signals[self.src] dst = signals[self.dst] src_slice = self.src_slice dst_slice = self.dst_slice inc = self.inc def step_slicedcopy(): if inc: dst[dst_slice] += src[src_slice] else: dst[dst_slice] = src[src_slice] return step_slicedcopy
[docs]class ElementwiseInc(Operator): """Increment signal ``Y`` by ``A * X`` (with broadcasting). Implements ``Y[...] += A * X``. Parameters ---------- A : Signal The first signal to be multiplied. X : Signal The second signal to be multiplied. Y : Signal The signal to be incremented. tag : str, optional (Default: None) A label associated with the operator, for debugging purposes. Attributes ---------- A : Signal The first signal to be multiplied. tag : str or None A label associated with the operator, for debugging purposes. X : Signal The second signal to be multiplied. Y : Signal The signal to be incremented. Notes ----- 1. sets ``[]`` 2. incs ``[Y]`` 3. reads ``[A, X]`` 4. updates ``[]`` """ def __init__(self, A, X, Y, tag=None): self.A = A self.X = X self.Y = Y self.tag = tag self.sets = [] self.incs = [Y] self.reads = [A, X] self.updates = [] def _descstr(self): return '%s, %s -> %s' % (self.A, self.X, self.Y) def make_step(self, signals, dt, rng): A = signals[self.A] X = signals[self.X] Y = signals[self.Y] # check broadcasting shapes Ashape = npext.broadcast_shape(A.shape, 2) Xshape = npext.broadcast_shape(X.shape, 2) Yshape = npext.broadcast_shape(Y.shape, 2) assert all(len(s) == 2 for s in [Ashape, Xshape, Yshape]) for da, dx, dy in zip(Ashape, Xshape, Yshape): if not (da in [1, dy] and dx in [1, dy] and max(da, dx) == dy): raise BuildError("Incompatible shapes in ElementwiseInc: " "Trying to do %s += %s * %s" % (Yshape, Ashape, Xshape)) def step_elementwiseinc(): Y[...] += A * X return step_elementwiseinc
def reshape_dot(A, X, Y, tag=None): """Checks if the dot product needs to be reshaped. Also does a bunch of error checking based on the shapes of A and X. """ badshape = False ashape = (1,) if A.shape == () else A.shape xshape = (1,) if X.shape == () else X.shape if A.shape == (): incshape = X.shape elif X.shape == (): incshape = A.shape elif X.ndim == 1: badshape = ashape[-1] != xshape[0] incshape = ashape[:-1] else: badshape = ashape[-1] != xshape[-2] incshape = ashape[:-1] + xshape[:-2] + xshape[-1:] if (badshape or incshape != Y.shape) and incshape != (): raise BuildError("shape mismatch in %s: %s x %s -> %s" % (tag, A.shape, X.shape, Y.shape)) # Reshape to handle case when np.dot(A, X) and Y are both scalars return (np.dot(A, X)).size == Y.size == 1
[docs]class DotInc(Operator): """Increment signal ``Y`` by ``dot(A, X)``. Implements ``Y[...] += np.dot(A, X)``. .. note:: Currently, this only supports matrix-vector multiplies for compatibility with Nengo OCL. Parameters ---------- A : Signal The first signal to be multiplied. X : Signal The second signal to be multiplied. Y : Signal The signal to be incremented. tag : str, optional (Default: None) A label associated with the operator, for debugging purposes. Attributes ---------- A : Signal The first signal to be multiplied. tag : str or None A label associated with the operator, for debugging purposes. X : Signal The second signal to be multiplied. Y : Signal The signal to be incremented. Notes ----- 1. sets ``[]`` 2. incs ``[Y]`` 3. reads ``[A, X]`` 4. updates ``[]`` """ def __init__(self, A, X, Y, tag=None): if X.ndim >= 2 and any(d > 1 for d in X.shape[1:]): raise BuildError("X must be a column vector") if Y.ndim >= 2 and any(d > 1 for d in Y.shape[1:]): raise BuildError("Y must be a column vector") self.A = A self.X = X self.Y = Y self.tag = tag self.sets = [] self.incs = [Y] self.reads = [A, X] self.updates = [] def _descstr(self): return '%s, %s -> %s' % (self.A, self.X, self.Y) def make_step(self, signals, dt, rng): X = signals[self.X] A = signals[self.A] Y = signals[self.Y] reshape = reshape_dot(A, X, Y, self.tag) def step_dotinc(): inc = np.dot(A, X) if reshape: inc = np.asarray(inc).reshape(Y.shape) Y[...] += inc return step_dotinc
[docs]class SimPyFunc(Operator): """Set a signal to a Python function with optional arguments. Implements ``output[...] = fn(*args)`` where ``args`` can include the current simulation time ``t`` and an input signal ``x``. Note that ``output`` may also be None, in which case the function is called but no output is captured. Parameters ---------- output : Signal or None The signal to be set. If None, the function is still called. fn : callable The function to call. t : Signal or None The signal associated with the time (a float, in seconds). If None, the time will not be passed to ``fn``. x : Signal or None An input signal to pass to ``fn``. If None, an input signal will not be passed to ``fn``. tag : str, optional (Default: None) A label associated with the operator, for debugging purposes. Attributes ---------- fn : callable The function to call. output : Signal or None The signal to be set. If None, the function is still called. t : Signal or None The signal associated with the time (a float, in seconds). If None, the time will not be passed to ``fn``. tag : str or None A label associated with the operator, for debugging purposes. x : Signal or None An input signal to pass to ``fn``. If None, an input signal will not be passed to ``fn``. Notes ----- 1. sets ``[] if output is None else [output]`` 2. incs ``[]`` 3. reads ``([] if t is None else [t]) + ([] if x is None else [x])`` 4. updates ``[]`` """ def __init__(self, output, fn, t, x, tag=None): self.output = output self.fn = fn self.t = t self.x = x self.tag = tag self.sets = [] if output is None else [output] self.incs = [] self.reads = ([] if t is None else [t]) + ([] if x is None else [x]) self.updates = [] def _descstr(self): return '%s -> %s, fn=%r' % (self.x, self.output, self.fn.__name__) def make_step(self, signals, dt, rng): fn = self.fn output = signals[self.output] if self.output is not None else None t = signals[self.t] if self.t is not None else None x = signals[self.x] if self.x is not None else None def step_simpyfunc(): args = (np.copy(x),) if x is not None else () y = fn(t.item(), *args) if t is not None else fn(*args) if output is not None: if y is None: # required since Numpy turns None into NaN raise SimulationError( "Function %r returned None" % fn.__name__) try: output[...] = y except ValueError: raise SimulationError("Function %r returned invalid value " "%r" % (fn.__name__, y)) return step_simpyfunc