Giter VIP home page Giter VIP logo

hellonlp / classifier-multi-label Goto Github PK

View Code? Open in Web Editor NEW
669.0 669.0 140.0 3.56 MB

多标签文本分类,多标签分类,文本分类, multi-label, classifier, text classification, BERT, seq2seq,attention, multi-label-classification

Python 100.00%
attention bert classifier-multi-label cnn multi-label multi-label-classification seq2seq tensorflow text-classification text-classifier textcnn

classifier-multi-label's Introduction

classifier-multi-label's People

Contributors

hellonlp avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

classifier-multi-label's Issues

tf2使用textcnn网络问题

tf1 版本转tf2问题,当不添加textcnn网络时,训练预测均没有问题。但是当加入textcnn时训练时loss与acc都不错,但是预测都是错误的。以下tf2实现的textcnn基本都是直接转的。此外我还尝试tf.keras.layers.Conv2D()以及conv1d实现。但是效果都不行,本来考虑是不是训练周期等参数问题,但是跟您的项目参数保持一致,训练出来的模型就是有问题(有进行dropout),所以想请教一下您。

def textcnn(x):
    pooled_outputs = []

    filter_sizes = [2, 3, 4, 5, 6, 7]
    inputs_expand = tf.expand_dims(x, -1)
    for filter_size in filter_sizes:
        filter_shape = [filter_size, 312, 1, 128]
        W = tf.Variable(tf.random.truncated_normal(filter_shape, stddev=0.1), dtype=tf.float32, name="W")
        b = tf.Variable(tf.constant(0.1, shape=[128]), dtype=tf.float32, name="b")
        conv = tf.nn.conv2d(
            inputs_expand,
            W,
            strides=[1, 1, 1, 1],
            padding="VALID",
            name="conv")
        # Apply nonlinearity
        h = tf.nn.relu(tf.nn.bias_add(conv, b), name="relu")
        # Maxpooling over the outputs
        pooled = tf.nn.max_pool(
            h,
            ksize=[1, 60 - filter_size + 1, 1, 1],
            strides=[1, 1, 1, 1],
            padding='VALID',
            name="pool")
        pooled_outputs.append(pooled)
    # Combine all the pooled features
    num_filters_total = 128 * len(filter_sizes)
    h_pool = tf.concat(pooled_outputs, 3)
    h_pool_flat = tf.reshape(h_pool, [-1, num_filters_total])

    return h_pool_flat

The passed save_path is not a valid checkpoin

Traceback (most recent call last):
File "predict.py", line 43, in
MODEL = ModelAlbertTextCNN()
File "predict.py", line 26, in init
self.albert, self.sess = self.load_model()
File "predict.py", line 39, in load_model
saver.restore(sess, ckpt.model_checkpoint_path)
ValueError: The passed save_path is not a valid checkpoint

大模型

大佬有尝试使用大模型解决类似问题么,例如开源的百川&chatglm等

训练结束后predict.py脚本获取不到标签问题

训练结束后predict.py脚本获取不到标签问题
知乎上也有很多人说predict.py脚本获取到的标签为空,其实不是训练数据有问题或者轮次不够,作者的get_label 函数逻辑有一些小小的问题,我这里简单修改了一下,可以成功获取到标签,新的predict.py 的get_label 函数如下:

def get_label(sentence):
    """
    Prediction of the sentence's label.
    """
    feature = get_feature_test(sentence)
    fd = {MODEL.albert.input_ids: [feature[0]],
          MODEL.albert.input_masks: [feature[1]],
          MODEL.albert.segment_ids:[feature[2]],
          }
    prediction = MODEL.sess.run(MODEL.albert.predictions, feed_dict=fd)[0]
    print(prediction)
    r=[]
    for i in range(len(prediction)):
        if prediction[i]!=0.0:
            r.append(id2label(i))
    return r
    #return [id2label(l) for l in np.where(prediction==1)[0] if l!=0]

pytorch

May I ask if there will be a related version of pytorch released in the future?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.