Giter VIP home page Giter VIP logo

wobert's Introduction

WoBERT

以词为基本单位的中文BERT(Word-based BERT)

详情

https://kexue.fm/archives/7758

训练

目前开源的WoBERT是Base版本,在哈工大开源的RoBERTa-wwm-ext基础上进行继续预训练,预训练任务为MLM。初始化阶段,将每个词用BERT自带的Tokenizer切分为字,然后用字embedding的平均作为词embedding的初始化。模型使用单张24G的RTX训练了100万步(大概训练了10天),序列长度为512,学习率为5e-6,batch_size为16,累积梯度16步,相当于batch_size=256训练了6万步左右。训练语料大概是30多G的通用型语料。

此外,我们还提供了WoNEZHA,这是基于华为开源的NEZHA进行再预训练的,训练细节跟WoBERT基本一样。NEZHA的模型结构跟BERT相似,不同的是它使用了相对位置编码,而BERT用的是绝对位置编码,因此理论上NEZHA能处理的文本长度是无上限的。这里提供以词为单位的WoNEZHA,就是让大家多一个选择。

2021年03月03日: 新增WoBERT Plus模型,以RoBERTa-wwm-ext为基础,中文MLM式预训练,重新构建词表(比已经开源的WoBERT更完善),30+G语料,maxlen=512,batch_size=256、lr=1e-5训练了25万步(4 * TITAN RTX,累积4步梯度,是之前的WoBERT的4倍),每1000步耗时约1580s,共训练了18天,训练acc约64%,训练loss约1.80。

依赖

pip install bert4keras==0.8.8

下载

评测

IFLYTEK TNEWS
BERT 60.31 56.94
WoBERT 61.15 57.05
WoBERT Plus 61.92 58.20

引用

Bibtex:

@techreport{zhuiyiwobert,
  title={WoBERT: Word-based Chinese BERT model - ZhuiyiAI},
  author={Jianlin Su},
  year={2020},
  url="https://github.com/ZhuiyiTechnology/WoBERT",
}

联系

邮箱:[email protected] 追一科技:https://zhuiyi.ai

wobert's People

Contributors

zhuiyitechnology avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

wobert's Issues

CUDA error: device-side assert triggered

您好,我尝试在vocab.txt词表中添加了“家居”和“时政”两个词表文件中没有的词汇后,使用模型时报了这样的错: CUDA error: device-side assert triggered。看到issue里也有遇到同样问题的朋友,回复是在bert4keras里边可以通过compound_tokens的方式增加新词。由于使用前我把TF版的WoBERT转成PyTorch版的了,请问我应该怎么修改呢,您有什么建议么?由衷感谢🌹

有关预训练

您好,看您用了30G的语料来训练wobert,原生bert不是只用了wiki吗?感觉同语料下对比才更有说服力

数据集

您好,能提供一下预训练用的数据集吗

tokenizer.tokenize分词问题

tokenizer.tokenize("作为一个品牌 希望你们不要推卸责任 不要把错误推卸顾客身上")
为什么分词完了
没有空格了
['[CLS]', '作为', '一个', '品牌', '希望', '你们', '不要', '推', '卸', '责', '任', '不要', '把', '错误', '推', '卸', '顾客', '身上', '[SEP]']

关于unilm文本生成

苏神您好,我看了下您那个自动生成标题的代码,有一个关于预测时批次处理的问题。举个例子,输入时sentence1和sentence2,经过tokenizer的处理后会变成[cls s1 sep s2 sep padding]的形式。但是在预测的时候时由于没有sentence2,把输入padding到同一长度后会变为[cls s1 sep padding],那么sentence2相对于训练时候的position ids会比原来偏移量多了padding的数目,请问这个会对模型的预测有问题吗?bert4keras的源码有些长,我理解的也比较浅薄,希望苏神可以解答一下,多谢。

WoBERT+ 按词频排序的词表

您好,我发现您提供的 WoBERT+ 中的词表是按照字符长度和字典序排序的,可以提供按词频排序的词表吗

关于MLM预测句子中[MASK]的候选词的问题

我尝试了用google的chinese-bert模型 和 zhuiyi的wobert-plus模型,来进行预测句子中[MASK]的候选词的实验,发现wobert-plus得到的结果都是停用词,请教一下哪里操作错误了。

predict函数:

def TopCandidates(token_ids, i, topn=64):
    """用语言模型给出第i个位置的topn个候选token
    """
    token_ids_i = token_ids[i]
    token_ids[i] = tokenizer._token_mask_id
    token_ids = np.array([token_ids])
    probas = model.predict(token_ids)[0, i]
    ids = list(probas.argsort()[::-1][:topn])
    if token_ids_i in ids:
        ids.remove(token_ids_i)
    else:
        ids = ids[:-1]
    return_token_ids = [token_ids_i] + ids
    return_probas = [probas[_i] for _i in return_token_ids]
    return return_token_ids, return_probas   # 将输入token放在第一位,方便调用

