Giter VIP home page Giter VIP logo

beamtreerecursivecells's Introduction

Official Code for Beam Tree Recursive Cells (ICML 2023)

Credits:

Requirements

  • torch==1.10.0
  • tqdm==4.62.3
  • jsonlines==2.0.0
  • torchtext==0.8.1
  • ninja==1.10.2
  • typing-extensions==4.5.0
  • psutil==5.8.0

Data Setup

You can verify if the data is properly set up from the directory tree here.

Processing Data

  • Go to preprocess/ and run each preprocess files to preprocess the corresponding data (process_SNLI_addon.py must be run after process_SNLI.py; otherwise no order requirement)

We share some of the processed data with its exact splits here (put the processed_data folder in the outermost project directory).

How to train

Train: python trian.py --model=[insert model name] -- dataset=[insert dataset name] --times=[insert total runs] --device=[insert device name] --model_type=[classifier/sentence_pair]

  • Check argparser.py for exact options.
  • Model type sentence_pair represents sentence-matching tasks like NLI. Modely type classifier represents simple sentence classification tasks.
  • Generally we use total times as 3.

Tree Parsing

  • For tree parsing from a classifier model: python extract_trees_classifier.py --model=[insert model name] --device=[insert device name] -- dataset=[insert dataset name]
  • For tree parsing from a sentence-paur matching model: python extract_trees_nli.py --model=[insert model name] --device=[insert device name] -- dataset=[insert dataset name]

Inputs for parsing can be modified from inside a list in extract_trees_classifier.py or python extract_trees_nli.py (line 66)

Dataset Nomenclature

The dataset nomenclature in the codebase and in the paper are a bit different. We provide a mapping here of the form ([codebase dataset name] == [paper dataset name])

  • listopsc == ListOps
  • listopsd == ListOps-DG
  • listops_ndr50 == ListOps-DG1
  • listops_ndr100 == ListOps-DG2
  • proplogic == Logical Inference (Operator generalization split)
  • proplogic_C == Logical Inference (C-split for systematic generalization)
  • SST2 == SST2
  • SST5 == SST5
  • IMDB == IMDB
  • MNLIdev == MNLI

The speed-suffixed names are for stress tests.

Model Nomenclature

The model nomenclature in the codebase and in the paper are a bit different. We provide a mapping here of the form ([codebase model name] == [paper model name])

  • RCell == RecurrentGRC
  • BalancedTreeCell == BalancedTreeGRC
  • RandomTreeCell == RandomTreeGRC
  • GoldTreeCell == GoldTreeGRC
  • GumbelTreeLSTM == GumbelTreeLSTM
  • GumbelTreeCell == GumbelTreeGRC
  • MCGumbelTreeCell == MCGumbelTreeGRC
  • CYKCell == CYK-GRC
  • OrderedMemory = Ordered Memory
  • CRvNN == CRvNN
  • CRvNN_worst == CRvNN without halt (during stress test)
  • BSRPCell == BSRP-GRC (beam 5)
  • BigBSRPCell == BSRP-GRC (beam 8)
  • NDR = NDR (Neural Data Router)
  • BeamTreeLSTM == BT-LSTM (beam 5)
  • BeamTreeCell == BT-GRC (beam 5)
  • SmallerBeamTreeCell == BT-GRC (beam 2)
  • DiffBeamTreeCell == BT-GRC + OneSoft (beam 5)
  • SmallerDiffBeamTreeCell == BT-GRC + OneSoft (beam 2)
  • DiffSortBeamTreeCell == BT-GRC + SOFT (beam 5)

Citation

@InProceedings{Chowdhury2023beam,
  title = 	 {Beam Tree Recursive Cells},
  author =       {Ray Chowdhury, Jishnu and Caragea, Cornelia},
  booktitle = 	 {Proceedings of the 40th International Conference on Machine Learning},
  year = 	 {2023}
}

Contact the associated github email for any question or issue.

beamtreerecursivecells's People

Contributors

jrc1995 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.