Giter VIP home page Giter VIP logo

songnet's Introduction

SongNet

SongNet: SongCi + Song (Lyrics) + Sonnet + etc.

@inproceedings{li-etal-2020-rigid,
    title = "Rigid Formats Controlled Text Generation",
    author = "Li, Piji and Zhang, Haisong and Liu, Xiaojiang and Shi, Shuming",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month = jul,
    year = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url = "https://www.aclweb.org/anthology/2020.acl-main.68",
    doi = "10.18653/v1/2020.acl-main.68",
    pages = "742--751"
}

Run

  • python prepare_data.py
  • ./train.sh

Evaluation

  • Modify test.py: m_path = the best dev model
  • ./test.sh
  • python metrics.py

Polish

  • ./polish.sh

Download

Reference

songnet's People

Contributors

lipiji avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

songnet's Issues

执行./test 出现错误 "IndexError: The shape of the mask [1] at index 0 does not match "

李老师你好, 您当前的代码, 我运行没有任何问题, 但是当我把数据迁移到自己搜集的数据时, 会出现错误.

具体报错如下

Traceback (most recent call last):
  File "test.py", line 359, in <module>
    res = top_k_inc(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s)
  File "test.py", line 61, in top_k_inc
    incremental_state)
  File "/content/SongNet/biglm.py", line 91, in work_incremental
    incremental_state=incremental_state)
  File "/content/SongNet/transformer.py", line 73, in work_incremental
    attn_mask=self_attn_mask, incremental_state=incremental_state)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/content/SongNet/transformer.py", line 156, in forward
    prev_key = prev_key[bidx]
IndexError: The shape of the mask [1] at index 0 does not match the shape of the indexed tensor [2, 12, 1, 64] at index 0

错误情形
train和eval, polish都没有问题, 但是运行test中执行到某条数据时, 就会出现这种错误. 也就是有的数据可以正常预测和打印, 有的不能.

临时的解决办法
我在test中增加try except 跳过执行出错的例子.

希望能解决bug
我读了transformer prev_key前后的代码, 没能理解错误. 如果您在调试中也遇到类似问题, 能给一些解决的提示么?

训练与生成时 ys_tpl与xs_tpl 不匹配的问题

请问一下,训练时ys_tpl与xs_tpl(代表格式与押韵信息)只有c0,c2,c1,而生成(polish)时,对应的ys_tpl和xs_tpl却会包含已有的字的信息,比如C={c0,c0,love,c1,,bends,c0,remove,c1,,} 。
1)那么模型在生成时真的能够正确识别到ys_tpl和xs_tpl中字的信息吗,考虑到在训练时它从未见过这样的输入。如果缺失的字的比例只有20%甚至更低时,模型真的还能有相应的生成能力吗?
2)在代码里面,对应的缺失的字“” 会被统一替换为c1,这里是不是没有考虑到字如果是韵脚的情况,即“”应该被替换为c2?

How to generate text containing fixed text information?

Hello, as shown in the Table 6 of the paper, you mention that "our model has the ability of refining and polishing given the format C which contains some fixed text information". So could you please tell me how to make it specifically? 😊 You could reply in Chinese if you would like to, thanks a lot!

运行./test.sh时发生报错:

你好 我在训练网络完毕之后 将test.py中的m_path改为了结果中最新的checkpoint的地址

但是在运行./test.sh时发生报错:
Traceback (most recent call last):
File "test.py", line 34, in
lm_model, lm_vocab, lm_args = init_model(m_path, gpu, "./model/vocab.txt")
File "test.py", line 28, in init_model
lm_model.load_state_dict(ckpt['model'])
File "/usr/local/anaconda3/envs/GPT/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for BIGLM:
size mismatch for tok_embed.weight: copying a param with shape torch.Size([6410, 768]) from checkpoint, the shape in current model is torch.Size([28781, 768]).
size mismatch for out_proj.weight: copying a param with shape torch.Size([6410, 768]) from checkpoint, the shape in current model is torch.Size([28781, 768]).
size mismatch for out_proj.bias: copying a param with shape torch.Size([6410]) from checkpoint, the shape in current model is torch.Size([28781]).

请问这是什么原因导致的呢?非常感谢

预训练模型的vocab.txt提供错误?

想要用pre-trained模型重新训练时,出现:
RuntimeError: Error(s) in loading state_dict for BIGLM:
size mismatch for tok_embed.weight: copying a param with shape torch.Size([28781, 768]) from checkpoint, the shape in current model is torch.Size([6410, 768]).
的错误,个人推测似乎是vocab.txt上传到songci的版本,想请问怎么处理?

在google colab 执行./train.sh, 出现错误, 不确定是否是torch版本带来的问题

image

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [136, 16, 2304]], which is output 0 of AddBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

version: torch-1.7.0+cu101, python-3.6.9

李老师, 能否告知下你的torch版本啊

c2押韵问题

您好,论文中提到,c2 表示押韵的字词;代码中,把最后一个字(除标点外)都标记为是押韵的。请问,这是基于宋词的特点,还是有别的考量啊?
dataset 中代码

LICENSE

Thanks for sharing the code and model. What's the LICENSE of SongNet?

评价标准

