Source code for nengo.spa.basalganglia
import numpy as np
import nengo
from nengo.exceptions import ValidationError
from nengo.spa.action_objects import DotProduct, Source
from nengo.spa.module import Module
from nengo.utils.compat import is_number
[docs]class BasalGanglia(Module):
"""A basal ganglia, performing action selection on a set of given actions.
See `.networks.BasalGanglia` for more details.
Parameters
----------
actions : Actions
The actions to choose between.
input_synapse : float, optional (Default: 0.002)
The synaptic filter on all input connections.
label : str, optional (Default: None)
A name for the ensemble. Used for debugging and visualization.
seed : int, optional (Default: None)
The seed used for random number generation.
add_to_container : bool, optional (Default: None)
Determines if this Network will be added to the current container.
If None, will be true if currently in a Network context.
"""
def __init__(self, actions, input_synapse=0.002,
label=None, seed=None, add_to_container=None):
self.actions = actions
self.input_synapse = input_synapse
self._bias = None
Module.__init__(self, label, seed, add_to_container)
nengo.networks.BasalGanglia(dimensions=self.actions.count, net=self)
@property
def bias(self):
"""Create a bias node, when needed."""
if self._bias is None:
with self:
self._bias = nengo.Node([1], label="basal ganglia bias")
return self._bias
[docs] def on_add(self, spa):
"""Form the connections into the BG to compute the utilty values.
Each action's condition variable contains the set of computations
needed for that action's utility value, which is the input to the
basal ganglia.
"""
Module.on_add(self, spa)
self.spa = spa
self.actions.process(spa) # parse the actions
for i, action in enumerate(self.actions.actions):
cond = action.condition.expression
# the basal ganglia hangles the condition part of the action;
# the effect is handled by the thalamus
# Note: A Source is an output from a module, and a Symbol is
# text that can be parsed to be a SemanticPointer
for c in cond.items:
if isinstance(c, DotProduct):
if ((isinstance(c.item1, Source) and c.item1.inverted) or
(isinstance(c.item2, Source) and c.item2.inverted)):
raise NotImplementedError(
"Inversion in subexpression '%s' from action '%s' "
"is not supported by the Basal Ganglia." %
(c, action))
if isinstance(c.item1, Source):
if isinstance(c.item2, Source):
# dot product between two different sources
self.add_compare_input(i, c.item1, c.item2,
c.scale)
else:
self.add_dot_input(i, c.item1, c.item2, c.scale)
else:
# enforced in DotProduct constructor
assert isinstance(c.item2, Source)
self.add_dot_input(i, c.item2, c.item1, c.scale)
elif isinstance(c, Source):
self.add_scalar_input(i, c)
elif is_number(c):
self.add_bias_input(i, c)
else:
raise NotImplementedError(
"Subexpression '%s' from action '%s' is not supported "
"by the Basal Ganglia." % (c, action))
[docs] def add_bias_input(self, index, value):
"""Make an input that is just a fixed scalar value.
Parameters
----------
index : int
the index of the action
value : float or int
the fixed utility value to add
"""
with self.spa:
nengo.Connection(self.bias, self.input[index:index+1],
transform=value, synapse=self.input_synapse)
[docs] def add_compare_input(self, index, source1, source2, scale):
"""Make an input that is the dot product of two different sources.
This would be used for an input action such as ``dot(vision, memory)``.
Each source might be transformed before being compared. If the
two sources have different vocabularies, we use the vocabulary of
the first one for comparison.
Parameters
----------
index : int
The index of the action.
source1 : Source
The first module output to read from.
source2 : Source
The second module output to read from.
scale : float
A scaling factor to be applied to the result.
"""
raise NotImplementedError("Compare between two sources will never be "
"implemented as discussed in "
"https://github.com/nengo/nengo/issues/759")
[docs] def add_dot_input(self, index, source, symbol, scale):
"""Make an input that is the dot product of a Source and a Symbol.
This would be used for an input action such as ``dot(vision, A)``.
The source may have a transformation applied first.
Parameters
----------
index : int
The index of the action.
source : Source
The module output to read from.
symbol : Source
The semantic pointer to compute the dot product with.
scale : float
A scaling factor to be applied to the result.
"""
output, vocab = self.spa.get_module_output(source.name)
# the first transformation, to handle dot(vision*A, B)
t1 = vocab.parse(source.transform.symbol).get_convolution_matrix()
# the linear transform to compute the fixed dot product
t2 = np.array([vocab.parse(symbol.symbol).v*scale])
transform = np.dot(t2, t1)
with self.spa:
nengo.Connection(output, self.input[index:index+1],
transform=transform, synapse=self.input_synapse)
[docs] def add_scalar_input(self, index, source):
"""Add a scalar input that will vary over time.
This is used for the ouput of the `.Compare` module.
Parameters
----------
index : int
The index of the action.
source : Source
The module output to read from.
"""
output, _ = self.spa.get_module_output(source.name)
if output.size_out != 1:
raise NotImplementedError(
"Only 1-dimensional sources can be scalar inputs")
try:
scale = float(eval(source.transform.symbol))
except ValueError:
raise ValidationError("Transform must be scalar; got '%s'"
% source.transform.symbol,
attr='source.transform')
with self.spa:
nengo.Connection(output, self.input[index:index+1],
transform=scale,
synapse=self.input_synapse)