This page was generated from docs/in-depth/api-low-level.ipynb. Interactive online version: Binder badge

πŸ›  Low-level Module API

The low-level API in Rockpool is designed for minimal efficient implementation of stateful neural networks.

The Module base class provides facilities for configuring, simulating and examining networks of stateful neurons.

Constructing a Module

All Module subclasses accept minimally a shape argument on construction. This should specify the input, output and internal dimensionality of the Module completely, so that the code can determine how many neurons should be generated, and the sizes of the state variables and parameters.

Some Module subclasses allow you to specify the module shape by setting concrete parameter arrays, e.g. by setting a vector of length (N,) as the bias parameters for a set of neurons. These concrete parameter values will be used to initialise the Module, and if the Module is reset, then the parameters will return to those concrete values.

Otherwise, all Module subclasses will set reasonable default initialisation values for the parameters.

[1]:
# - Switch off warnings
import warnings

warnings.filterwarnings("ignore")

# - Useful imports
try:
    from rich import print
except:
    pass

# - Example of constructing a module
from rockpool.nn.modules import Rate
import numpy as np

# - Construct a Module with 4 neurons
mod = Rate(4)
print(mod)
Rate  with shape (4,)
[2]:
# - Construct a Module with concrete parameters
mod = Rate(4, tau=np.ones(4))
print(mod)
Rate  with shape (4,)

Evolving a Module

You evolve the state of a Module by simply calling it. Module subclasses expect clocked raterised data as numpy arrays with shape (T, Nin) or (batches, T, Nin). batches is the number of batches; T is the number of time steps, and Nin is the input size of the module mod.size_in.

Calling a Module has the following syntax:

output, new_state, recorded_state = mod(input: np.array, record: bool = False)

As a result of calling the Module, the output of the module is returned as a numpy array with shape (batches, T, Nout). Here Nout is the output size of the module module.size_out.

new_state will be a state dictionary containing the final state of the module, and all submodules, at the end of evolution. This will become more relevant when using the functional API (see [𝝺] Low-level functional API).

recorded_state is only requested if the argument record = True is passed to the module. In that case recorded_state will be a nested dictionary containing the recorded state of the module and all submodules. Each element in recorded_state should have shape (T, ...), where T is the number of evolution timesteps and the following dimensions are whatever appropriate for that state variable.

[3]:
# - Generate and evolve over some input
T = 5
input = np.random.rand(T, mod.size_in)
output, _, _ = mod(input)
print(f"Output shape: {output.shape}")
Output shape: (1, 5, 4)
[4]:
# - Request the recorded state
output, _, recorded_state = mod(input, record=True)
print("Parameters:", recorded_state)
Parameters:
{
    'rec_input': array([[[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]]]),
    'x': array([[[0.00447242, 0.00335577, 0.00469451, 0.00364505],
        [0.00536816, 0.00415351, 0.00567921, 0.00439317],
        [0.00544684, 0.00446443, 0.00618463, 0.00441146],
        [0.00637465, 0.00499181, 0.00717   , 0.00518635],
        [0.00711199, 0.00548246, 0.00750611, 0.00532915]]])
}

Parameters, State and SimulationParameters

Rockpool defines three types of parameters for Module s: Parameter, State and SimulationParameter.

Parameter s are roughly any parameter that you would consider part of the configuration of a network. If you need to tell someone else how to specify your network (without going into details of simulation backend), you tell them about your Parameter s. Often the set of Parameter s will be the trainable parameters of a network.

State s are any internal values that need to be maintained to track how the neurons, synapses, whatever in the dynamical system of a Module evolve over time. This could comprise neuron membrane potentials; synaptic currents; etc.

SimulationParameter s are attributes that need to be specified for simulation purposes, but which shouldn’t directly affect the network output and behaviour in theory. For example, the time-step dt of a Module is required for a forward Euler ODE solver, but the network configuration should be valid and usable regardless of what dt is set to. And you shouldn’t need to specify the dt when telling someone else about your network configuration.

One more useful wrapper class is Constant. You should use this to wrap any model parameters that you want to force not to be trainable.

These classes are defined in rockpool.parameters.

Building a network with Module s

The build a complex network in Rockpool, you need to define your own Module subclass. Module takes care of many things for you, allowing you to define a network architecture without much overhead.

Minimally you need to define an Module.__init__() method, which specifies network parameters (e.g. weights) and whichever submodules are required for your network. The submodules take over the job of defining their own parameters and states.

You also need to define an Module.evolve() method, which contains the β€œplumbing” of your network. This method specifies how data is passed through your network, between submodules, and out again.

We’ll build a simple FFwd layer containing some weights and a set of neurons.

Note that this simple example doesn’t return the updated module state and recorded state properly.

[5]:
# - Build a simple network
from rockpool.nn.modules import Module
from rockpool.parameters import Parameter
from rockpool.nn.modules import RateJax