论文里列出了很多标准,包括 rhyme, tpl 等相关 loss,但是代码中真正做 backpropagation 的好像只有 nll/ppl。
我想请问下预训练的过程中,这些 格式、韵律 相关的 Loss 有做计算和backpropagation吗?

另外,论文中提到 Beam Search 方法生成,代码中只有在 test.py 里出现过,我想请问下是尝试过 TopK 生成效果优于 BeamSearch 吗?为什么在做 polish 的时候不选择 BeamSearch ?

model部分使用work函数输出的词是随机中文

您好!
最近在研究您的代码,您在做模型测试的时候,使用的是work_incremental函数,我这边看到您的work函数,去除了incremental_state然后用了该函数,但是输出的时候输出的是随机的中文,而且标点符号也不会显示,如下面的截图,请问这是什么情况?
微信图片_20210330214711

关于预训练的疑问

论文中提到,预训练按照bert来的。主要说到了完形填空任务,请问是否用到了下句预测

关于global attention

E_c+E_p+E_s 在H和F中都有出现,可不可以考虑 将H改成 H = E_w + E_g,使得H的信息更加简单,容易解耦,同时网络并不会损失E_c,E_p,E_s的信息?

Bad results..

Hi
I trained on small dataset

/content/SongNet
9
2300 667 2599
7
2
9
vocab
done
vocab.size = 1215
batch_acm 99, loss 5.277, acc 0.102, nll 6.265, ppl 86.928, x_acm 1584, lr 0.000002
batch_acm 199, loss 3.566, acc 0.262, nll 4.273, ppl 20.652, x_acm 3179, lr 0.000005
batch_acm 299, loss 2.664, acc 0.323, nll 3.223, ppl 9.619, x_acm 4774, lr 0.000008
batch_acm 399, loss 2.123, acc 0.396, nll 2.580, ppl 6.136, x_acm 6374, lr 0.000010
batch_acm 499, loss 1.807, acc 0.452, nll 2.195, ppl 4.706, x_acm 7969, lr 0.000013
validating...
epoch-3-acm-499 nll= 1.846888825275015 ppl= 3.7909837519747205 count= 667.0
batch_acm 599, loss 1.580, acc 0.502, nll 1.921, ppl 3.879, x_acm 9564, lr 0.000015
batch_acm 699, loss 1.370, acc 0.561, nll 1.673, ppl 3.265, x_acm 11164, lr 0.000018
batch_acm 799, loss 1.151, acc 0.631, nll 1.418, ppl 2.733, x_acm 12759, lr 0.000020
batch_acm 899, loss 0.934, acc 0.701, nll 1.167, ppl 2.290, x_acm 14354, lr 0.000023
batch_acm 999, loss 0.695, acc 0.787, nll 0.890, ppl 1.882, x_acm 15954, lr 0.000025
validating...
epoch-6-acm-999 nll= 0.3973974670427314 ppl= 1.3300093074609851 count= 667.0
batch_acm 1099, loss 0.494, acc 0.856, nll 0.652, ppl 1.589, x_acm 17549, lr 0.000028
batch_acm 1199, loss 0.332, acc 0.913, nll 0.460, ppl 1.384, x_acm 19144, lr 0.000030
batch_acm 1299, loss 0.223, acc 0.945, nll 0.330, ppl 1.260, x_acm 20739, lr 0.000033
batch_acm 1399, loss 0.157, acc 0.966, nll 0.252, ppl 1.192, x_acm 22339, lr 0.000035
batch_acm 1499, loss 0.117, acc 0.975, nll 0.208, ppl 1.156, x_acm 23934, lr 0.000038
validating...
epoch-10-acm-1499 nll= 0.08431253172289664 ppl= 1.060383505013393 count= 667.0
training time: 453sec.

and test result is unreadable text after execute polish.sh with my ckpt epoch10_batch_1499 and my vocab.txt

ps: my edited polish_tpl.txt

['Gufd<s1>327711<s2>_____,____ менять.______ _____ сейчас. _________ любимый. ______ _____ много.']
0.7558178901672363

result:

Gufd<s1>327711<s2>_____,____ менять.______ _____ сейчас. _________ любимый. ______ _____ много.
<bos>По-ти, мув менять.шозыха сйшаб с</s>

请问该怎样理解transformer模块中的incremental_state?

您好,近日想做一些关于AI写诗方面的研究,研读了您的代码。但是在读到您的transformer模块时,不太理解您的incremental_state的实现,因为之前在其他论文或transformer的代码中并没有看到过类似的实现(可能也有我看代码看的不很多的缘故),想询问一下您,您的incremental state与其中的'bidx'项该怎么样进行理解?
这个transformer实现是不是一种原来transformer加速实现的方法,有没有相应的论文或这方面的说明?求指教!

谢谢!

如何使用多卡训练?

尝试把 train.sh 里的 world_size 和 gpus 都设为 8,报了这个错误:

label_smoothing.py", line 15, in init
self.one_hot = torch.full((1, size), self.smoothing_value).to(device)
RuntimeError: CUDA error: invalid device ordinal

请问应该怎么办啊?

ps:另外发现了一个小问题,无论怎样设置 CUDA_VISIBLE_DEVICES,单卡时总是使用第二个 GPU,正在尝试解决

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.