Giter VIP home page Giter VIP logo

bohuanglab / cell-e_2 Goto Github PK

View Code? Open in Web Editor NEW
8.0 1.0 0.0 13.23 MB

Multimodal encoder-only transformer model for image-based protein predictions

Home Page: https://bohuanglab.github.io/CELL-E_2/

License: MIT License

Python 99.29% Jupyter Notebook 0.71%
artificial-intelligence attention-mechanism deep-learning generative-ai language-modeling multi-modal protein-design protein-sequences pytorch text-to-image

cell-e_2's Introduction

CELL-E 2: Translating Proteins to Pictures and Back with a Bidirectional Text-to-Image Transformer

This repository is the official implementation of CELL-E 2: Translating Proteins to Pictures and Back with a Bidirectional Text-to-Image Transformer.

architecture

Requirements

Create a virtual environment and install the required packages via:

pip install -r requirements.txt

Next, install torch = 2.0.0 with the appropriate CUDA version

Model Downloads

Models are available on Hugging Face.

We also have two spaces available where you can run predictions on your own data!

Generating Images

image_generation

To generate images, set the saved model as the ckpt_path. This method can be unstable, so refer to Demo.ipynb to see another way of loading.

from omegaconf import OmegaConf
from celle_main import instantiate_from_config

configs = OmegaConf.load(configs/celle.yaml);

model = instantiate_from_config(configs.model).to(device);

model.sample(text=sequence,
                      condition=nucleus,
                      return_logits=True,
                      progress=True)

Sequence Prediction

model.sample_text(condition=nucleus,
                  image=image,
                  return_logits=True,
                  progress=True)

Training

Training for CELL-E occurs in 3 stages:

  • Training Protein Threshold Image Encoder
  • Training a Nucleus Image Encoder
  • Training CELL-E Transformer

VQGANs

If using the protein threshold image, set threshold: True for the dataset.

We use a slightly modified version of the taming-transformers code.

To train, run the following script:

python celle_taming_main.py --base configs/threshold_vqgan.yaml -t True

Please refer to the original repo for additional flags, such as --gpus.

Preparing Dataset

Images

We provide scripts for downloading Human Protein Atlas and OpenCell images in the scripts folder. A data_csv is needed to for the dataloader. You must generate a csv file which contains the columns nucleus_image_path, protein_image_path, metadata_path, split (train or val), and sequence (optional). It is assumed that this file exists within the the same general data folder as the images and metadata files.

Metadata

Metadata is a JSON which should accompany every protein sequence. If a sequence does not appear in the data_csv, it must appear in metadata.json with the a key named protein_sequence.

Adding more information here can be useful for querying individual proteins. They can be retrieved via retrieve_metadata, which creates a self.metadata variable within the dataset object.

To train, run the following script:

python celle_main.py --base configs/celle.yaml -t True

Specify --gpus in the same format as VQGAN.

CELL-E contains the following options:

  • ckpt_path : Resume previous CELL-E 2 training. Saved model with state_dict
  • vqgan_model_path : Saved protein image model (with state_dict) for protein image encoder
  • vqgan_config_path: Saved protein image model yaml
  • condition_model_path : Saved condition (nucleus) model (with state_dict) for protein image encoder
  • condition_config_path: Saved condition (nucleus) model yaml
  • num_images: 1 if only using protein image encoder, 2 if including condition image encoder
  • image_key: nucleus, target, or threshold
  • dim: Dimension of language model embedding
  • num_text_tokens: total number of tokens in language model (33 for ESM-2)
  • text_seq_len: Total number of amino acids considered
  • depth: Transformer model depth, deeper is usually better at the cost of VRAM
  • heads: number of heads used in multi-headed attention
  • dim_head: size of attention heads
  • attn_dropout: Attention Dropout rate in training
  • ff_dropout: Feed-Forward Dropout rate in training
  • loss_img_weight: Weighting applied to image reconstruction. text weight = 1
  • loss_text_weight: Weighting applied to condition image reconstruction.
  • stable: Norms weights (for when exploding gradients occur)
  • learning_rate: Learning rate for Adam optimizer
  • monitor: Param used to save models

Citation

Please cite us if you decide to use our code for any part of your research.

@inproceedings{
anonymous2023translating,
title={CELL-E 2: Translating Proteins to Pictures and Back with a Bidirectional Text-to-Image Transformer},
author={Emaad Khwaja, Yun S. Song, Aaron Agarunov, and Bo Huang},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=YSMLVffl5u}
}

cell-e_2's People

Contributors

emaadkhwaja avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.