Spiking Neuron Model

class snnpytorch.neuron.spiking_neuron.SpikingNeuronLayer(num_neurons=100, spiking_threshold=1, membrane_potential_decay_factor=0.1)

Bases: torch.nn.modules.module.Module

Class of a spiking neuron layer.

The iterative model for a neuron is inspired from the following papers:

‘Enabling Deep Spiking Neural Networks with Hybrid Conversion and Spike Timing Dependent Backpropagation’ by Rathi et al. , ICLR 2019, https://openreview.net/forum?id=B1xSperKvH

‘Spatio-Temporal Backpropagation for Training High-Performance Spiking Neural Networks’ by Wu et al., https://www.frontiersin.org/articles/10.3389/fnins.2018.00331/full

In PyTorch, the leaky neuron membrane voltage (u) is updated in a discrete manner. The membrane update at an iteration is the sum of input current (i) and a decaying previous membrane voltage. From this membrane voltage, spiking threshold (v) of spiked neurons are subtracted. Mathematically,

\[ \begin{align}\begin{aligned}u^{t} = \lambda u^{t-1} + i - vo^{t-1}\\o^{t-1} = 1 \: if \: u^{t-1} \:> v\: else \:0\end{aligned}\end{align} \]

where \(\lambda\) is the membrane time constant, t is the present timestep/iteration.

Parameters
  • num_neurons – Size of neuron layer

  • spiking_threshold – Spiking threshold voltage of neurons

  • membrane_potential_decay_factor – Membrane time constant

check_for_spikes()torch.Tensor

Check for membrane voltages which exceed the spiking threshold.

Returns

Binary tensor, 1 for spiked neurons, 0 for non-spiked neurons

compute_membrane_potentials(x: torch.Tensor) → None

Update neuron membrane potentials.

Parameters

x – Input Synaptic current

forward(x: torch.Tensor)torch.Tensor

Forward pass for this spiking neuron layer

Parameters

x – Synaptic current input

Returns

Binary tensor, 1 for spiked neurons, 0 for non-spiked neurons

initialize_states(layer_shape=None, model_device='cuda:0') → None

Initialize the layer variable parameters: membrane potential and neuron spiked.

Parameters
  • layer_shape – ( batch size, num_neurons )

  • model_device – ‘cpu’ or ‘cuda:0’