Giter VIP home page Giter VIP logo

tcrclassifier's Introduction

CSC 7343 - Project Report

Group: Muhammad Hussain and André Longon

Pretraining

To deviate from the predictive pretraining of GPT and BERT, we decided to formulate a contrastive learning self-supervised objective. Similar to contrastive learning in vision, we generated augmented data points from each batch. For each datapoint (sequence of tokens), we randomly replace one of the non-reserved tokens which represent the antigen and TCR sequences with the special <mask> token. We perform this augmentation twice to get two projections of each data point.

We then pass the two projections through the model and extract the first vector of the output sequence. We then compute the loss using the contrastive loss function used in SimCLR. Where sim(u,v) is the dot product between two normalized vectors.

download

The loss function is dependent on the batch size because the batch size determines how many pairs of positive and negative samples are considered in each iteration of training. When you have a larger batch size, the average has the effect of considering the agreement or disagreement across more pairs. To do this on a limited resource budget we used gradient accumulation. We average the gradient of N steps and then update the model, instead of every forward-backward pass.

This approach was inspired by contrastive learning in computer vision where a model is encouraged to embed an image and its noised/blurred counterpart similarly. Also of inspiration is the intuition that we are somewhat tolerant to omitted letters in reading text. By training the network to be resilient to augmentations and to distance different points, we hope the network will build some basic knowledge of the data before its supervised fine-tuning. Let us see if we accomplish this.

Results

The model with no pre-training, achieves an average accuracy of 52.4% after training for 3 epochs with a batch size of 1024. With pre-training, the model was able to achieve an accuracy of 62.4% after 3 epochs of pre-training and 3 epochs of fine tuning with a batch size of 1024. The optimizer used for pretraining was SGD with a learning rate of 1e-3 and the optimizer used for supervised fine-tuning was AdamW with a learning rate of 1e-5.

Epochs Learning rate Batch size Avg Accuracy %
No pre-training 3 1e-5 1024 52.4
Pre-training 3, 3 (Fine-tuning) 1e-3, 1e-5 (Fine-tuning) 1024 64.7

Avg accuracy reported is from 3 fold cross validation

Iteration #1 Iteration #2 Iteration #3
No pre-training 57.9 56.9 42.1
Pre-training 60.3 66.5 67.4

Accuracy (%) each iteration of 3 fold cross validation

Discussion

The approach poses challenges in hyperparameter tuning, demands considerable computational resources and makes it necessary to use large batch sizes. The exploration of an effective set of hyperparameters, such as choice of optimizer, temperature, batch size and gradient accumulation iterator are crucial for successful implementation. The reliance on large batch sizes adds to the computational overhead influencing the scalability of the approach.

While we achieved a performance improvement, it may not be significant enough to justify the additional computation, especially when scaled up on a larger dataset. Further work could explore better augmentation schemes then the masking we presented. Compared with images, it was difficult to conceptualize augmentations on tokens that made intuitive sense from an embedding similarity standpoint. The slight performance boost is encouraging to pursue alternative augmentations and even different contrastive loss objectives.

References

A Simple Framework for Contrastive Learning of Visual Representations, arxiv.org/pdf/2002.05709v3.pdf. Accessed 4 Dec. 2023.

Nikolas Adloglou, Implementing SimCLR with pytorch lightning, theaisummer.com/simclr Accessed 4 Dec. 2023.

Konkle, T., Alvarez, G.A. A self-supervised domain-general learning framework for human ventral stream representation. Nat Commun 13, 491 (2022).

tcrclassifier's People

Contributors

wahaj-47 avatar cest-andre avatar

Watchers

 avatar Kostas Georgiou avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.