这个是load google的chinese-bert模型

tokenizer = Tokenizer(
    dict_path,
    do_lower_case=True,
)  # 建立分词器

model = build_transformer_model(
    config_path,
    checkpoint_path,
    segment_vocab_size=0,  # 去掉segmeng_ids输入
    with_mlm=True,
)

sent = '***总书记是一位有着47年党龄的共产党员。'
token_ids = tokenizer.encode(sent)[0]
print(token_ids)
print(len(token_ids))
words = tokenizer.ids_to_tokens(token_ids)
print(words)

return_token_ids, return_probas = TopCandidates(token_ids, i=4, topn=8)
for tid, tp in zip(return_token_ids, return_probas):
    print(tid, tokenizer.id_to_token(tid), tp)

output:

[101, 739, 6818, 2398, 2600, 741, 6381, 3221, 671, 855, 3300, 4708, 8264, 2399, 1054, 7977, 4638, 1066, 772, 1054, 1447, 511, 102]
23
['[CLS]', '习', '近', '平', '总', '书', '记', '是', '一', '位', '有', '着', '47', '年', '党', '龄', '的', '共', '产', '党', '员', '。', '[SEP]']
2600 总 0.99927753
5439 老 0.00027872992
4638 的 0.00024593325
1398 同 4.721549e-05
5244 總 3.3572527e-05
1199 副 1.1924949e-05
2218 就 8.368754e-06
3295 曾 7.822215e-06

这个是load zhuiyi的wobert-plus

tokenizer = Tokenizer(
    dict_path,
    do_lower_case=True,
    pre_tokenize=lambda s: jieba.cut(s, HMM=False),
)  # 建立分词器

model = build_transformer_model(
    config_path,
    checkpoint_path,
    segment_vocab_size=0,  # 去掉segmeng_ids输入
    with_mlm=True,
)

sent = '***总书记是一位有着47年党龄的共产党员。'
token_ids = tokenizer.encode(sent)[0]
print(token_ids)
print(len(token_ids))
words = tokenizer.ids_to_tokens(token_ids)
print(words)

return_token_ids, return_probas = TopCandidates(token_ids, i=4, topn=8)
for tid, tp in zip(return_token_ids, return_probas):
    print(tid, tokenizer.id_to_token(tid), tp)

output:

[101, 36572, 39076, 2274, 6243, 21309, 5735, 1625, 513, 5651, 3399, 44374, 179, 102]
14
['[CLS]', '***', '总书记', '是', '一位', '有着', '47', '年', '党', '龄', '的', '共产党员', '。', '[SEP]']
6243 一位 2.1206936e-09 # 概率低
101 [CLS] 0.8942671
102 [SEP] 0.10569866
179 。 2.877889e-06
5661 , 1.2525259e-06
3399 的 1.0122681e-06
178 、 7.5024326e-07
5663 : 5.766404e-07

vocab问题

在与预训练时,加载了roberta的权重,但是又精简了vocab,那么同样的字tokenizer.encode 的结果和 原始roberta encode结果是不一样的,这样是不是会增加收敛的时间?如果保留roberta的vocab,将新增的词放在词表的后面,是不是会收敛快些?

test/csl.py里的wobert换成wonezha以后报错

 line 678, in get_tensor
    return CheckpointReader_GetTensor(self, compat.as_bytes(tensor_str))
tensorflow.python.framework.errors_impl.NotFoundError: Key bert/embeddings/position_embeddings not found in checkpoint

转torch模型时,先导出为ckpt模型,是否需要自己导出vocab.txt并修改bert_config.json

修改train.py

1. 构建模型

设置

bert = build_transformer_model(
    config_path,
    checkpoint_path,
    with_mlm='linear',
    keep_tokens=keep_tokens,  # 只保留keep_tokens中的字,精简原字表
    compound_tokens=compound_tokens,  # 增加词,用字平均来初始化
    return_keras_model=False, 
)

model = bert.model

2. 保存模型

添加

def on_epoch_end(self, epoch, logs=None):
        model.save_weights('bert_model.weights')  # 保存模型
        bert.save_weights_as_checkpoint(filename='ckpt_model/bert_model.ckpt')

3. 保存词汇表

添加

from bert4keras.tokenizers import Tokenizer, load_vocab, save_vocab
# 加载jieba词表的top-num_words个词,去除BERT词表中的一些词
if os.path.exists('tokenizer_config.json'):
    token_dict, keep_tokens, compound_tokens = json.load(
        open('tokenizer_config.json')
    )
    save_vocab("ckpt_model/vocab.txt", token_dict)
