- 🤗 :huggingface
- 🌱 :知乎
- 👯 :hellonlp
hellonlp / classifier-multi-label Goto Github PK
View Code? Open in Web Editor NEW多标签文本分类,多标签分类,文本分类, multi-label, classifier, text classification, BERT, seq2seq,attention, multi-label-classification
多标签文本分类,多标签分类,文本分类, multi-label, classifier, text classification, BERT, seq2seq,attention, multi-label-classification
您好,能分享下最后版本的requirements.txt吗? 看了您的知乎文章,tf=1.14.0,但是安装了之后依旧报错,如题。
tf1 版本转tf2问题,当不添加textcnn网络时,训练预测均没有问题。但是当加入textcnn时训练时loss与acc都不错,但是预测都是错误的。以下tf2实现的textcnn基本都是直接转的。此外我还尝试tf.keras.layers.Conv2D()以及conv1d实现。但是效果都不行,本来考虑是不是训练周期等参数问题,但是跟您的项目参数保持一致,训练出来的模型就是有问题(有进行dropout),所以想请教一下您。
def textcnn(x):
pooled_outputs = []
filter_sizes = [2, 3, 4, 5, 6, 7]
inputs_expand = tf.expand_dims(x, -1)
for filter_size in filter_sizes:
filter_shape = [filter_size, 312, 1, 128]
W = tf.Variable(tf.random.truncated_normal(filter_shape, stddev=0.1), dtype=tf.float32, name="W")
b = tf.Variable(tf.constant(0.1, shape=[128]), dtype=tf.float32, name="b")
conv = tf.nn.conv2d(
inputs_expand,
W,
strides=[1, 1, 1, 1],
padding="VALID",
name="conv")
# Apply nonlinearity
h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
# Maxpooling over the outputs
pooled = tf.nn.max_pool(
h,
ksize=[1, 60 - filter_size + 1, 1, 1],
strides=[1, 1, 1, 1],
padding='VALID',
name="pool")
pooled_outputs.append(pooled)
# Combine all the pooled features
num_filters_total = 128 * len(filter_sizes)
h_pool = tf.concat(pooled_outputs, 3)
h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total])
return h_pool_flat
from tensorflow.contrib import tpu as contrib_tpu
ModuleNotFoundError: No module named 'tensorflow.contrib'
Traceback (most recent call last):
File "predict.py", line 43, in
MODEL = ModelAlbertTextCNN()
File "predict.py", line 26, in init
self.albert, self.sess = self.load_model()
File "predict.py", line 39, in load_model
saver.restore(sess, ckpt.model_checkpoint_path)
ValueError: The passed save_path is not a valid checkpoint
请问这个该怎么解决?已经降级为protobuf-3.20.3
大佬有尝试使用大模型解决类似问题么,例如开源的百川&chatglm等
训练结束后predict.py脚本获取不到标签问题
知乎上也有很多人说predict.py脚本获取到的标签为空,其实不是训练数据有问题或者轮次不够,作者的get_label 函数逻辑有一些小小的问题,我这里简单修改了一下,可以成功获取到标签,新的predict.py 的get_label 函数如下:
def get_label(sentence):
"""
Prediction of the sentence's label.
"""
feature = get_feature_test(sentence)
fd = {MODEL.albert.input_ids: [feature[0]],
MODEL.albert.input_masks: [feature[1]],
MODEL.albert.segment_ids:[feature[2]],
}
prediction = MODEL.sess.run(MODEL.albert.predictions, feed_dict=fd)[0]
print(prediction)
r=[]
for i in range(len(prediction)):
if prediction[i]!=0.0:
r.append(id2label(i))
return r
#return [id2label(l) for l in np.where(prediction==1)[0] if l!=0]
Failed to find any matching files for /root/autodl-tmp/classifier_multi_label/albert_small_zh_google/albert_model.ckpt,您好,请问这个albert_model.ckpt文件是需要自己添加吗
classifier_multi_label_textcnn 显示没有,这个是什么问题
May I ask if there will be a related version of pytorch released in the future?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.