This page was generated from docs/tutorials/jax_lif_sgd.ipynb. Interactive online version: Binder badge

⚡️ Training a spiking network with Jax

This tutorial demonstrates using Rockpool and a Jax-accelerated LIF feed-forward neuron layer to perform gradient descent training of all network parameters. The result is a trained spiking layer which can generate a pre-defined signal from a noisy spiking input.

Requirements and housekeeping

This example requires the Rockpool package from SynSense, as well as jax and its dependencies.

import jax
# - Switch off warnings
import warnings


# - Rockpool imports
from rockpool import TSEvent, TSContinuous
from rockpool.nn.modules import LIFJax, LinearJax, ExpSynJax
from rockpool.nn.modules.jax.jax_lif_ode import LIFODEJax
from rockpool.nn.combinators import Sequential
from rockpool.parameters import Constant

# - Typing
from typing import Callable, Dict, Tuple
import types

# - Numpy
import numpy as np
import copy

# - Pretty printing
    from rich import print

from tqdm.autonotebook import tqdm

# - Plotting imports and config
import sys
!{sys.executable} -m pip install --quiet matplotlib
import matplotlib.pyplot as plt

%matplotlib inline
plt.rcParams["figure.figsize"] = [12, 4]
plt.rcParams["figure.dpi"] = 300

Signal generation from frozen noise task

We will use a single feed-forward layer of spiking neurons to convert a chosen pattern of random input spikes over time, into a pre-defined temporal signal with complex dynamics.

The network architecture is strictly feedforward, but the spiking neurons nevertheless contain temporal dynamics in their synaptic and membrane signals, with explicit time constants.

Some number of input channels Nin will contain independent Poisson spikes at some rate spiking_prob/dt. A single output channel should generate a chirp signal with increasing frequency, up to a maximum of chirp_freq_factor. You can play with these parameters below.

# - Define input and target
Nin = 200
dt = 1e-3
chirp_freq_factor = 10
dur_input = 1000e-3

# - Generate a time base
T = int(np.round(dur_input / dt))
timebase = np.linspace(0, (T - 1) * dt, T)

# - Generate a chirp signal as a target
chirp = np.atleast_2d(np.sin(timebase * 2 * np.pi * (timebase * chirp_freq_factor))).T
target_ts = TSContinuous(timebase, chirp, periodic=True, name="Target chirp")