else:
    # 加载并精简词表
    token_dict, keep_tokens = load_vocab(
        dict_path=dict_path,
        simplified=True,
        startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'],
    )
    pure_tokenizer = Tokenizer(token_dict.copy(), do_lower_case=True)
    user_dict = []
    for w, _ in sorted(jieba.dt.FREQ.items(), key=lambda s: -s[1]):
        if w not in token_dict:
            token_dict[w] = len(token_dict)
            user_dict.append(w)
        if len(user_dict) == num_words:
            break
    compound_tokens = [pure_tokenizer.encode(w)[0][1:-1] for w in user_dict]
    json.dump([token_dict, keep_tokens, compound_tokens],
              open('tokenizer_config.json', 'w'))
    save_vocab("ckpt_model/vocab.txt", token_dict)

4. 修改bert_config.json

统计词汇表大小

# wc -l vocab.txt
33585 vocab.txt

修改bert_config.json的"vocab_size",添加"model_type"

{
  "attention_probs_dropout_prob": 0.1, 
  "directionality": "bidi", 
  "hidden_act": "gelu", 
  "hidden_dropout_prob": 0.1, 
  "hidden_size": 768, 
  "initializer_range": 0.02, 
  "intermediate_size": 3072, 
  "max_position_embeddings": 512, 
  "num_attention_heads": 12, 
  "num_hidden_layers": 12, 
  "pooler_fc_size": 768, 
  "pooler_num_attention_heads": 12, 
  "pooler_num_fc_layers": 3, 
  "pooler_size_per_head": 128, 
  "pooler_type": "first_token_transform", 
  "type_vocab_size": 2, 
  "vocab_size": 33585,
  "model_type":"bert"
}

5. 用WoBERT_pytorch中的转换脚本转换

CLS问题

WOBERT模型中的CLS向量是不是普通BERT模型中的CLS向量含义一样?将WOBERT模型中的CLS向量用于文本分类模型可以吗?

怎样设成 ckpt文件的

model.save_weights('bert_model.weights') # 保存模型
我看mlm任务保存的 是这个权重,

bert_model.ckpt.data-00000-of-00001 这个怎么来的呢

关于WoBERT+模型无法加载

config_path = 'WoBERT/bert_config.json'
checkpoint_path = 'WoBERT/bert_model.ckpt'
dict_path = 'WoBERT/vocab.txt'
bert = bert4keras.models.build_transformer_model(
config_path=config_path,
checkpoint_path=checkpoint_path,
with_pool=True,
hierarchical_position=True,
return_keras_model=False,
)
当想加载WoBERT+模型,却会报错Key bert/pooler/dense/kernel not found in checkpoint,能帮帮我吗
Traceback (most recent call last):
File "C:\Users\14301\miniconda3\envs\gluon\lib\site-packages\tensorflow\python\training\py_checkpoint_reader.py", line 70, in get_tensor
self, compat.as_bytes(tensor_str))
RuntimeError: Key bert/pooler/dense/kernel not found in checkpoint
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Users\14301\miniconda3\envs\gluon\lib\site-packages\IPython\core\interactiveshell.py", line 3343, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 6, in
return_keras_model=False,
File "C:\Users\14301\miniconda3\envs\gluon\lib\site-packages\bert4keras\models.py", line 2338, in build_transformer_model
transformer.load_weights_from_checkpoint(checkpoint_path)
File "C:\Users\14301\miniconda3\envs\gluon\lib\site-packages\bert4keras\models.py", line 297, in load_weights_from_checkpoint
raise e
File "C:\Users\14301\miniconda3\envs\gluon\lib\site-packages\bert4keras\models.py", line 291, in load_weights_from_checkpoint
values.append(self.load_variable(checkpoint, v))
File "C:\Users\14301\miniconda3\envs\gluon\lib\site-packages\bert4keras\models.py", line 691, in load_variable
variable = super(BERT, self).load_variable(checkpoint, name)
File "C:\Users\14301\miniconda3\envs\gluon\lib\site-packages\bert4keras\models.py", line 262, in load_variable
return tf.train.load_variable(checkpoint, name)
File "C:\Users\14301\miniconda3\envs\gluon\lib\site-packages\tensorflow\python\training\checkpoint_utils.py", line 85, in load_variable
return reader.get_tensor(name)
File "C:\Users\14301\miniconda3\envs\gluon\lib\site-packages\tensorflow\python\training\py_checkpoint_reader.py", line 74, in get_tensor
error_translator(e)
File "C:\Users\14301\miniconda3\envs\gluon\lib\site-packages\tensorflow\python\training\py_checkpoint_reader.py", line 35, in error_translator
raise errors_impl.NotFoundError(None, None, error_message)
tensorflow.python.framework.errors_impl.NotFoundError: Key bert/pooler/dense/kernel not found in checkpoint
不过当转换为pytorch版时,却能够正常加载使用。

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.