Giter VIP home page Giter VIP logo

pytorch-rnn-tutorial's Introduction

PyTorch tutorial on using RNNs and Encoder-Decoder RNNs for time series forcasting and hyperparameter tuning

Some blabber

This package resulted from my effort to write a simple PyTorch based ML package that uses recurrent neural networks (RNN) to predict a given time series data.

You must be wondering why you should bother with this package since there is a lot of stuff on the internet on this topic. Well, let me tell you that I have traveled the internet lanes and I was really frustrated by how scattered the information is in this context. It was a lot of effort to collect all the relevant parts from the internet and construct this package.

I had only a basic background in ML and zero knowledge of PyTorch (using Keras doesn't prepare you for PyTorch 😛) when I started writing this package. But that actually ended up being a blessing in disguise. Since I was starting from scratch, I was able to write the code in a way that was intuitive and easy to understand for people who are new to the subject.

So if you're feeling lost and frustrated, give this package a try. It might just help you understand not only RNNs, but PyTorch as well. And who knows, you might even have a little fun along the way.

Code Functionalities

  1. Many-to-One prediction using PyTorch's vanilla versions of RNN, LSTM, and GRU.
  2. Many-to-Many (or Seq2Seq) prediction using Encoder-Decoder architecture; base units could be RNN, LSTM, or GRU.
  3. Hyperparameter Tuning! It uses the Optuna library for that.
  4. Save PyTorch models, as well as reload and train them further.
  5. Works on any univariate and equispaced time series data.
  6. Can use GPUs.

Usage

Best way to figure out how to use this package is to check out the example notebooks available in the Notebooks folder.

I have also made a sample notebook available in Google Colab! Open In Colab

Code Structure

I have structured the code so that different operations are abstracted away in Python Classes. Here is a brief summary:

  • Model: Directory  -  contains classes which define the RNN models. RNN_Vanilla.py defines the Many-to-One RNN; the traditional kind. The EncDec.py file defines the Encode-Decoder class which uses the traditional RNN units as Encoder and Decoder modules, which are then combined together to provide a one-shot Many-to-Many prediction.

  • Notebooks: Directory - example notebooks which demonstrate how to use the code on a sample time series data consisting of multi frequency sin waves. It also contains a notebook which demonstrates how to perform hyperparameter tuning using Optuna.

  • Saved_models: Directory, empty - used to store the output from the Create_and_Train.py file.

  • Utils Directory - contains all the class files which do the data prep, training, testing, validation, and predicting.

    • Trainer.py contains the training loop, a test function to run the model on test data, as well as functions to make predictions.
    • SeqData.py file is used to create sequenced dataset, in torch tensors format, given a numpy 1D time series.
    • Create_and_Train.py is THE main file which creates a model (using the classes in the Model directory), runs the epoch loop, saves PyTorch models and train-test loss curves.
  • imports.py file is used by the notebooks present in the Notebooks folder.

  • requirements.txt file can be used in conjunction with pip to install the required packages.

Limitations

I haven't generalized the code to use multivariate time series data for the sake of simplicty. But, it is relatively easy to do. If interested, report in the repo's Issues section and we can collaborate!

Note

I also recommend checking out my colleague's implementation of rnn in pytorch.

pytorch-rnn-tutorial's People

Contributors

rakesh-yadav avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.