This page was generated from docs/in-depth/api-functional.ipynb. Interactive online version: Binder badge

[𝝺] Low-level functional API

Rockpool Module s and the JaxModule base class support a functional form for manipulating parameters and for evolution. This is particularly important when using Jax, since this library requires a functional programming style.

Functional evolution

First let’s set up a module to play with:

[1]:
# - Switch off warnings
import warnings

warnings.filterwarnings("ignore")

# - Rockpool imports
from rockpool.nn.modules import RateJax

# - Other useful imports
import numpy as np

try:
    from rich import print
except:
    pass

# - Construct a module
N = 3
mod = RateJax(N)

Now if we evolve the module, we get the outputs we expect:

[2]:
# - Set up some input
T = 10
input = np.random.rand(T, N)
output, new_state, record = mod(input)
[3]:
print("output:", output)
output: [[[0.01076346 0.02204564 0.02956902]
  [0.04010231 0.06479559 0.07025937]
  [0.0432203  0.06293815 0.11122491]
  [0.06027619 0.08317991 0.1307195 ]
  [0.05882018 0.09205481 0.15850767]
  [0.08602361 0.12148949 0.16591538]
  [0.10797263 0.11710763 0.18505278]
  [0.13634151 0.13502166 0.20233528]
  [0.15542242 0.15269144 0.20374872]
  [0.16027772 0.15767853 0.19997193]]]
[4]:
print("new_state:", new_state)
new_state:
{
    'x': DeviceArray([0.16027772, 0.15767853, 0.19997193], dtype=float32),
    'rng_key': DeviceArray([2469880657, 3700232383], dtype=uint32)
}

So far so good. The issue with jax is that jit-compiled modules and functions cannot have side-effects. For Rockpool, evolution almost always has side-effects, in terms of updating the internal state variables of each module.

In the case of the evolution above, we can see that the internal state was not updated during evolution:

[5]:
print("mod.state:", mod.state())
print(mod.state()["x"], " != ", new_state["x"])
mod.state:
{
    'rng_key': DeviceArray([ 237268104, 2681681569], dtype=uint32),
    'x': DeviceArray([0., 0., 0.], dtype=float32)
}
[0. 0. 0.]  !=  [0.16027772 0.15767853 0.19997193]

The correct resolution to this is to assign new_state to the module atfer each evolution:

[6]:
mod = mod.set_attributes(new_state)
print(mod.state()["x"], " == ", new_state["x"])
[0.16027772 0.15767853 0.19997193]  ==  [0.16027772 0.15767853 0.19997193]

You will have noticed the functional form of the call to set_attributes() above. This is addressed in the next section.

Functional state and attribute setting

Direct attribute assignment works at the top level, using standard Python syntax:

[7]:
new_tau = mod.tau * 0.4
mod.tau = new_tau
print(new_tau, " == ", mod.tau)
[0.008 0.008 0.008]  ==  [0.008 0.008 0.008]

A functional form is also supported, via the set_attributes() method. Here a copy of the module (and submodules) is returned, to replace the β€œold” module with one with updated attributes:

[8]:
params = mod.parameters()
params["tau"] = params["tau"] * 3.0

# - Note the functional calling style
mod = mod.set_attributes(params)

# - check that the attribute was set
print(params["tau"], " == ", mod.tau)
[0.024 0.024 0.024]  ==  [0.024 0.024 0.024]

Functional module reset

Resetting the module state and parameters also must be done using a functional form:

[9]:
# - Reset the module state
mod = mod.reset_state()

# - Reset the module parameters
mod = mod.reset_parameters()

Jax flattening

JaxModule provides the methods tree_flatten() and tree_unflatten(), which are required to serialise and deserialise modules for Jax compilation and execution.

If you write a JaxModule subclass, it will be automatically registered with Jax as a pytree. You shouldn’t need to override tree_flatten() or tree_unflatten() in your modules.

Flattening and unflattening requires that your __init__() method must be callable with only a shape as input, which should be sufficient to specify the network architecture of your module and all submodules.

If that isn’t the case, then you may need to override tree_flatten() and tree_unflatten().