Comments (11)
IBM has a pretty good example
https://github.com/IBM/pytorch-seq2seq/blob/master/seq2seq/models/DecoderRNN.py#L108-L164
from moses.
@lilleswing, the code from MOSES you've sent implements teacher forcing. Did you mean that we should add free running for training?
from moses.
Yes I misread the code.
It is missing annealing off the teacher forcing (but that was not a component of the initial paper). The initial paper did always have teacher forcing during training and free running during sampling. It would be an improvement above the paper implementation.
from moses.
Yes, we’ll add free run soon. It will probably be denoted as a separate model at the metrics table.
from moses.
have you ever tested the reconstruction accuracy of VAE model? I tested the reconstruction accuracy and the performance is very bad, here is my testing code, is there any problem? thanks!
`def read_smiles_csv(path):
return pd.read_csv(path, usecols=['SMILES'], squeeze=True).astype(str).tolist()
if name == 'main':
parser = get_parser()
config = parser.parse_known_args()[0]
device = torch.device(config.device)
if device.type.startswith('cuda'):
torch.cuda.set_device(device.index or 0)
model_config = torch.load(config.config_save)
model_vocab = torch.load(config.vocab_save)
model_state = torch.load(config.model_save)
model = VAE(model_vocab, model_config)
model.load_state_dict(model_state)
model = model.to(device)
model.eval()
test_data_path = 'train.csv'
test_data = random.sample(read_smiles_csv(test_data_path), 100)
NUM_DEC = 500
num = 0
for ech in tqdm(test_data):
tensors = [model.string2tensor(ech.strip().strip("\n"), device=device)]
z_vecs, _ = model.forward_encoder(tensors)
res_lst = []
for i in tqdm(range(NUM_DEC)):
res = model.sample(n_batch=z_vecs.size(0), z=z_vecs)
res_lst.extend(res)
if ech in res:
num += 1
print("recons num: ", num)
print("reconstruct acc: ", num*1.0/100)`
from moses.
Hi, @liujunhongznn
Hi!
Low reconstruction quality is due to the posterior collapse that frequently happens in VAEs. Since the goal of MOSES is to produce the generative distribution as good as possible, the posterior collapse is acceptable for this task. If you want to obtain meaningful latent codes, try reducing KL divergence weight.
from moses.
@danpol Hello! Can you help me with VAE because I'm mixed up. As you before-mentioned this VAE implementation does use Teacher Forcing approach, but I don't see any loops with decoder (except val mode for generation of SMILES). Am I right that it's literally training with teacher forcing = 1? Because we don't pass previous predicted tokens (like in seq2seq models)
from moses.
Hi, @bokertof! VAE in MOSES uses teacher forcing—we pass the correct token, not the sampled one.
from moses.
@danpol Ok, I got it. Can you tell me what the reason not to use the sampled tokens as input? I'm trying to implement similar net and faced an issue when model with feeding of previously predicted tokens doesn't learn whatsoever.
from moses.
If you feed sampled tokens, you have to propagate the gradient through sampling (e.g., with REINFORCE), which has notoriously high variance. You could use variance reduction techniques, but it lies far from the notion of a "baseline".
from moses.
Thank you so much!
from moses.
Related Issues (20)
- why does the variable 'vocab' there have the property - 'vectors'? HOT 1
- ChemVAE support
- Error occurs when loading datasets, seems a gzip error HOT 11
- JTN-VAE model implementation
- Error installing molsets due to dependency pomegranate==0.12.0 HOT 2
- Any recent update?
- Distributed torch evaluation
- AttributeError: 'DataFrame' object has no attribute 'append' HOT 1
- Incompatibility with pandas 2.0
- Boron in MOSES
- some issues of eval.py
- Either docker or Git not working.
- Regarding of tests/test_metrics.py
- Aging, Epigenetic Drift and Gene Expression.
- Training ORGAN on cuda error
- Not a gzipped file when get_dataset() HOT 1
- On which distribution and version is 'pip install molsets' supposed to work?
- moses.get_dataset('train') -> BadGzipFile: Not a gzipped file (b've') HOT 4
- RuntimeError: 'lengths' argument should be a 1D CPU int64 tensor, but got 1D cuda:1 Long tensor HOT 4
- Is the validity check of smiles in moses the same as RDKit?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from moses.