lindermanlab / s5 Goto Github PK
View Code? Open in Web Editor NEWLicense: MIT License
License: MIT License
Thanks for the amazing paper, I have a technical question about the paper: if a dataset contains samples with irregular time sampling, and the irregularity is not consistent across samples, i.e. sample 1 could be sampled at different timepoints compared to sample 2, would S5 be able to model the data well?
Hi! I'm trying to increasing the batch size on training cifar10 to 1500. However, in this way the GPU will run out of memory, I'm wondering if there's a solution for this since I'm planning on using S5 on tasks that will involve very large input. Here's the configuration (shell script in running experiment)
python run_train.py --C_init=lecun_normal --batchnorm=True --bidirectional=True \ --blocks=3 --bsz=600 --clip_eigs=True --d_model=512 --dataset=lra-cifar-classification \ --epochs=250 --jax_seed=16416 --lr_factor=4.5 --n_layers=6 --opt_config=BfastandCdecay \ --p_dropout=0.1 --ssm_lr_base=0.001 --ssm_size_base=384 --warmup_end=1 --weight_decay=0.07
Here's the full error message:
2024-04-06 22:22:12.820520: W external/xla/xla/service/hlo_rematerialization.cc:2218] Can't reduce memory use below 35.65GiB (38280953856 bytes) by rematerialization; only reduced to 36.48GiB (39170521860 bytes) 2024-04-06 22:22:35.625179: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 37.88GiB (rounded to 40675475712)requested by op Traceback (most recent call last): File "Path/S5/run_train.py", line 101, in <module> train(parser.parse_args()) File "Path/S5/s5/train.py", line 172, in train state, train_loss, step = train_epoch(state, File "Path/S5/s5/train_helpers.py", line 344, in train_epoch state, loss = train_step( File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, **params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 2677, in bind return self.bind_with_trace(top_trace, args, params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive return primitive.impl(*tracers, **params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss out_flat, compiled = _pjit_call_impl_python( File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1143, in _pjit_call_impl_python return compiled.unsafe_call(*args), compiled File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1349, in __call__ results = self.xla_executable.execute_sharded(input_bufs) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 40675475656 bytes.
While studying the difference between the experimental Hyena-S5 model (development branch) and H3, I've noticed that the S5SSM filter comes with an GeLU activation:
S5/configs/hyena_S5/wikitext_S5.yaml
Line 65 in 008bd54
The filter finishes with this GeLU activation, which means the activated value gets passed straight to the inner product. This seems strange compared to other approaches:
num_inner_mlps
(default 2).Is it intentional that the GeLU output is passed straight to the inner product?
First, thank you very much for the nice JAX implementation of Hyena as well as the recursive Hyena-S5!
As I'm trying to reproduce the results, there are several environments issues.
Just wonder if it's convenient to ask for the CUDA version as well as the python packages configurations for reproducing the results.
Thanks!
Hi,
Im interested in a recursive variant of the S5 approach, seeing as how I want to apply this to inherently sequential/interleaved tasks such as control. I think S5 could be a great fit there, and the single state space strikes me as more general than the S4 formulation. I intend to combine this as a linear pre-filter to feed into a nonlinear gated unit, along the lines of the original HIPPO paper; though I bet many variants are possible.
It occurs to me that this functionality should be trivial to build on top of the parallel scan implementation provided in this repo; the below appears to me like it should 'just work'; but if anyone has comments and suggestions, they would be very welcome.
from s5.ssm import S5SSM
class S5SSMRecursive(S5SSM):
def __call__(self, x, u):
C = 2*self.C_tilde if self.conj_sym else self.C_tilde
x = self.Lambda_bar * x + self.B_bar @ u
y = (C @ x).real + self.D * u
return x, y
def init_carry(self, key=None):
x = np.zeros(self.P)
return x + 1j * x
First i must say i love seeing new work in this area :)
The S4 repository is not designed to be installed. There are some standalone versions of the models, requiring copy pastes, and custom installations and compilation.
I would love it if by design you would add a pyproject.toml
or setup.py
and ideally even release this as a PyPi package.
I believe it will increase the use, and get S4 (S5) more popular among researchers
Dear authors,
thank you for your great work.
I was wondering, how hard would it be to make S5 model input-dependent like Mamba? On matrices B, C and delta, but also even on A?
If you do this with the current implementation, would it be drastically slower?
Hi, we got the pleasure to work with S5 in our research project. Most of the code works as expected. However, the cosine annealing lr schedule doesn't look like what I expected. Digging into the details, I found that the step
count progresses already during the warmup, which results in
args.warmup_end
instead of 0num_epochs - args.warmup
The result is the strange curve shown in the image below. A quick fix would be to reset step = 0
if epoch == args.warmup_end
.
I'm currently looking to reproduce the results from 6.3 of the S5 paper, but it seems like the experiment is missing in both the S5 model & the dataloading.
Would it be correct to 'hijack' the log_step
(
Line 216 in bdb9eda
log(dt)
of the observations?Hi,
First thanks for this awesome work and repo :) I'm trying to reproduce the results on PathX using this codebase, and I'm noticing high variance across seeds. I ran 10 different seeds (jax_seed=[0, 1, 2, 3, 4 ,5, 6, 7, 8, 9]). 6 of them worked as expected but 4 of them failed to learn anything. It seems that in your paper you take the average over 3 seeds, so I think this level of variance is strange. Any idea what could be wrong? Or is this level of variance expected?
Also, I have some questions about the code;
map_nested_fn
only acts on leave nodes in the tree. Is this a bug? Which one is the intended behavior?broadcast_dims=[0]
in dropout, for example here?Thanks in advance!
First of all, thank you for the well-organized repo! Apart from the jax installation you mentioned, it is very straight-foward to run the experiments.
However, since I am new to jax, it is not clear how to run a multi-gpu training. With the script provided, it seems only 1 GPU is operating, with minimal memory used by other GPUs.
Is there any additional measure I have to take to conduct a multi-gpu training?
Thanks in advance!
Hi, congrats on such an elegant and no doubt very high impact piece of work. And the JAX code in your paper is much appreciated! However, I was curious if you have plans for a more end-to-end JAX implementation on github. Right now im looking at porting the normal-hippo init to JAX for instance; seems doable, but the easier to reproduce the better I suppose.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.