suinleelab / contrastivevi Goto Github PK
View Code? Open in Web Editor NEWLicense: BSD 3-Clause "New" or "Revised" License
License: BSD 3-Clause "New" or "Revised" License
First of all, thank you for the great tool!
A small suggestion: currently, the get_library_log_means_and_var()
function in utils.py
requires that adata.obs
has a column batch
. It would be neat if this could be any column name, specifically the one specified by batch_key
in ContrastiveVIModel.setup_anndata()
.
Hi contrastiveVI team!
I just wanted to let you know that we are soon going to release a version of scvi-tools that breaks your code. The fortunate thing is that it doesn't look hard to fix. You can read more here:
https://docs.scvi-tools.org/en/0.15.0-beta.0/release_notes/v0.15.0.html
Please feel free to reach out with any questions on our discourse. As a first step you may consider upper bounding the version to <0.15.0, and then making the changes.
Dear contrastiveVI dev team,
First of all thanks a lot for developing this cool model and congrats on the Nature Methods paper on it!
I was trying to apply it to a MIBI dataset harboring different drug treatment conditions, but unfortunately ran into an issue I can't seem to figure out myself.
I run the following code (taken from the Alzheimer example), wherein treated_control is an anndata file containing my single cell data and with "Drug" being the condition column.
# imports
from contrastive_vi.model import ContrastiveVI
from pytorch_lightning.utilities.seed import seed_everything
seed_everything(42) # For reproducibility
treated_control = treated_control.copy()
ContrastiveVI.setup_anndata(treated_control) # setup adata for use with this model
model = ContrastiveVI(
treated_control,
n_salient_latent=10,
n_background_latent=10,
use_observed_lib_size=False
)
background_indices = np.where(treated_control.obs["Drug"] == "CTRL")[0]
target_indices = np.where(treated_control.obs["Drug"] != "CTRL")[0]
model.train(
check_val_every_n_epoch=1,
train_size=0.8,
background_indices=background_indices,
target_indices=target_indices,
use_gpu=False,
early_stopping=True,
max_epochs=500,
)
running model.train, I get the following error message:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[20], line 4
1 background_indices = np.where(treated_control.obs["Drug"] == "CTRL")[0]
2 target_indices = np.where(treated_control.obs["Drug"] != "CTRL")[0]
----> 4 model.train(
5 check_val_every_n_epoch=1,
6 train_size=0.8,
7 background_indices=background_indices,
8 target_indices=target_indices,
9 use_gpu=False,
10 early_stopping=True,
11 max_epochs=500,
12 )
File ~\Anaconda3\envs\ST0036\Lib\site-packages\contrastive_vi\model\base\training_mixin.py:88, in ContrastiveTrainingMixin.train(self, background_indices, target_indices, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, plan_kwargs, **trainer_kwargs)
77 trainer_kwargs[es] = (
78 early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
79 )
80 runner = TrainRunner(
81 self,
82 training_plan=training_plan,
(...)
86 **trainer_kwargs,
87 )
---> 88 return runner()
File ~\Anaconda3\envs\ST0036\Lib\site-packages\scvi\train\_trainrunner.py:74, in TrainRunner.__call__(self)
71 if hasattr(self.data_splitter, "n_val"):
72 self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 74 self.trainer.fit(self.training_plan, self.data_splitter)
75 self._update_history()
77 # data splitter only gets these attrs after fit
File ~\Anaconda3\envs\ST0036\Lib\site-packages\scvi\train\_trainer.py:186, in Trainer.fit(self, *args, **kwargs)
180 if isinstance(args[0], PyroTrainingPlan):
181 warnings.filterwarnings(
182 action="ignore",
183 category=UserWarning,
184 message="`LightningModule.configure_optimizers` returned `None`",
185 )
--> 186 super().fit(*args, **kwargs)
File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:740, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
735 rank_zero_deprecation(
736 "`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6."
737 " Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'"
738 )
739 train_dataloaders = train_dataloader
--> 740 self._call_and_handle_interrupt(
741 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
742 )
File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:685, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
675 r"""
676 Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
677 as all errors should funnel through them
(...)
682 **kwargs: keyword arguments to be passed to `trainer_fn`
683 """
684 try:
--> 685 return trainer_fn(*args, **kwargs)
686 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
687 except KeyboardInterrupt as exception:
File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:777, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
775 # TODO: ckpt_path only in v1.7
776 ckpt_path = ckpt_path or self.resume_from_checkpoint
--> 777 self._run(model, ckpt_path=ckpt_path)
779 assert self.state.stopped
780 self.training = False
File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:1138, in Trainer._run(self, model, ckpt_path)
1136 self.call_hook("on_before_accelerator_backend_setup")
1137 self.accelerator.setup_environment()
-> 1138 self._call_setup_hook() # allow user to setup lightning_module in accelerator environment
1140 # check if we should delay restoring checkpoint till later
1141 if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:1438, in Trainer._call_setup_hook(self)
1435 self.training_type_plugin.barrier("pre_setup")
1437 if self.datamodule is not None:
-> 1438 self.datamodule.setup(stage=fn)
1439 self.call_hook("setup", stage=fn)
1441 self.training_type_plugin.barrier("post_setup")
File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\core\datamodule.py:461, in LightningDataModule._track_data_hook_calls.<locals>.wrapped_fn(*args, **kwargs)
459 else:
460 attr = f"_has_{name}_{stage}"
--> 461 has_run = getattr(obj, attr)
462 setattr(obj, attr, True)
464 elif name == "prepare_data":
AttributeError: 'ContrastiveDataSplitter' object has no attribute '_has_setup_TrainerFn.FITTING'
I ran the package in a fresh conda environment.
Any ideas where the issue may lie?
Thanks a ton for your help!
Best regards,
Sven
Good morning,
I'd like to try your tool, I don't have much experience with scVI and mainly work with the single-cell experiment class in Bioconductor, so it would be grea to have your insights on how to best run it.
1). Multiple conditions:
The data set I have is not unlike the MIX-Seq one, but I'm dealing with multiple conditions, not two. Would you run them all together or rather separately pairwise (control vs treatment 1, then control vs treatment 2 etc)?
2). Raw or normalized counts:
Does the tool require raw or normalized counts? (assuming raw, but wanted to make sure).
3). Model parameters:
How do parameters such as the number of latent layers
, n_hidden
and batch size
affect the output? I tried running my dataset using the parameters below (using all conditions vs dmso). The model finished, but looking at the salient UMAP it was one big blob.
## Initialize raw counts
adata.raw = adata
adata.layers["counts"] = adata.X.toarray() # keep raw counts for scdef
model = ContrastiveVI(
adata,
n_batch = 0, # no batch correction
n_layers = 1,
n_hidden = 128,
n_salient_latent=10,
n_background_latent=10,
use_observed_lib_size=False
)
background_indices = np.where(adata.obs["Treatment"] == "DMSO")[0]
target_indices = np.where(adata.obs["Treatment"] != "DMSO")[0]
model.train(
check_val_every_n_epoch = 1,
train_size = 0.8, # 0 to 1
background_indices = background_indices,
target_indices = target_indices,
use_gpu = True,
early_stopping = True,
max_epochs = 1000,
batch_size = 512
)
## Convert things back into R to continue with SCE class
reducedDim(sce, "salient") <- scd$obsm[["salient_rep"]] # assign as dimensional reduction to sce
set.seed(100)
sce <- runUMAP(sce, dimred = "salient", n_neighbors = 5, name = paste0("UMAP_salient"), BPPARAM = mcparam) # assuming euclidean distance works as a metric
Many thanks : )
I'm very interested to employ your elegant approach and started wth the tutorial. Everything works except for these lines:
top_genes_per_cell_type[cell_type] = results_tmp.index[:num_top_genes]
de_results[cell_type] = set(results_tmp.index)
I get these types of errors:
NameError: name 'top_genes_per_cell_type' is not defined
Any suggestions?
Hi,
I have a question about the downstream analysis using results from contrastive VI. I believed that all the values in the salient latent space were set to zero. However, when I generated a UMAP plot with all the cells and samples, including normal and case samples, I was able to observe a UMPAP plot that looks very reasonable. In this plot, normal cells clustered closely rather than with case cells.
So my question is, after model training, is it impossible or not informative to conduct downstream analyses such as connectivity, dimension reduction, DEG or others? If so, is it correct that the only way to use the salient space and other functions related to salient results is for un-normal cells by extracting them from the original data?
It would be great if you could double-check my thoughts or correct me if I have misunderstood some points.
Thank you!
Hey ContrastiveVI team!
I'm super excited to try out your method! I ran ContrastiveVI and saved the model, presuming that the model would have similar scVI saving and loading capabilities, as below:
model.save(out_dir + "my_model/")
without any error. But now that I try to load it using:
model = scvi.model.SCVI.load(out_dir + 'n_latent.10/my_model', adata_cv)
I get the following error:
INFO File
/data/peer/chanj3/HTA.fresh_plasticity.SCLC.120122/out.SCLC.ContrastiveVI.120122/n_latent.10/my_model/mode
l.pt already downloaded
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[28], line 1
----> 1 model = scvi.model.SCVI.load(out_dir + 'n_latent.10/my_model', adata_cv)
File ~/anaconda3/envs/multiome_py_r/lib/python3.10/site-packages/scvi/model/base/_base_model.py:598, in BaseModelClass.load(cls, dir_path, adata, use_gpu, prefix, backup_url)
596 registry = attr_dict.pop("registry_")
597 if _MODEL_NAME_KEY in registry and registry[_MODEL_NAME_KEY] != cls.__name__:
--> 598 raise ValueError(
599 "It appears you are loading a model from a different class."
600 )
602 if _SETUP_ARGS_KEY not in registry:
603 raise ValueError(
604 "Saved model does not contain original setup inputs. "
605 "Cannot load the original setup."
606 )
ValueError: It appears you are loading a model from a different class.
Is there any way to save and load the ContrastiveVI model? Thanks so much!
Hi, do we have tutorials for totalcontastiveVI? Thanks.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.