Giter VIP home page Giter VIP logo

s5's People

Contributors

jimmysmith1919 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

s5's Issues

Irregular time sampling question

Thanks for the amazing paper, I have a technical question about the paper: if a dataset contains samples with irregular time sampling, and the irregularity is not consistent across samples, i.e. sample 1 could be sampled at different timepoints compared to sample 2, would S5 be able to model the data well?

Out of memory when batch size is large

Hi! I'm trying to increasing the batch size on training cifar10 to 1500. However, in this way the GPU will run out of memory, I'm wondering if there's a solution for this since I'm planning on using S5 on tasks that will involve very large input. Here's the configuration (shell script in running experiment)
python run_train.py --C_init=lecun_normal --batchnorm=True --bidirectional=True \ --blocks=3 --bsz=600 --clip_eigs=True --d_model=512 --dataset=lra-cifar-classification \ --epochs=250 --jax_seed=16416 --lr_factor=4.5 --n_layers=6 --opt_config=BfastandCdecay \ --p_dropout=0.1 --ssm_lr_base=0.001 --ssm_size_base=384 --warmup_end=1 --weight_decay=0.07
Here's the full error message:
2024-04-06 22:22:12.820520: W external/xla/xla/service/hlo_rematerialization.cc:2218] Can't reduce memory use below 35.65GiB (38280953856 bytes) by rematerialization; only reduced to 36.48GiB (39170521860 bytes) 2024-04-06 22:22:35.625179: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 37.88GiB (rounded to 40675475712)requested by op Traceback (most recent call last): File "Path/S5/run_train.py", line 101, in <module> train(parser.parse_args()) File "Path/S5/s5/train.py", line 172, in train state, train_loss, step = train_epoch(state, File "Path/S5/s5/train_helpers.py", line 344, in train_epoch state, loss = train_step( File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, **kwargs) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, **params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 2677, in bind return self.bind_with_trace(top_trace, args, params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/core.py", line 815, in process_primitive return primitive.impl(*tracers, **params) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss out_flat, compiled = _pjit_call_impl_python( File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/pjit.py", line 1143, in _pjit_call_impl_python return compiled.unsafe_call(*args), compiled File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, **kwargs) File "Path/miniconda3/envs/s5/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1349, in __call__ results = self.xla_executable.execute_sharded(input_bufs) jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 40675475656 bytes.

Hyena-S5 SSM has a strange activation setup

While studying the difference between the experimental Hyena-S5 model (development branch) and H3, I've noticed that the S5SSM filter comes with an GeLU activation:

activation: "gelu"

The filter finishes with this GeLU activation, which means the activated value gets passed straight to the inner product. This seems strange compared to other approaches:

  • H3 has no activation within the filter.
  • Hyena uses a MLP with its depth configured by num_inner_mlps (default 2).

Is it intentional that the GeLU output is passed straight to the inner product?

[Development branch] Request on update for development requirements.txt

First, thank you very much for the nice JAX implementation of Hyena as well as the recursive Hyena-S5!

As I'm trying to reproduce the results, there are several environments issues.
Just wonder if it's convenient to ask for the CUDA version as well as the python packages configurations for reproducing the results.

Thanks!

Recursive S5 implementation

Hi,

Im interested in a recursive variant of the S5 approach, seeing as how I want to apply this to inherently sequential/interleaved tasks such as control. I think S5 could be a great fit there, and the single state space strikes me as more general than the S4 formulation. I intend to combine this as a linear pre-filter to feed into a nonlinear gated unit, along the lines of the original HIPPO paper; though I bet many variants are possible.

It occurs to me that this functionality should be trivial to build on top of the parallel scan implementation provided in this repo; the below appears to me like it should 'just work'; but if anyone has comments and suggestions, they would be very welcome.

from s5.ssm import S5SSM
class S5SSMRecursive(S5SSM):
    
    def __call__(self, x, u):
        C = 2*self.C_tilde if self.conj_sym else self.C_tilde
        x = self.Lambda_bar * x + self.B_bar @ u
        y = (C @ x).real + self.D * u
        return x, y
        
    def init_carry(self, key=None):
        x = np.zeros(self.P)
        return x + 1j * x

design request: installable library

First i must say i love seeing new work in this area :)

I have a request to make:

The S4 repository is not designed to be installed. There are some standalone versions of the models, requiring copy pastes, and custom installations and compilation.

I would love it if by design you would add a pyproject.toml or setup.py and ideally even release this as a PyPi package.
I believe it will increase the use, and get S4 (S5) more popular among researchers

Mamba + S5?

Dear authors,
thank you for your great work.

I was wondering, how hard would it be to make S5 model input-dependent like Mamba? On matrices B, C and delta, but also even on A?

If you do this with the current implementation, would it be drastically slower?

Bug with cosine annealing schedule

Hi, we got the pleasure to work with S5 in our research project. Most of the code works as expected. However, the cosine annealing lr schedule doesn't look like what I expected. Digging into the details, I found that the step count progresses already during the warmup, which results in

  • the cosine being evaluated at step args.warmup_end instead of 0
  • the cosine decaying to its minimum value at num_epochs - args.warmup

The result is the strange curve shown in the image below. A quick fix would be to reset step = 0 if epoch == args.warmup_end.

Screenshot from 2024-02-20 10-26-12

Reproducing "6.3 Variable Observation Interval"

I'm currently looking to reproduce the results from 6.3 of the S5 paper, but it seems like the experiment is missing in both the S5 model & the dataloading.

Would it be correct to 'hijack' the log_step (

S5/s5/ssm.py

Line 216 in bdb9eda

self.log_step = self.param("log_step",
) and substitute if for log(dt) of the observations?

Reproducing Results on PathX

Hi,

First thanks for this awesome work and repo :) I'm trying to reproduce the results on PathX using this codebase, and I'm noticing high variance across seeds. I ran 10 different seeds (jax_seed=[0, 1, 2, 3, 4 ,5, 6, 7, 8, 9]). 6 of them worked as expected but 4 of them failed to learn anything. It seems that in your paper you take the average over 3 seeds, so I think this level of variance is strange. Any idea what could be wrong? Or is this level of variance expected?

Also, I have some questions about the code;

  1. Here it seems that you intend to include LayerNorm/BatchNorm parameters in SSM parameters. However, this has no effect since map_nested_fn only acts on leave nodes in the tree. Is this a bug? Which one is the intended behavior?
  2. Is there a reason you use broadcast_dims=[0] in dropout, for example here?

Thanks in advance!

Mulit-gpu training

First of all, thank you for the well-organized repo! Apart from the jax installation you mentioned, it is very straight-foward to run the experiments.

However, since I am new to jax, it is not clear how to run a multi-gpu training. With the script provided, it seems only 1 GPU is operating, with minimal memory used by other GPUs.

Is there any additional measure I have to take to conduct a multi-gpu training?

Thanks in advance!

JAX implementation

Hi, congrats on such an elegant and no doubt very high impact piece of work. And the JAX code in your paper is much appreciated! However, I was curious if you have plans for a more end-to-end JAX implementation on github. Right now im looking at porting the normal-hippo init to JAX for instance; seems doable, but the easier to reproduce the better I suppose.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.