# - Generate a Poisson frozen random spike train
spiking_prob = 0.01
input_sp_raster = np.random.rand(T, Nin) < spiking_prob
input_sp_ts = TSEvent.from_raster(
    input_sp_raster, name="Input spikes", periodic=True, dt=dt

# - Plot the input and target signals
(target_ts * Nin / 2 + Nin / 2).plot(color="orange", lw=2)
plt.title("Input and target");

LIF neuron

The spiking neuron we will use is a leaky integrate-and-fire spiking neuron (“LIF” neuron). This neuron recevies input spike trains \(S_{in}(t) = \sum_j\delta(t-t_j)\), which are integrated via weighted exponential synapses. Synaptic currents are then integrated into a neuron state (“membrane potential”) \(V_{mem}\).

The neuron obeys the dynamics

\[\tau_{mem}\cdot\dot{V}_{mem} + V_{mem} = {I}_{syn} + I_{bias} + \sigma\zeta(t)\]
\[\tau_{syn}\cdot\dot{I}_{syn} + I_{syn} = 0\]
\[I_{syn} += W_{in} \cdot S_{in}(t)\]

Where \(\tau_{mem}\) and \(\tau_{syn}\) are membrane and synaptic time constants; \(I_{bias}\) is a constant bias current for each neuron; \(\sigma\zeta(t)\) is a white noise process with std. dev. \(\sigma\).

Output spikes are generated when \(V_{mem}\) crosses the firing threshold \(V_{th} = 0\). This process generates a spike train \(S(t)\) as a series of delta functions, and causes a subtractive reset of \(V_{mem}\):

\[V_{mem} > V_{th} \rightarrow S(t) = H(V_{mem}(t)), V_{mem} = V_{mem} - 1\]

The analog output signal is generated using a surrogate

\[U(t) = \tanh(x + 1) / 2 + 0.5\]

The output of the network \(o(t)\) is therefore given by

\[o(t) = W_{out} \cdot S(t)\]

For more detail, see the documentation for the Jax module LIFJax.

Build a network

The network architecture is a single feedforward layer, with weighted spiking inputs and outputs. Spiking is generated via a function that provides a surrogate gradient in the backwards pass. This permits propagation of an error gradient through the layer, making gradient-descent training possible.

For this regression task we will also use an exponential synapse layer to perfprm temporal smoothing of the output. Regressing to a smooth signal is much easier with a continuous output signal, than using the spike deltas alone.

# - Network size
N = 50
Nout = 1
input_scale = 1.
# - Generate a network using the sequential combinator
modFFwd = Sequential(
    LinearJax((Nin, N)),
    LIFJax(N, dt=dt),
    LinearJax((N, Nout)),

JaxSequential  with shape (200, 1) {
    LinearJax '0_LinearJax' with shape (200, 50)
    LIFJax '1_LIFJax' with shape (50, 50)
    ExpSynJax '2_ExpSynJax' with shape (50,)
    LinearJax '3_LinearJax' with shape (50, 1)

Simulate initial state of network

If we simulate the untrained network with our random input spikes, we don’t expect anything sensible to come out. Let’s do this, and take a look at how the network behaves.

# - Randomise the network state

# - Evolve with the frozen noise spiking input
tsOutput, new_state, record_dict = modFFwd(input_sp_raster * input_scale, record=True)

# - Plot the analog output
[<matplotlib.lines.Line2D at 0x3186ab730>]

We can also examine the internal state of the network, by interrogating record_dict:

# - Make a function that converts ``record_dict``
def plot_record_dict(rd):
    Isyn_ts = TSContinuous.from_clocked(
        rd["1_LIFJax"]["isyn"][0, :, :, 0], dt, name="Synaptic currents $I_{syn}$"
    Vmem_ts = TSContinuous.from_clocked(
        rd["1_LIFJax"]["vmem"][0], dt, name="Membrane potential $V_{mem}$"
    spikes_ts = TSEvent.from_raster(
        rd["1_LIFJax"]["spikes"][0], dt, name="LIF layer spikes"

    # - Plot the internal activity of selected neurons
    Isyn_ts.plot(stagger=1.1, skip=5)

    Vmem_ts.plot(stagger=1.1, skip=5)



Training the network

In order to train the network we need to define a loss function to optimise. This function accepts a set of parameters, the network, the inputs and target for a trial, and computes an error (“loss”) for the trial. The loss computed by comparing the network output to the target using mean-squared error.

Usually you would add regularisation terms to the loss function, to make sure parameters don’t grow too large; to encourage low firing rates; etc. Generally you would want to also place bounds on the time constants, to prevent them becoming too small and causing numerical instability. See 🏃🏽‍♀️ Training a Rockpool network with Jax for more information.

# - Import the convenience functions
from import bounds_cost, make_bounds, bounds_clip

# - Generate a set of pre-configured bounds
lower_bounds, upper_bounds = make_bounds(modFFwd.parameters())
print("lower_bounds: ", lower_bounds, "upper_bounds: ", upper_bounds)
    '0_LinearJax': {'weight': -inf},
    '1_LIFJax': {'bias': -inf, 'tau_mem': -inf, 'tau_syn': -inf, 'threshold': -inf},
    '2_ExpSynJax': {'tau': -inf},
    '3_LinearJax': {'weight': -inf}
    '0_LinearJax': {'weight': inf},
    '1_LIFJax': {'bias': inf, 'tau_mem': inf, 'tau_syn': inf, 'threshold': inf},
    '2_ExpSynJax': {'tau': inf},
    '3_LinearJax': {'weight': inf}
# - Impose a lower bound for the time constants
lower_bounds["1_LIFJax"]["tau_syn"] = 11 * dt
lower_bounds["1_LIFJax"]["tau_mem"] = 11 * dt
lower_bounds['1_LIFJax']['threshold'] = 0.1
lower_bounds["2_ExpSynJax"]["tau"] = 11 * dt
print("lower_bounds:", lower_bounds)
    '0_LinearJax': {'weight': -inf},
    '1_LIFJax': {'bias': -inf, 'tau_mem': 0.011, 'tau_syn': 0.011, 'threshold': 0.1},
    '2_ExpSynJax': {'tau': 0.011},
    '3_LinearJax': {'weight': -inf}
import as l
import jax.numpy as jnp

# - Define a loss function
def losses(params, net, input, target):
    # - Reset the network state
    net = net.reset_state()

    # - Clip the parameters to bounds
    params_clip = bounds_clip(params, lower_bounds, upper_bounds)

    # - Apply the parameters
    net = net.set_attributes(params_clip)

    # - Evolve the network
    output, _, states = net(input, record=True)

    # - Calculate parameter bounds
    bounds = bounds_cost(params, lower_bounds, upper_bounds)

    # - Add an L2 norm to the parameters
    l2 = l.l2sqr_norm(params)

    # - Add cost to states
    act_sum = 0.#np.mean(states['1_LIFJax']['isyn']**2)

    # - Add cost to output
    out_sum = np.mean(output**2)

    # - Return the loss
    return jnp.array([l.mse(output, target), bounds / 10, 100. * l2, act_sum * 10., out_sum / 100.])

def loss(params, net, input, target):
    return losses(params, net, input, target).sum()

Below we define a training loop that uses a gradient-descent optimisation algorithm (“Adam”, provided by Jax) to iteratively optimise the network parameters. We keep track of the loss value for each iteration for later visualisation.

# - Useful imports
from tqdm.autonotebook import tqdm
from copy import deepcopy
from itertools import count

# -- Import an optimiser to use and initalise it
import jax
from jax.example_libraries.optimizers import adam

# - Get the optimiser functions
init_fun, update_fun, get_params = adam(1e-4)

# - Initialise the optimiser with the initial parameters
params0 = copy.deepcopy(modFFwd.parameters())
opt_state = init_fun(modFFwd.parameters())

# - Get a compiled value-and-gradient function
loss_vgf = jax.jit(jax.value_and_grad(loss))
loss_gf = jax.jit(jax.grad(loss))
loss_f = jax.jit(loss)

# - Compile the optimiser update function
update_fun = jax.jit(update_fun)

# - Record the loss values over training iterations
loss_t = []
grad_t = []
losses_t = []

num_epochs = 10000
# Test loss function
loss_gf(params0, modFFwd, input_sp_raster * input_scale, chirp);
loss_f(params0, modFFwd, input_sp_raster * input_scale, chirp);
loss_vgf(params0, modFFwd, input_sp_raster * input_scale, chirp);
# - Loop over iterations
i_trial = count()

for _ in tqdm(range(num_epochs)):
    # - Get parameters for this iteration
    params = get_params(opt_state)

    # - Get the loss value and gradients for this iteration
    loss_val, grads = loss_vgf(params, modFFwd, input_sp_raster * input_scale, chirp)
    losses_t.append(losses(params, modFFwd, input_sp_raster * input_scale, chirp))

    # - Update the optimiser
    opt_state = update_fun(next(i_trial), grads, opt_state)

    # - Keep track of the loss
    losses_t.append(losses(params, modFFwd, input_sp_raster * input_scale, chirp))
# - Plot the loss
plt.xlabel("Training iteration")
plt.title("Training progress");
plt.legend(['mse', 'bounds', 'l2', 'act', 'out'])
<matplotlib.legend.Legend at 0x318b4dac0>

Plot the ouput of the trained network

The MSE loss decreased — so far, so good. But what has the network learned?

# - Simulate with trained parameters
modFFwd = modFFwd.set_attributes(get_params(opt_state))
modFFwd = modFFwd.reset_state()
output_ts, _, record_dict = modFFwd(input_sp_raster * input_scale)

# - Compare the output to the target
plt.plot(chirp, lw=3)
plt.title("Output vs target")

# - Plot the internal state of selected neurons

Plot the network parameters

Let’s see how much the network parameters changed. Since the initial parameter set was random, we’ll plot the difference between the trained and initial parameters \(\theta^* - \theta\).

modIn = modFFwd[0]
modLIF = modFFwd[1]
modOut = modFFwd[3]
# - Plot the change in input weights
w_diff = modIn.weight - params0["0_LinearJax"]["weight"]
lim = np.max(np.abs(w_diff))
plt.imshow(w_diff, aspect="auto")
plt.title("Input weight change $w^*_{in}-w_{in}$")
plt.clim([-lim, lim])

# - Plot the change in output weights
plt.stem(modOut.weight - params0["3_LinearJax"]["weight"])
plt.title("Output weight change $w^*_{out}-w_{out}$")

# - Plot the distribution of final time constants
plt.hist(np.array(modLIF.tau_mem.flatten()) * 1e3, 20)
plt.xlabel("$\\tau_{mem}$ (ms)")
plt.title("Histogram of membrane time constants $\\tau_{mem}$")

plt.hist(np.array(modLIF.tau_syn.flatten()) * 1e3, 20)
plt.xlabel("$\\tau_{syn}$ (ms)")
plt.title("Histogram of synaptic time constants $\\tau_{syn}$")

# - Plot the distribution of final biases
plt.hist(np.array(modLIF.bias.flatten()), 20)
plt.xlabel("Bias value $I_{bias}$")
plt.title("Histogram of neuron biases $I_{bias}$");

The power of automatic differentiation is that almost for free, we get to optimise not just the weights, but all time constants and biases simultaneously. And we didn’t have to compute the gradients by hand!

As a sanity check, let’s see how the trained network responds if we give it a different random noise input.

spiking_prob = 0.01
sp_rand_ts = np.random.rand(T, Nin) < spiking_prob
# - Simulate with trained parameters
modFFwd = modFFwd.set_attributes(get_params(opt_state))
modFFwd = modFFwd.reset_state()
output_ts, _, record_dict = modFFwd(sp_rand_ts * input_scale)

# - Compare the output to the target
plt.plot(chirp, lw=3)
plt.title("Output vs target")

# - Plot the internal state of selected neurons

As expected, the network doesn’t do anything sensible with data it has never seen.


This approach can be used identically to train recurrent spiking networks, as well as multi-layer (i.e. deep) networks.