Giter VIP home page Giter VIP logo

wavetransformer's Introduction

WaveTransformer Repository

Welcome to the repository of the paper WaveTransformer: A Novel Architecture for Audio Captioning Based on Learning Temporal and Time-Frequency Information If you want to reproduce the results of the paper and know what you are doing, then jump ahead, get the pre-trained weights from here and run the inference code as shown here If you want to re-train WaveTransformer, then you can use the master branch, as it has the code based on the most up-to-date version of PyTorch.

There is also an on-line demo of the WaveTransformer.

If you need some help on using WaveTransformer, please read the following instructions.

How do I use WaveTransformer

Setting up the environment

To start using the audio captioning WaveTransformer, firstly you have to set-up the code. Please note bold that the code in this repository is tested with Python 3.7 or 3.6.

To set-up the code, you have to do the following:

1. Clone this repository.
2. Install dependencies

Use the following command to clone this repository at your terminal:

$ git clone [email protected]:haantran96/wavetransformer.git

To install the dependencies, you can use pip. It is advisable to run this system with a virtual environment to avoid package conflicts

$ pip install -r requirement_pip.txt

Dataset setup

Please go to DCASE2020's Baseline repository, part Preparing the data to download and set up the data.

Create a dataset

To create the dataset, you can either run the script processes/dataset.py using the command:

$ python processes/dataset.py

or run the system using the main.py script. In any case, the dataset creation will start.

You can select if you want to have the validation of the data by altering the validate_dataset parameter at the settings/dataset_creation.yaml file.

The result of the dataset creation process will be the creation of the directories:

1. `data/data_splits`,
2. `data/data_splits/development`,
3. `data/data_splits/evaluation`, and
4. `data/pickles`

The directories in data/data_splits have the input and output examples for the optimization and assessment of the baseline DNN. The data/pickles directory holds the pickle files that have the frequencies of the words and characters (so one can use weights in the objective function) and the correspondence of words and characters with indices.

Note bold: Once you have created the dataset, there is no need to create it every time. That is, after you create the dataset using the baseline system, then you can set

workflow:
  dataset_creation: No

at the settings/main_settings.yaml file.

Using the pre-trained weights for inference

The pre-trained weights are stored at outputs/models directory. Please be noted that the pre-trained weights are different for each different model.

In the settings folder, there are the following files:

  1. dirs_and_files.yaml: Stores the locations of the according files. For example:
root_dirs:
  outputs: 'outputs'
  data: 'data'
# -----------------------------------
dataset:
  development: &dev 'development'
  evaluation: &eva 'evaluation'
  validation: &val 'validation'
  features_dirs:
    output: 'data_splits'
    development: *dev
    evaluation: *eva
    validation: *val
  audio_dirs:
    downloaded: 'clotho_audio_files'
    output: 'data_splits_audio'
    development: *dev
    evaluation: *eva
    validation: *val
  annotations_dir: 'clotho_csv_files'
  pickle_files_dir: 'pickles'
  files:
    np_file_name_template: 'clotho_file_{audio_file_name}_{caption_index}.npy'
    words_list_file_name: 'words_list.p'
    words_counter_file_name: 'words_frequencies.p'
    characters_list_file_name: 'characters_list.p'
    characters_frequencies_file_name: 'characters_frequencies.p'
    validation_files_file_name: 'validation_file_names.p'
# -----------------------------------
model:
  model_dir: 'models'
  checkpoint_model_name: 'model_name.pt'
  pre_trained_model_name: 'best_model_name.pt'
# -----------------------------------
logging:
  logger_dir: 'logging'
  caption_logger_file: 'caption_file.txt'

Most important directories are: feature_dirs/output and model, as you must specify the locations of the /data and model paths according. Noted: by default, the code will save current best model as best_checkpoint_model_name.pt, so it is advisable to always set model/pre_trained_model_name as best_checkpoint_model_name.pt.

  1. main_settings.yaml. As mentioned, if you have already created the database, please set dataset_creation: No. For inference, please set dnn_training: No as shown below:
workflow:
  dataset_creation: No
  dnn_training: No
  dnn_evaluation: Yes
# ---------------------------------
dataset_creation_settings: !include dataset_creation.yaml
# -----------------------------------
feature_extraction_settings: !include feature_extraction.yaml
# -----------------------------------
dnn_training_settings: !include method.yaml
# -----------------------------------
dirs_and_files: !include dirs_and_files.yaml
# EOF
  1. method.yaml: contain different hyperparameters. This is the setting for the best models:
model: !include model.yaml
# ----------------------
data:
  input_field_name: 'features'
  output_field_name: 'words_ind'
  load_into_memory: No
  batch_size: 12 
  shuffle: Yes
  num_workers: 30
  drop_last: Yes
  use_multiple_mode: No
  use_validation_split: Yes 
# ----------------------
training:
  nb_epochs: 300
  patience: 10
  loss_thr: !!float 1e-4
  optimizer:
    lr: !!float 1e-3
  grad_norm:
    value: !!float 1.
    norm: 2
  force_cpu: No
  text_output_every_nb_epochs: !!int 10
  nb_examples_to_sample: 100
  use_class_weights: Yes
  use_y: Yes
  clamp_value_freqs: -1  # -1 is for ignoring
  # EOF
  1. model.yaml: The settings are different for different models. However, this line should be set to "Yes" to do the inference: use_pre_trained_model: Yes

*Please use the according files for reference:

  inner_kernel_size_encoder: 5
  inner_padding_encoder: 2
  pw_kernel_encoder: 5
  pw_padding_encoder: 2
  inner_kernel_size_encoder: 5
  inner_padding_encoder: 2
  pw_kernel_encoder: 5
  pw_padding_encoder: 2
  merge_mode_encoder: 'mean'
  inner_kernel_size_encoder: 5
  inner_padding_encoder: 2

Finally, to run the whole inference code:

python main.py -c main_settings -j $ID

main_settings should be the same name with your main_settings.yaml file.

Re-training WaveTransformer

The process for retraining are the same like inference. However, you must change as the following:

  1. main_settings.yaml. As mentioned, if you have already created the database, please set dataset_creation: No. For training, please set dnn_training: Yes as shown below:
workflow:
  dataset_creation: No
  dnn_training: Yes
  dnn_evaluation: Yes
  1. method.yaml: make changes as to the indicated hyperparameters
  2. model.yaml: this line should be set to "No" to do the training (from scratch):

use_pre_trained_model: No

If you wish to continue training, you can also set use_pre_trained_model to Yes.

Acknowledgement

The implementation of the codebase is adapted (with some modifications) from the following works:

  1. For WaveNet implementation: https://www.kaggle.com/c/liverpool-ion-switching/discussion/145256
  2. For Transformer implementation: https://nlp.seas.harvard.edu/2018/04/03/attention.html
  3. For beam search decoding: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding
  4. For Depthwise separable convolution implementation: https://github.com/dr-costas/dnd-sed

wavetransformer's People

Contributors

an-tran528 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.