Comments (3)
Hi,
- "we only compute the loss w.r.t y0" refers to the mse loss, the x part is transformed to the regularization term.
- "nll" loss is for tracing the generation quality
- you can regard the "decoder_nll" as a regularization term of the embedding vectors.
- we didn't try this setting, you're free to try it.
from diffuseq.
I see. Thank you!
from diffuseq.
Hi,
- "we only compute the loss w.r.t y0" refers to the mse loss, the x part is transformed to the regularization term.
- "nll" loss is for tracing the generation quality
- you can regard the "decoder_nll" as a regularization term of the embedding vectors.
- we didn't try this setting, you're free to try it.
The calculation of mse loss also involves the x part according to my understanding:
target = x_start
model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
terms["mse"] = mean_flat((target - model_output) ** 2)
model_out_x_start = self._x0_helper(model_output, x_t, t)['pred_xstart'] # predicted_xstart = model_output
t0_mask = (t == 0)
t0_loss = mean_flat((x_start_mean - model_out_x_start) ** 2)
terms["mse"] = th.where(t0_mask, t0_loss, terms["mse"])
Since predict_xstart
in the config.json is True, model_output
is acctually the estimated x_start. You just directly calculte the mse loss between x_start and model_output without input_mask, so the x part is also involved in the mse loss. I print the result with the following code:
print((target[0] - model_output[0]) ** 2, input_ids_mask[0]) # print the first sentence of one batch
The output is:
tensor([[2.0462e+00, 3.8795e-01, 3.2121e-03, ..., 2.4803e-01, 1.7676e-01,
4.3906e-01],
[4.9620e+00, 5.4831e+00, 1.3603e+00, ..., 4.6070e+00, 3.6652e+00,
4.0038e-01],
[3.2369e-03, 8.0437e-01, 3.5606e-01, ..., 1.2198e-01, 5.0738e-01,
6.2278e-03],
...,
[3.9527e-01, 1.4870e+00, 6.3621e+00, ..., 2.2396e-02, 6.8237e-03,
1.6679e-01],
[6.9479e-01, 3.2860e-02, 7.0464e+00, ..., 1.8594e-01, 3.4218e-01,
7.5128e-03],
[1.6733e+00, 4.8956e-01, 5.6478e+00, ..., 3.3924e-02, 1.5616e-02,
1.0451e-01]], device='cuda:0', grad_fn=<PowBackward0>) tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
We can see that the loss of x part is nonzero, which is contradictory to your paper. If I have any misunderstanding, hope someone can correct it. Thanks a lot!
from diffuseq.
Related Issues (20)
- Issues with decoding and evaluation HOT 2
- Padding during training results in a "Killed"
- BERT parameter
- Try to train the model with another dataset, but get so many [UNK] token.
- a few questions about the 'MBR' decoding strategy. HOT 2
- Version of many packages
- Incorrect self-BLEU Computation
- a question about --local_rank
- Could not find a version that satisfies the requirement torch==1.9.0+cu111
- i face some promble Dataset(2) in "text_datasets.py" HOT 1
- If there is any rule to modify the parameters HOT 1
- Machine Translation Task with DiffuSeq HOT 6
- A question about the loss in V2
- Implementation of using soft absorbing state in the forward process in training. HOT 1
- ddim sampling HOT 2
- DDPM HOT 1
- train
- Where is CommonsenseConversation/test.jsonl ? When I run train. sh and then run run_decode_solver. sh or run_decode. sh, I always can't find test.jsonl HOT 2
- 'grad_norm' is NaN HOT 2
- Understanding tT_loss HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from diffuseq.