Giter VIP home page Giter VIP logo

hihaluemen / class-based-regularization-effects-in-nlp Goto Github PK

View Code? Open in Web Editor NEW

This project forked from apasunuri/class-based-regularization-effects-in-nlp

0.0 0.0 0.0 512 KB

An experimental study was performed to study the effects of Neural Network regularization on class-based accuracies in multiclass classification specifically for sequential models and language data.

Jupyter Notebook 100.00%

class-based-regularization-effects-in-nlp's Introduction

Class Dependent Effects of Regularization in Sequential Models and Language Data

Various forms of regularization are used to prevent over-fitting in nearly every widely used neural network model nowadays. Recent studies, specifically the paper The Effects of Regularization and Data Augmentation are Class Dependent by Yann LeCun have shown that while regularization can improve accuracy as a whole in image classification problems, the accuracy of certain classes is drastically lowered, even with uniformed regularizers like weight decay. This study explores whether these class-specific biases caused by regularization are present in Natural Language Processing (NLP) classification tasks. This study tests various types of sequential models of different complexities, including RNN, LSTM, and Pretrained BERT on different datasets with different numbers of classes. These models are then trained with different types of uninformed regularization. Ultimately, through performing these experiments it is empirically shown that more complex models, such as LSTMs and BERT trained and finetuned on high-class datasets are more prone to show class biases.

Code

The code for the study is organized within two directories specified as BERT and RNN-LSTM. The BERT directory contains two files bert_train.ipynb and bert_evaluate.ipynb. The bert_train.ipynb notebook contains code to load in a pretrained BERT model and the corresponding datasets and finetune the model for the task of Masked Language Modeling on different levels of L2 and Dropout Regularization methods. The bert_evaluate.ipynb notebook loads in the saved finetuned BERT models and evaluates the different class-specific test accuracies for the different models and generates a plot of different class accuracies for different levels of regularization on the model. The RNN-LSTM directory contains three files rnn_lstm_train.ipynb, rnn_lstm_evaluate.ipynb, and rnn_lstm_plots.ipynb. The rnn_lstm_train.ipynb notebook contains code to load and preprocess the different datasets, initialize the RNN and LSTM models, and train the models for multiclass classification on different levels of L1, L2, Dropout, and DropConnect Regularization methods. The rnn_lstm_evaluate.ipynb notebook loads in the different saved RNN and LSTM models and evaluates the different class-specific test accuracies for the different models on different levels of regularization. The rnn_lstm_plots.ipynb notebook generates plots of the class-specific test accuracies for the different models. The code for the RNN and LSTM models was written by myself and the code for the BERT model was written by my project collaborator, Noah McDermott.

class-based-regularization-effects-in-nlp's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.