class ffwd_net(Module):
    # - Provide an `__init__` method to specify required parameters and modules
    #   Here you check, define and initialise whatever parameters and
    #   state you need for your module.
    def __init__(
        self,
        shape,
        *args,
        **kwargs,
    ):
        # - Call superclass initialisation
        #   This is always required for a `Module` class
        super().__init__(shape=shape, *args, **kwargs)

        # - Specify weights attribute
        #   We need a weights matrix for our input weights.
        #   We specify the shape explicitly, and provide an initialisation function.
        #   We also specify a family for the parameter, "weights". This is used to
        #   query parameters conveniently, and is a good idea to provide.
        self.w_ffwd = Parameter(
            shape=self.shape,
            init_func=lambda s: np.zeros(s),
            family="weights",
        )

        # - Specify and a add submodule
        #   These will be the neurons in our layer, to receive the weighted
        #   input signals. This sub-module will be automatically configured
        #   internally, to specify the required state and parameters
        self.neurons = RateJax(self.shape[-1])

    # - The `evolve` method contains the internal logic of your module
    #   `evolve` takes care of passing data in and out of the module,
    #   and between sub-modules if present.
    def evolve(self, input_data, *args, **kwargs):
        # - Pass input data through the input weights
        x = input_data @ self.w_ffwd

        # - Pass the signals through the neurons
        x, _, _ = self.neurons(x)

        # - Return the module output
        return x, {}, {}

Writing an evolve() method that returns state and record

To adhere to the Module API, your Module.evolve() method must return the updated set of states after evolution, and must support recording internal states if requested. The example below replaces the Module.evolve() method for the network above, illustrating how to conveniently do this.

[6]:
def evolve(self, input_data, record: bool = False, *args, **kwargs):
    # - Initialise state and record dictionaries
    new_state = {}
    recorded_state = {}

    # - Pass input data through the input weights
    x = input_data @ self.w_ffwd

    # - Add an internal signal record to the record dictionary
    if record:
        recorded_state["weighted_input"] = x

    # - Pass the signals through the neurons, passing through the `record` argument
    x, submod_state, submod_record = self.neurons(x, record=record)

    # - Record the submodule state
    new_state.update("neurons", submod_state)

    # - Include the recorded state
    recorded_state.update("neurons", submod_record)

    # - Return the module output
    return x, new_state, recorded_state

Inspecting a Module

You can examine the internal parameters and state of a Module using a set of convenient inspection methods parameters(), state() and simulation_parameters().

params: dict = mod.parameters(family: str = None)
state: dict = mod.state(family: str = None)
simulation_parameters: dict = mod.simulation_parameters(family: str = None)

In each case the method returns a nested dictionary containins all registered attributes for the module and all submodules.

[7]:
# - Build a module for our network
my_mod = ffwd_net((4, 6))
print(my_mod)
ffwd_net  with shape (4, 6) {
    RateJax 'neurons' with shape (6,)
}
[8]:
# - Show module parameters
print("Parameters:", my_mod.parameters())
Parameters:
{
    'w_ffwd': array([[0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.]]),
    'neurons': {
        'tau': DeviceArray([0.02, 0.02, 0.02, 0.02, 0.02, 0.02], dtype=float32),
        'bias': DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32),
        'threshold': DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)
    }
}
[9]:
# - Show module state
print("State:", my_mod.state())
State:
{
    'neurons': {
        'rng_key': DeviceArray([1251626347,  511538859], dtype=uint32),
        'x': DeviceArray([0., 0., 0., 0., 0., 0.], dtype=float32)
    }
}
[10]:
# - Return parameters from particular families
print("Module time constants:", my_mod.parameters("taus"))
print("Module weights:", my_mod.parameters("weights"))
Module time constants:
{'neurons': {'tau': DeviceArray([0.02, 0.02, 0.02, 0.02, 0.02, 0.02], dtype=float32)}}
Module weights:
{
    'w_ffwd': array([[0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0.]])
}

You can of course access all attributes of a Module directly using standard Python β€œdot” indexing syntax:

[11]:
# - Access parameters directly
print(".w_ffwd:", my_mod.w_ffwd)
print(".neurons.tau:", my_mod.neurons.tau)
.w_ffwd: [[0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0.]]
.neurons.tau: [0.02 0.02 0.02 0.02 0.02 0.02]

Module API reference

Every Module provides the following attributes:

Attribute

Description

class_name

The name of the subclass

name

The attribute name that this module was assigned to. Will be None for a base-level module

full_name

The class name and module name together. Useful for printing

spiking_input

If True this module expects spiking input. Otherwise the input is real-valued

spiking_output

If True this module produces spiking output. Otherwise the module outputs floating-point values

shape

The dimensions of the module. Can have any number of entries, for complex modules. shape[0] is the input dimensionality; shape[-1] is the output dimensionality.

size_in

The number of input channels the module expects

size_out

The number of output channels the module produces

Every Module provides the following methods:

Method

Description

parameters()

Return a nested dictionary of module parameters, optionally restricting the search to a particular family of parameters such as weights

state()

Return a nested dictionary of module state

simulation_parameters()

Return a nested dictionary of module simulation parameters

modules()

Return a list of submodules of this module

attributes_named()

Search for and return nested attributes matching a particular name

set_attributes()

Set the parameter values for this and nested submodules

reset_state()

Reset the state of this and nested submodules

reset_parameters()

Reset the parameters of this and nested submodules to their initialisation defaults

_auto_batch()

Utility method to assist with handling batched data

timed()

Convert this module to the high-level TimedModule API.