Giter VIP home page Giter VIP logo

pytorch-widedeep's Introduction

PyPI version Python 3.7 3.8 3.9 Build Status Documentation Status codecov Code style: black Maintenance contributions welcome Slack

pytorch-widedeep

A flexible package to use Deep Learning with tabular data, text and images using wide and deep models.

Documentation: https://pytorch-widedeep.readthedocs.io

Companion posts and tutorials: infinitoml

Experiments and comparisson with LightGBM: TabularDL vs LightGBM

The content of this document is organized as follows:

  1. introduction
  2. The deeptabular component
  3. installation
  4. quick start (tl;dr)

Introduction

pytorch-widedeep is based on Google's Wide and Deep Algorithm

In general terms, pytorch-widedeep is a package to use deep learning with tabular data. In particular, is intended to facilitate the combination of text and images with corresponding tabular data using wide and deep models. With that in mind there are a number of architectures that can be implemented with just a few lines of code. The main components of those architectures are shown in the Figure below:

The dashed boxes in the figure represent optional, overall components, and the dashed lines/arrows indicate the corresponding connections, depending on whether or not certain components are present. For example, the dashed, blue-lines indicate that the deeptabular, deeptext and deepimage components are connected directly to the output neuron or neurons (depending on whether we are performing a binary classification or regression, or a multi-class classification) if the optional deephead is not present. Finally, the components within the faded-pink rectangle are concatenated.

Note that it is not possible to illustrate the number of possible architectures and components available in pytorch-widedeep in one Figure. Therefore, for more details on possible architectures (and more) please, see the documentation, or the Examples folders and the notebooks there.

In math terms, and following the notation in the paper, the expression for the architecture without a deephead component can be formulated as:

Where 'W' are the weight matrices applied to the wide model and to the final activations of the deep models, 'a' are these final activations, and φ(x) are the cross product transformations of the original features 'x'. In case you are wondering what are "cross product transformations", here is a quote taken directly from the paper: "For binary features, a cross-product transformation (e.g., “AND(gender=female, language=en)”) is 1 if and only if the constituent features (“gender=female” and “language=en”) are all 1, and 0 otherwise".

While if there is a deephead component, the previous expression turns into:

I recommend using the wide and deeptabular models in pytorch-widedeep. However it is very likely that users will want to use their own models for the deeptext and deepimage components. That is perfectly possible as long as the the custom models have an attribute called output_dim with the size of the last layer of activations, so that WideDeep can be constructed. Again, examples on how to use custom components can be found in the Examples folder. Just in case pytorch-widedeep includes standard text (stack of LSTMs) and image (pre-trained ResNets or stack of CNNs) models.

The deeptabular component

It is important to emphasize that each individual component, wide, deeptabular, deeptext and deepimage, can be used independently and in isolation. For example, one could use only wide, which is in simply a linear model. In fact, one of the most interesting functionalities inpytorch-widedeep would be the use of the deeptabular component on its own, i.e. what one might normally refer as Deep Learning for Tabular Data. Currently, pytorch-widedeep offers the following different models for that component:

  1. TabMlp: a simple MLP that receives embeddings representing the categorical features, concatenated with the continuous features.
  2. TabResnet: similar to the previous model but the embeddings are passed through a series of ResNet blocks built with dense layers.
  3. TabNet: details on TabNet can be found in TabNet: Attentive Interpretable Tabular Learning

And the Tabformer family, i.e. Transformers for Tabular data:

  1. TabTransformer: details on the TabTransformer can be found in TabTransformer: Tabular Data Modeling Using Contextual Embeddings.
  2. SAINT: Details on SAINT can be found in SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training.
  3. FT-Transformer: details on the FT-Transformer can be found in Revisiting Deep Learning Models for Tabular Data.
  4. TabFastFormer: adaptation of the FastFormer for tabular data. Details on the Fasformer can be found in FastFormers: Highly Efficient Transformer Models for Natural Language Understanding
  5. TabPerceiver: adaptation of the Perceiver for tabular data. Details on the Perceiver can be found in Perceiver: General Perception with Iterative Attention

Note that while there are scientific publications for the TabTransformer, SAINT and FT-Transformer, the TabFasfFormer and TabPerceiver are our own adaptation of those algorithms for tabular data.

For details on these models and their options please see the examples in the Examples folder and the documentation.

Installation

Install using pip:

pip install pytorch-widedeep

Or install directly from github

pip install git+https://github.com/jrzaurin/pytorch-widedeep.git

Developer Install

# Clone the repository
git clone https://github.com/jrzaurin/pytorch-widedeep
cd pytorch-widedeep

# Install in dev mode
pip install -e .

