Giter VIP home page Giter VIP logo

haoxiangsnr / spiking-fullsubnet Goto Github PK

View Code? Open in Web Editor NEW
43.0 3.0 8.0 155.54 MB

Official repository of Spiking-FullSubNet, the Intel N-DNS Challenge Algorithmic Track Winner.

Home Page: https://haoxiangsnr.github.io/spiking-fullsubnet/

License: MIT License

Python 99.38% Jupyter Notebook 0.62%
neuromorphic-computing noise-reduction speech-denoising speech-enhancement neuromorphic-audio-processing

spiking-fullsubnet's Introduction

Stargazers Forks Contributors Issues MIT License


Spiking-FullSubNet

Intel N-DNS Challenge Algorithmic Track Winner
Explore the docs »

View Demo · Report Bug · Request Feature

About The Project

Spiking-FullSubNet

We are proud to announce that Spiking-FullSubNet has emerged as the winner of Intel N-DNS Challenge Track 1 (Algorithmic). Please refer to our brief write-up here for more details. This repository serves as the official home of the Spiking-FullSubNet implementation. Here, you will find:

  • A PyTorch-based implementation of the Spiking-FullSubNet model.
  • Scripts for training the model and evaluating its performance.
  • The pre-trained models in the model_zoo directory, ready to be further fine-tuned on the other datasets.

Updates

[2024-02-26] Currently, our repo contains two versions of the code:

  1. The frozen version, which serves as a backup for the code used in a previous competition. However, due to a restructuring in the audiozen directory, this version can no longer be directly used for inference. If you need to verify the experimental results from that time, please refer to this specific commit: 38fe020. There you will find everything you need. After switching to this commit, you can place the checkpoints from the model_zoo into the exp directory and use -M test for inference or -M train to retrain the model.

  2. The latest version of the code has undergone some restructuring and optimization to make it more understandable for readers. We've also introduced acceleate to assist with better training practices. We believe you can follow the instructions in the help documentation to run the training code directly. The pre-trained model checkpoints and a more detailed paper will be released by next weekend, so please stay tuned for that.

Documentation

See the Documentation for installation and usage. Our team is actively working to improve the documentation. Please feel free to raise an issue or submit a pull request if you have suggestions for enhancements.

License

All the code in this repository is released under the MIT License, for more details see the LICENSE file.

spiking-fullsubnet's People

Contributors

haoxiangsnr avatar lava-nc-user avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

spiking-fullsubnet's Issues

Some Reproducibility problems

Hello. I have found a couple reproducibility issues. Some are easier to fix on others.

In '/recipes/intel_ndns/spiking_fullsubnet/dataloader.py' it is written self.noisy_files = glob.glob(root + "noisy/**.wav"). On my machine, "noisy/*.wav" works instead.

In recipes/intel_ndns/spiking_fullsubnet_freeze_phase/trainer.py the code will not execute because the following import statement is broken

from audiozen.trainer_backup.base_trainer_gan_accelerate_ddp_validate import BaseTrainer

BaseTrainer does not seem to exist in the repository, however Trainer is present.

Also, IntelSISNR is missing as well from the repository.

Lastly, the "recipes/intel_ndns/spiking_fullsubnet/exp" is missing the checkpoint files I believe. So you cannot resume training or test the model.

LSTM操作数计算

我注意到这个网络模型中含有LSTM结构,请问LSTM结构是如何确定其神经元操作数和突触操作数的呢

数据集获取

作为音频处理的爱好者,我试图复现您团队的训练过程。为了更好地复现,我希望获得您团队训练时所使用的训练集和验证集,但遗憾的是,我并非香港理工大学的学生,所以无法进入您提供的数据集服务器。请问还有其他方式可以下载您团队使用的训练集和验证集吗?

训练非常慢

你好,我设置的batch_size=12,音频长度4s,显卡是4090,训练非常慢,请问是不是GSN本身的算力非常大?

模型文件以及模型使用

你好,请问model_zoo目录下best目录中存在多个pkl文件,哪个文件是训练之后的模型文件呢,、以及如果我得到了训练好的模型,我该如何使用或调用这个模型来尝试对某个含噪音频进行去噪以更直观地体会到模型去噪后的效果呢

Missing and Unexpected keys in state_dict

I am trying to run inference using baseline_m.toml however the given checkpoint files seem to have the wrong state keys.

Given the following command:

 accelerate launch /mnt/c/Users/madha/code/spiking-fullsubnet/recipes/intel_ndns/spiking_fullsubnet/run.py -C /mnt/c/Users/madha/code/spiking-fullsubnet/recipes/intel_ndns/spiking_fullsubnet/baseline_m.toml -M test --ckpt_path /mnt/c/Users/madha/code/spiking-fullsubnet/model_zoo/intel_ndns/spike_fsb/baseline_m/checkpoints/best

I am getting the following error:

02-26 19:19:32: Initialized logger with log file in /mnt/c/Users/madha/code/spiking-fullsubnet/recipes/intel_ndns/spiking_fullsubnet/exp/baseline_m.
Loading dataset from /mnt/c/Users/madha/code/spiking-fullsubnet/datasets/validation_set/...
Found 3243 files.
02-26 19:19:37: Configuration file is saved to /mnt/c/Users/madha/code/spiking-fullsubnet/recipes/intel_ndns/spiking_fullsubnet/exp/baseline_m/config__2024_02_26--19_19_36.toml.
02-26 19:19:37: Environment information:
- `Accelerate` version: 0.27.2
- Platform: Linux-5.10.102.1-microsoft-standard-WSL2-x86_64-with-glibc2.31
- Python version: 3.10.13
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.2.1 (True)
- System RAM: 12.26 GB
- GPU Available: True
- GPU IDs: 1
- GPU type: NVIDIA T500
02-26 19:19:37:
 ==========================================================================================
