Giter VIP home page Giter VIP logo

contrastivevi's People

Contributors

chris522229197 avatar chrisdlin avatar ethanweinberger avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

contrastivevi's Issues

batch_key name adata.obs['batch'] is hardcoded

First of all, thank you for the great tool!
A small suggestion: currently, the get_library_log_means_and_var() function in utils.py requires that adata.obs has a column batch. It would be neat if this could be any column name, specifically the one specified by batch_key in ContrastiveVIModel.setup_anndata().

Issue during training of model

Dear contrastiveVI dev team,

First of all thanks a lot for developing this cool model and congrats on the Nature Methods paper on it!
I was trying to apply it to a MIBI dataset harboring different drug treatment conditions, but unfortunately ran into an issue I can't seem to figure out myself.

I run the following code (taken from the Alzheimer example), wherein treated_control is an anndata file containing my single cell data and with "Drug" being the condition column.

# imports
from contrastive_vi.model import ContrastiveVI
from pytorch_lightning.utilities.seed import seed_everything

seed_everything(42) # For reproducibility

treated_control = treated_control.copy()
ContrastiveVI.setup_anndata(treated_control) # setup adata for use with this model

model = ContrastiveVI(
    treated_control,
    n_salient_latent=10,
    n_background_latent=10,
    use_observed_lib_size=False
)

background_indices = np.where(treated_control.obs["Drug"] == "CTRL")[0]
target_indices = np.where(treated_control.obs["Drug"] != "CTRL")[0]

model.train(
    check_val_every_n_epoch=1,
    train_size=0.8,
    background_indices=background_indices,
    target_indices=target_indices,
    use_gpu=False,
    early_stopping=True,
    max_epochs=500,
)

running model.train, I get the following error message:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[20], line 4
      1 background_indices = np.where(treated_control.obs["Drug"] == "CTRL")[0]
      2 target_indices = np.where(treated_control.obs["Drug"] != "CTRL")[0]
----> 4 model.train(
      5     check_val_every_n_epoch=1,
      6     train_size=0.8,
      7     background_indices=background_indices,
      8     target_indices=target_indices,
      9     use_gpu=False,
     10     early_stopping=True,
     11     max_epochs=500,
     12 )

File ~\Anaconda3\envs\ST0036\Lib\site-packages\contrastive_vi\model\base\training_mixin.py:88, in ContrastiveTrainingMixin.train(self, background_indices, target_indices, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, plan_kwargs, **trainer_kwargs)
     77 trainer_kwargs[es] = (
     78     early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
     79 )
     80 runner = TrainRunner(
     81     self,
     82     training_plan=training_plan,
   (...)
     86     **trainer_kwargs,
     87 )
---> 88 return runner()

File ~\Anaconda3\envs\ST0036\Lib\site-packages\scvi\train\_trainrunner.py:74, in TrainRunner.__call__(self)
     71 if hasattr(self.data_splitter, "n_val"):
     72     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 74 self.trainer.fit(self.training_plan, self.data_splitter)
     75 self._update_history()
     77 # data splitter only gets these attrs after fit

File ~\Anaconda3\envs\ST0036\Lib\site-packages\scvi\train\_trainer.py:186, in Trainer.fit(self, *args, **kwargs)
    180 if isinstance(args[0], PyroTrainingPlan):
    181     warnings.filterwarnings(
    182         action="ignore",
    183         category=UserWarning,
    184         message="`LightningModule.configure_optimizers` returned `None`",
    185     )
--> 186 super().fit(*args, **kwargs)

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:740, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
    735     rank_zero_deprecation(
    736         "`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6."
    737         " Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'"
    738     )
    739     train_dataloaders = train_dataloader
--> 740 self._call_and_handle_interrupt(
    741     self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    742 )

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:685, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    675 r"""
    676 Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
    677 as all errors should funnel through them
   (...)
    682     **kwargs: keyword arguments to be passed to `trainer_fn`
    683 """
    684 try:
--> 685     return trainer_fn(*args, **kwargs)
    686 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    687 except KeyboardInterrupt as exception:

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:777, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    775 # TODO: ckpt_path only in v1.7
    776 ckpt_path = ckpt_path or self.resume_from_checkpoint
--> 777 self._run(model, ckpt_path=ckpt_path)
    779 assert self.state.stopped
    780 self.training = False

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:1138, in Trainer._run(self, model, ckpt_path)
   1136 self.call_hook("on_before_accelerator_backend_setup")
   1137 self.accelerator.setup_environment()
-> 1138 self._call_setup_hook()  # allow user to setup lightning_module in accelerator environment
   1140 # check if we should delay restoring checkpoint till later
   1141 if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:1438, in Trainer._call_setup_hook(self)
   1435 self.training_type_plugin.barrier("pre_setup")
   1437 if self.datamodule is not None:
-> 1438     self.datamodule.setup(stage=fn)
   1439 self.call_hook("setup", stage=fn)
   1441 self.training_type_plugin.barrier("post_setup")

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\core\datamodule.py:461, in LightningDataModule._track_data_hook_calls.<locals>.wrapped_fn(*args, **kwargs)
    459     else:
    460         attr = f"_has_{name}_{stage}"