Important note for Mac users: at the time of writing the latest torch release is 1.9. Some past issues when running on Mac, present in previous versions, persist on this release and the data-loaders will not run in parallel. In addition, since python 3.8, the multiprocessing library start method changed from 'fork' to'spawn'. This also affects the data-loaders (for any torch version) and they will not run in parallel. Therefore, for Mac users I recommend using python 3.7 and torch <= 1.6 (with the corresponding, consistent version of torchvision, e.g. 0.7.0 for torch 1.6). I do not want to force this versioning in the setup.py file since I expect that all these issues are fixed in the future. Therefore, after installing pytorch-widedeep via pip or directly from github, downgrade torch and torchvision manually:

pip install pytorch-widedeep
pip install torch==1.6.0 torchvision==0.7.0

None of these issues affect Linux users.

Quick start

Binary classification with the adult dataset using Wide and DeepDense and defaults settings.

Building a wide (linear) and deep model with pytorch-widedeep:

import pandas as pd
import numpy as np
import torch
from sklearn.model_selection import train_test_split

from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import WidePreprocessor, TabPreprocessor
from pytorch_widedeep.models import Wide, TabMlp, WideDeep
from pytorch_widedeep.metrics import Accuracy

# the following 4 lines are not directly related to ``pytorch-widedeep``. I
# assume you have downloaded the dataset and place it in a dir called
# data/adult/
df = pd.read_csv("data/adult/adult.csv.zip")
df["income_label"] = (df["income"].apply(lambda x: ">50K" in x)).astype(int)
df.drop("income", axis=1, inplace=True)
df_train, df_test = train_test_split(df, test_size=0.2, stratify=df.income_label)

# prepare wide, crossed, embedding and continuous columns
wide_cols = [
    "education",
    "relationship",
    "workclass",
    "occupation",
    "native-country",
    "gender",
]
cross_cols = [("education", "occupation"), ("native-country", "occupation")]
embed_cols = [
    ("education", 16),
    ("workclass", 16),
    ("occupation", 16),
    ("native-country", 32),
]
cont_cols = ["age", "hours-per-week"]
target_col = "income_label"

# target
target = df_train[target_col].values

# wide
wide_preprocessor = WidePreprocessor(wide_cols=wide_cols, crossed_cols=cross_cols)
X_wide = wide_preprocessor.fit_transform(df_train)
wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)

# deeptabular
tab_preprocessor = TabPreprocessor(embed_cols=embed_cols, continuous_cols=cont_cols)
X_tab = tab_preprocessor.fit_transform(df_train)
deeptabular = TabMlp(
    mlp_hidden_dims=[64, 32],
    column_idx=tab_preprocessor.column_idx,
    embed_input=tab_preprocessor.embeddings_input,
    continuous_cols=cont_cols,
)

# wide and deep
model = WideDeep(wide=wide, deeptabular=deeptabular)

# train the model
trainer = Trainer(model, objective="binary", metrics=[Accuracy])
trainer.fit(
    X_wide=X_wide,
    X_tab=X_tab,
    target=target,
    n_epochs=5,
    batch_size=256,
    val_split=0.1,
)

# predict
X_wide_te = wide_preprocessor.transform(df_test)
X_tab_te = tab_preprocessor.transform(df_test)
preds = trainer.predict(X_wide=X_wide_te, X_tab=X_tab_te)

# Save and load

# Option 1: this will also save training history and lr history if the
# LRHistory callback is used
trainer.save(path="model_weights", save_state_dict=True)

# Option 2: save as any other torch model
torch.save(model.state_dict(), "model_weights/wd_model.pt")

# From here in advance, Option 1 or 2 are the same. I assume the user has
# prepared the data and defined the new model components:
# 1. Build the model
model_new = WideDeep(wide=wide, deeptabular=deeptabular)
model_new.load_state_dict(torch.load("model_weights/wd_model.pt"))

# 2. Instantiate the trainer
trainer_new = Trainer(
    model_new,
    objective="binary",
)

# 3. Either start the fit or directly predict
preds = trainer_new.predict(X_wide=X_wide, X_tab=X_tab)

Of course, one can do much more. See the Examples folder, the documentation or the companion posts for a better understanding of the content of the package and its functionalities.

Testing

pytest tests

How to Contribute

Check CONTRIBUTING page.

Acknowledgments

This library takes from a series of other libraries, so I think it is just fair to mention them here in the README (specific mentions are also included in the code).

The Callbacks and Initializers structure and code is inspired by the torchsample library, which in itself partially inspired by Keras.

The TextProcessor class in this library uses the fastai's Tokenizer and Vocab. The code at utils.fastai_transforms is a minor adaptation of their code so it functions within this library. To my experience their Tokenizer is the best in class.

The ImageProcessor class in this library uses code from the fantastic Deep Learning for Computer Vision (DL4CV) book by Adrian Rosebrock.

pytorch-widedeep's People

Contributors

jrzaurin avatar 5uperpalo avatar hyonaldo avatar jin530 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.