SpikingSSMs: Learning Long Sequences with Sparse and Parallel Spiking State Space Models
Shuaijie Shen*, Chao Wang*, Renzhuo Huang, Yan Zhong, Qinghai Guo, Zhichao Lu, Jianguo Zhang, Luziwei Leng
Paper: https://arxiv.org/abs/2408.14909
This repository provides the official implementations and experiments for SDN (Surrogate Dynamic Network) and SpikingSSMs (Spiking State Space Models).
This repository contains two core components:
- SDN: A lightweight module for simulating spiking neuron dynamics.
- SpikingSSMs: A novel architecture combining spiking neural networks with state space models for long-sequence tasks.
- Python 3.8+
- PyTorch ≥1.10
- loguru
Install via conda/pip:
# PyTorch with CUDA 11.8
conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia
# Loguru
pip install loguru
Clone the official S4 repository and install dependencies:
git clone https://github.com/state-spaces/s4.git
cd s4
# Follow S4's installation instructions
- Generate training data:
python generate.py dataset
- Generate test data:
python generate.py dataset -n test
Dataset Structure (dataset/):
training-mem-T1024-N(0.0,1.0)-50000-tau_0.2.pt
test-mem-T1024-N(0.0,1.0)-50000-tau_0.2.pt
Each file contains:
input
: Input current tensor (shape: [50000, 1024])mem
: Attenuated membrane potential (shape: [50000, 1024])spike
: Spike train (shape: [50000, 1024])
python train.py \
--training 'dataset/training-mem-T1024-N(0.0,1.0)-50000-tau_0.2.pt' \
--test 'dataset/test-mem-T1024-N(0.0,1.0)-50000-tau_0.2.pt' \
--save exp1
Logs and checkpoints will be saved in exp1/.
Optimize SDN for inference:
jupyter notebook convert.ipynb # Follow interactive instructions
- Clone and setup S4:
git clone https://github.com/state-spaces/s4.git
cd s4
# Install S4 dependencies (refer to their documentation)
- Integrate our components:
cp -r /path/to/this/repo/src ./src
cp -r /path/to/this/repo/models ./models
cp -r /path/to/this/repo/configs ./configs
- Run CIFAR-10 experiment:
python -m train experiment=spikingssm/cifar
If you use this work in your research, please cite:
@misc{shen2024spikingssms,
title={SpikingSSMs: Learning Long Sequences with Sparse and Parallel Spiking State Space Models},
author={Shuaijie Shen and Chao Wang and Renzhuo Huang and Yan Zhong and Qinghai Guo and Zhichao Lu and Jianguo Zhang and Luziwei Leng},
year={2024},
eprint={2408.14909},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2408.14909}
}
This project is licensed under the MIT License.