--> 461         has_run = getattr(obj, attr)
    462         setattr(obj, attr, True)
    464 elif name == "prepare_data":

AttributeError: 'ContrastiveDataSplitter' object has no attribute '_has_setup_TrainerFn.FITTING'

I ran the package in a fresh conda environment.
Any ideas where the issue may lie?

Thanks a ton for your help!

Best regards,
Sven

Dealing mit with multiple conditions and details on model parameters

Good morning,

I'd like to try your tool, I don't have much experience with scVI and mainly work with the single-cell experiment class in Bioconductor, so it would be grea to have your insights on how to best run it.

1). Multiple conditions:
The data set I have is not unlike the MIX-Seq one, but I'm dealing with multiple conditions, not two. Would you run them all together or rather separately pairwise (control vs treatment 1, then control vs treatment 2 etc)?

2). Raw or normalized counts:
Does the tool require raw or normalized counts? (assuming raw, but wanted to make sure).

3). Model parameters:
How do parameters such as the number of latent layers, n_hidden and batch size affect the output? I tried running my dataset using the parameters below (using all conditions vs dmso). The model finished, but looking at the salient UMAP it was one big blob.

## Initialize raw counts
adata.raw = adata
adata.layers["counts"] = adata.X.toarray() # keep raw counts for scdef
model = ContrastiveVI(
    adata,
    n_batch = 0, # no batch correction
    n_layers = 1, 
    n_hidden = 128,
    n_salient_latent=10,
    n_background_latent=10,
    use_observed_lib_size=False
)
background_indices = np.where(adata.obs["Treatment"] == "DMSO")[0]
target_indices = np.where(adata.obs["Treatment"] != "DMSO")[0]
model.train(
    check_val_every_n_epoch = 1,
    train_size = 0.8, # 0 to 1
    background_indices = background_indices,
    target_indices = target_indices,
    use_gpu = True,
    early_stopping = True,
    max_epochs = 1000,
    batch_size = 512
)
## Convert things back into R to continue with SCE class
reducedDim(sce, "salient") <- scd$obsm[["salient_rep"]] # assign as dimensional reduction to sce
set.seed(100)
sce <- runUMAP(sce, dimred = "salient", n_neighbors = 5, name = paste0("UMAP_salient"), BPPARAM = mcparam) # assuming euclidean distance works as a metric

Many thanks : )

Alzheimer's Tutorial

I'm very interested to employ your elegant approach and started wth the tutorial. Everything works except for these lines:

top_genes_per_cell_type[cell_type] = results_tmp.index[:num_top_genes]
    de_results[cell_type] = set(results_tmp.index)

I get these types of errors:

NameError: name 'top_genes_per_cell_type' is not defined

Any suggestions?

salient space and downstrema for normal cells?

Hi,

I have a question about the downstream analysis using results from contrastive VI. I believed that all the values in the salient latent space were set to zero. However, when I generated a UMAP plot with all the cells and samples, including normal and case samples, I was able to observe a UMPAP plot that looks very reasonable. In this plot, normal cells clustered closely rather than with case cells.

So my question is, after model training, is it impossible or not informative to conduct downstream analyses such as connectivity, dimension reduction, DEG or others? If so, is it correct that the only way to use the salient space and other functions related to salient results is for un-normal cells by extracting them from the original data?

It would be great if you could double-check my thoughts or correct me if I have misunderstood some points.

Thank you!

Problem loading model

Hey ContrastiveVI team!

I'm super excited to try out your method! I ran ContrastiveVI and saved the model, presuming that the model would have similar scVI saving and loading capabilities, as below:

model.save(out_dir + "my_model/")

without any error. But now that I try to load it using:

model = scvi.model.SCVI.load(out_dir + 'n_latent.10/my_model', adata_cv)

I get the following error:

INFO     File                                                                                                      
         /data/peer/chanj3/HTA.fresh_plasticity.SCLC.120122/out.SCLC.ContrastiveVI.120122/n_latent.10/my_model/mode
         l.pt already downloaded                                                                                   
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[28], line 1
----> 1 model = scvi.model.SCVI.load(out_dir + 'n_latent.10/my_model', adata_cv)

File ~/anaconda3/envs/multiome_py_r/lib/python3.10/site-packages/scvi/model/base/_base_model.py:598, in BaseModelClass.load(cls, dir_path, adata, use_gpu, prefix, backup_url)
    596 registry = attr_dict.pop("registry_")
    597 if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__:
--> 598     raise ValueError(
    599         "It appears you are loading a model from a different class."
    600     )
    602 if _SETUP_ARGS_KEY not in registry:
    603     raise ValueError(
    604         "Saved model does not contain original setup inputs. "
    605         "Cannot load the original setup."
    606     )

ValueError: It appears you are loading a model from a different class.

Is there any way to save and load the ContrastiveVI model? Thanks so much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.