lvm-de-reproducibility's People
lvm-de-reproducibility's Issues
Some question about reproducing the code in COVID-19 patient data
Hi,
I got some questions in reproducing the code in COVID-19 patient data. I have followed the tutorial (https://github.com/PierreBoyeau/lvm-DE-reproducibility/tree/cb020124611a0e0dcf9393c74a66bae9bc60bfdd/experiments) to install all required packages, and follow the py script https://github.com/PierreBoyeau/lvm-DE-reproducibility/blob/cb020124611a0e0dcf9393c74a66bae9bc60bfdd/experiments/run_blish.py to reproduce the results in Figure 5. However, I could'n train the models (both scVI and scPhere) and two errors are popping up.
when train the scPhere model
mdl_sph = SCSphereFull(**mdl_sph_kwargs)
mdl_sph, train_sph = load_scvi_model_if_exists(mdl_sph, filename=sph_filename)
trainer_sph = UnsupervisedTrainer(
model=mdl_sph, gene_dataset=dataset, **trainer_sph_kwargs
)
if train_sph:
lr_sph = trainer_sph.find_lr()
logging.info("Using learning rate {}".format(lr_sph))
mdl_sph = SCSphereFull(**mdl_sph_kwargs)
trainer_sph = UnsupervisedTrainer(
model=mdl_sph, gene_dataset=dataset, **trainer_sph_kwargs
)
trainer_sph.train(n_epochs=N_EPOCHS, lr=lr_sph)
mdl_sph.eval()
Here is the error
INFO:root:Using Deep architecture ...
INFO:root:Using deep architecture
/home/icb/weixu.wang/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
INFO:root:Unique optim
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/tmp/ipykernel_10598/1878467605.py in <module>
5 )
6 if train_sph:
----> 7 lr_sph = trainer_sph.find_lr()
8 logging.info("Using learning rate {}".format(lr_sph))
9 mdl_sph = SCSphereFull(**mdl_sph_kwargs)
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/scvi-0.3.0-py3.7.egg/scvi/inference/trainer.py in find_lr(self, lr, eps)
357 for tensors_list in self.data_loaders_loop():
358 # loss = self.loss(*tensors_list)
--> 359 loss = self.iter_step(tensors_list)
360
361 if torch.isnan(loss).any():
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/scvi-0.3.0-py3.7.egg/scvi/inference/trainer.py in iter_step(self, tensors_list)
165 def iter_step(self, tensors_list):
166 if self.optimizer is not None:
--> 167 loss = self.loss(*tensors_list)
168 self.optimizer.zero_grad()
169 loss.backward()
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/scvi-0.3.0-py3.7.egg/scvi/inference/inference.py in loss(self, tensors)
98 n_samples=self.k,
99 train_library=self.train_library,
--> 100 beta=self.beta,
101 )
102 return loss.mean()
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/scvi-0.3.0-py3.7.egg/scvi/models/scsphere.py in forward(self, x, local_l_mean, local_l_var, batch_index, y, loss, n_samples, train_library, beta)
738 n_samples=n_samples,
739 train_library=train_library,
--> 740 reparam=True,
741 )
742 px_rate = outputs["px_rate"]
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/scvi-0.3.0-py3.7.egg/scvi/models/scsphere.py in inference(self, x, batch_index, y, reparam, n_samples, train_library)
572
573 x_ = torch.log1p(x).float()
--> 574 z_post = self.z_encoder(x_, n_samples=n_samples, reparam=reparam)
575 l_post = self.l_encoder(x_, n_samples=n_samples, reparam=reparam)
576 qz_m = z_post["qz_m"]
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/scvi-0.3.0-py3.7.egg/scvi/models/scsphere.py in forward(self, input, n_samples, reparam)
74 # else:
75 # z = z_dist.sample(n_samples)
---> 76 z_dist = PowerSpherical(z_mu, z_std.squeeze(-1))
77 if reparam:
78 z = z_dist.rsample((n_samples,))
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/power_spherical-0.1.0-py3.7.egg/power_spherical/distributions.py in __init__(self, loc, scale, validate_args)
154 _JointTSDistribution(
155 MarginalTDistribution(
--> 156 loc.shape[-1], scale, validate_args=validate_args
157 ),
158 HypersphericalUniform(
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/power_spherical-0.1.0-py3.7.egg/power_spherical/distributions.py in __init__(self, dim, scale, validate_args)
92 (dim - 1) / 2 + scale, (dim - 1) / 2, validate_args=validate_args
93 ),
---> 94 transforms=torch.distributions.AffineTransform(loc=-1, scale=2),
95 )
96 self.dim, self.scale = dim, scale
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/torch/distributions/transformed_distribution.py in __init__(self, base_distribution, transforms, validate_args)
78 batch_shape = shape[:cut]
79 event_shape = shape[cut:]
---> 80 super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args)
81
82 def expand(self, batch_shape, _instance=None):
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
49 if constraints.is_dependent(constraint):
50 continue # skip constraints that cannot be checked
---> 51 if param not in self.__dict__ and isinstance(getattr(type(self), param), lazy_property):
52 continue # skip checking lazily-constructed args
53 value = getattr(self, param)
AttributeError: type object 'MarginalTDistribution' has no attribute 'dim'
Meanwhile, when I run the scVI models, different error is popping up
mdl_iw = VAE(n_input=dataset.nb_genes, n_batch=dataset.n_batches, **mdl_iw_kwargs)
mdl_iw, train_iw = load_scvi_model_if_exists(mdl_iw, filename=iw_filename)
trainer_iw = UnsupervisedTrainer(
model=mdl_iw, gene_dataset=dataset, **trainer_iw_kwargs
)
if train_iw:
lr_iw = trainer_iw.find_lr()
mdl_iw = VAE(n_input=dataset.nb_genes, n_batch=dataset.n_batches, **mdl_iw_kwargs)
trainer_iw = UnsupervisedTrainer(
model=mdl_iw, gene_dataset=dataset, **trainer_iw_kwargs
)
logging.info("Using learning rate {}".format(lr_iw))
trainer_iw.train(n_epochs=250, lr=lr_iw)
mdl_iw.eval()
the error
INFO:root:Normal parameterization of the library
INFO:root:Scale decoder with Softmax normalization
/home/icb/weixu.wang/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/torch/utils/data/dataloader.py:557: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
INFO:root:Unique optim
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_13302/2510978555.py in <module>
5 )
6 if train_iw:
----> 7 lr_iw = trainer_iw.find_lr()
8 mdl_iw = VAE(n_input=dataset.nb_genes, n_batch=dataset.n_batches, **mdl_iw_kwargs)
9 trainer_iw = UnsupervisedTrainer(
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/scvi-0.3.0-py3.7.egg/scvi/inference/trainer.py in find_lr(self, lr, eps)
357 for tensors_list in self.data_loaders_loop():
358 # loss = self.loss(*tensors_list)
--> 359 loss = self.iter_step(tensors_list)
360
361 if torch.isnan(loss).any():
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/scvi-0.3.0-py3.7.egg/scvi/inference/trainer.py in iter_step(self, tensors_list)
165 def iter_step(self, tensors_list):
166 if self.optimizer is not None:
--> 167 loss = self.loss(*tensors_list)
168 self.optimizer.zero_grad()
169 loss.backward()
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/scvi-0.3.0-py3.7.egg/scvi/inference/inference.py in loss(self, tensors)
98 n_samples=self.k,
99 train_library=self.train_library,
--> 100 beta=self.beta,
101 )
102 return loss.mean()
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/scvi-0.3.0-py3.7.egg/scvi/models/vae.py in forward(self, x, local_l_mean, local_l_var, batch_index, y, loss, n_samples, train_library, beta, multiplicative_std)
439 train_library=train_library,
440 reparam=True,
--> 441 multiplicative_std=multiplicative_std,
442 )
443 px_rate = outputs["px_rate"]
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/scvi-0.3.0-py3.7.egg/scvi/models/vae.py in inference(self, x, batch_index, y, reparam, n_samples, train_library, multiplicative_std)
259 )
260
--> 261 z_post = self.z_encoder(x_, y, n_samples=n_samples, reparam=reparam, multiplicative_std=multiplicative_std)
262 log_qz_x_detach = Normal(z_post["q_m"].detach(), z_post["q_v"].sqrt().detach()).log_prob(z_post["latent"]).sum(-1)
263 z_variables = dict(
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1188 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1189 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190 return forward_call(*input, **kwargs)
1191 # Do not call functions when jit is used
1192 full_backward_hooks, non_full_backward_hooks = [], []
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/scvi-0.3.0-py3.7.egg/scvi/models/modules.py in forward(self, x, n_samples, reparam, squeeze, multiplicative_std, *cat_list)
513 if multiplicative_std is not None:
514 q_v = (multiplicative_std ** 2) * q_v
--> 515 dist = Normal(q_m, q_v.sqrt())
516 # dist = Normal(q_m, q_v)
517 # latent = self.reparameterize(q_m, q_v, reparam=reparam)
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/torch/distributions/normal.py in __init__(self, loc, scale, validate_args)
54 else:
55 batch_shape = self.loc.size()
---> 56 super(Normal, self).__init__(batch_shape, validate_args=validate_args)
57
58 def expand(self, batch_shape, _instance=None):
~/miniconda3/envs/ref_dif_testing/lib/python3.7/site-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
55 if not valid.all():
56 raise ValueError(
---> 57 f"Expected parameter {param} "
58 f"({type(value).__name__} of shape {tuple(value.shape)}) "
59 f"of distribution {repr(self)} "
ValueError: Expected parameter loc (Tensor of shape (25, 128, 10)) of distribution Normal(loc: torch.Size([25, 128, 10]), scale: torch.Size([25, 128, 10])) to satisfy the constraint Real(), but found invalid values:
tensor([[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
...,
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]]], device='cuda:0',
grad_fn=<ExpandBackward0>)
I think the installation is fine, seems all packages could be run.
Do you have any solution for this?
Best Regards,
Ken
"use_observed_batches" in "model.differential_expression depends on batch_correction
Dear all, dear Dr. Boyeau,
Thank you for the great toolkit, it is exactly what I need to analyse my data.
I have been interested in the lvm-de, as it not only uses the scvi models I already trained, but believe that the bayesian approach with optional permutation is perfect for the heterogeneity of my data.
Specifically, I am looking to compare the expression of two groups within a cell type, a scenario covered in the respective publication.
As I have several patients per group, I typically use the batch_correction
option.
However, I tried to use use_observed = True
for conditioning the expression per cell on the respective batch/sample, so I would get a sample-based comparison of expression values.
This gave me the error:
TypeError: scvi.model.base._differential.DifferentialComputation.get_bayes_factors() got multiple values for keyword argument 'use_observed_batches'
Digging a bit into the source code, I found that in scvi-tools / scvi / model / base / _utils.py
the _de_core
function calls DifferentialComputation.get_bayes_factors (line 250).
In this function call, the option is set as use_observed_batches=not batch_correction
.
So to my understanding, if I understand correctl, whenever batch_correction is True, I cannot condition on batches, and th other way around.
This does seem deliberate, so my question is how batch_correction and conditioning on batches relate to each other?
Why does one exclude the other, or is even the opposite of the other?
Is there a way to not condition on batch, and not batch_correct at the same time?
I am trying to get the background, as I am relatively new to the field and will publish the results, for which I like to have at least a most basic understanding.
Thank you for any help.
Best,
Max
Question about equation (10)
Hi Pierre,
In the equation (10) of your paper, you need to calculate the probability of LFC > delta and LFC < -delta. For defined two clusters A and B, we can estimate the denoised mean of A and B as hA and hB by the decoder, as described in the paper. How can we obtain the probability of LFC > delta or LFC < delta? Since we only have one estimate mean for A and B, respectively. Thanks.
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. ๐๐๐
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google โค๏ธ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.