Layer (type:depth-idx)                                            Param #
==========================================================================================
OptimizedModule                                                   --
├─SpikingFullSubNet: 1-1                                          --
│    └─SequenceModel: 2-1                                         --
│    │    └─LayerNorm: 3-1                                        128
│    │    └─StackedGSU: 3-2                                       330,240
│    │    └─Linear: 3-3                                           20,544
│    │    └─Identity: 3-4                                         --
│    └─SubbandModel: 2-2                                          --
│    │    └─ModuleList: 3-5                                       603,500
==========================================================================================
Total params: 954,412
Trainable params: 954,412
Non-trainable params: 0
==========================================================================================
Using device: 0
02-26 19:19:38: Begin testing...
02-26 19:19:38: Loading states from /mnt/c/Users/madha/code/spiking-fullsubnet/model_zoo/intel_ndns/spike_fsb/baseline_m/checkpoints/best
Traceback (most recent call last):
  File "/mnt/c/Users/madha/code/spiking-fullsubnet/recipes/intel_ndns/spiking_fullsubnet/run.py", line 151, in <module>
    run(config, args.resume)
  File "/mnt/c/Users/madha/code/spiking-fullsubnet/recipes/intel_ndns/spiking_fullsubnet/run.py", line 97, in run
    trainer.test(test_dataloaders, config["meta"]["ckpt_path"])
  File "/home/madhav/miniconda3/envs/spiking-fullsubnet/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/c/Users/madha/code/spiking-fullsubnet/audiozen/trainer.py", line 537, in test
    self._load_checkpoint(ckpt_path)
  File "/mnt/c/Users/madha/code/spiking-fullsubnet/audiozen/trainer.py", line 225, in _load_checkpoint
    self.accelerator.load_state(ckpt_path, map_location="cpu")
  File "/home/madhav/miniconda3/envs/spiking-fullsubnet/lib/python3.10/site-packages/accelerate/accelerator.py", line 2922, in load_state
    load_accelerator_state(
  File "/home/madhav/miniconda3/envs/spiking-fullsubnet/lib/python3.10/site-packages/accelerate/checkpointing.py", line 205, in load_accelerator_state
    models[i].load_state_dict(state_dict, **load_model_func_kwargs)
  File "/home/madhav/miniconda3/envs/spiking-fullsubnet/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2153, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for SpikingFullSubNet:
        Missing key(s) in state_dict: "fb_model.pre_layer_norm.weight", "fb_model.pre_layer_norm.bias", "fb_model.proj.weight", "fb_model.proj.bias", "sb_model.sb_models.0.pre_layer_norm.weight", "sb_model.sb_models.0.pre_layer_norm.bias", "sb_model.sb_models.0.proj.weight", "sb_model.sb_models.0.proj.bias", "sb_model.sb_models.1.pre_layer_norm.weight", "sb_model.sb_models.1.pre_layer_norm.bias", "sb_model.sb_models.1.proj.weight", "sb_model.sb_models.1.proj.bias", "sb_model.sb_models.2.pre_layer_norm.weight", "sb_model.sb_models.2.pre_layer_norm.bias", "sb_model.sb_models.2.proj.weight", "sb_model.sb_models.2.proj.bias".
        Unexpected key(s) in state_dict: "fb_model.fc_output_layer.weight", "fb_model.fc_output_layer.bias", "sb_model.sb_models.0.fc_output_layer.weight", "sb_model.sb_models.0.fc_output_layer.bias", "sb_model.sb_models.1.fc_output_layer.weight", "sb_model.sb_models.1.fc_output_layer.bias", "sb_model.sb_models.2.fc_output_layer.weight", "sb_model.sb_models.2.fc_output_layer.bias".
Traceback (most recent call last):
  File "/home/madhav/miniconda3/envs/spiking-fullsubnet/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/madhav/miniconda3/envs/spiking-fullsubnet/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/home/madhav/miniconda3/envs/spiking-fullsubnet/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1023, in launch_command
    simple_launcher(args)
  File "/home/madhav/miniconda3/envs/spiking-fullsubnet/lib/python3.10/site-packages/accelerate/commands/launch.py", line 643, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/home/madhav/miniconda3/envs/spiking-fullsubnet/bin/python', '/mnt/c/Users/madha/code/spiking-fullsubnet/recipes/intel_ndns/spiking_fullsubnet/run.py', '-C', '/mnt/c/Users/madha/code/spiking-fullsubnet/recipes/intel_ndns/spiking_fullsubnet/baseline_m.toml', '-M', 'test', '--ckpt_path', '/mnt/c/Users/madha/code/spiking-fullsubnet/model_zoo/intel_ndns/spike_fsb/baseline_m/checkpoints/best']' returned non-zero exit status 1.

Is additional configuration required to run inference on this model?

About Test Sets

The officially given method of downloading the test set is not working now, when trying to download the test set via GIT LFS, it reports the error
“Fetching test_set_1.zipfatal: could not resolve origin/test_set_1”
Would you be willing to share the datasets you have already downloaded, I would appreciate it!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.