Giter VIP home page Giter VIP logo

littlewang1220 / text_classifier Goto Github PK

View Code? Open in Web Editor NEW

This project forked from stanleylsx/text_classifier_tf2

0.0 0.0 0.0 141.42 MB

该项目是使用TextCNN/TextRCNN的文本分类任务,嵌入层可接入Word2Vec,Bert,也可以直接使用词粒度的随机embedding,带有Attention模块,项目基于Tensorflow2.3开发。数据的获取见app_comments_spider爬虫项目。

Python 100.00%

text_classifier's Introduction

Text Classifier

此仓库是基于Tensorflow2.3的文本分类任务,分别支持:

  • 随机初始Word Embedding + TextCNN
  • 随机初始Word Embedding + Attention + TextCNN
  • 随机初始Word Embedding + TextRCNN
  • Word2Vec + TextCNN
  • Word2Vec + Attention + TextCNN
  • Word2Vec + TextRCNN
  • Bert Embedding(没有微调,直接取向量) + TextCNN
  • Bert Embedding(没有微调,直接取向量) + TextRCNN

代码支持二分类和多分类,此项目基于爬取的游戏评论做了个二元的情感分类作为demo。

环境

  • python 3.6.7
  • tensorflow==2.3.0
  • gensim==3.8.3
  • jieba==0.42.1
  • sklearn==0.0

其他环境见requirements.txt

更新历史

日期 版本 描述
2018-12-01 v1.0.0 初始仓库
2020-10-20 v2.0.0 重构项目
2020-10-26 v2.1.0 加入F1、Precise、Recall分类指标,计算方式支持macro、micro、average、binary
2020-11-06 v2.2.0 加入TextRCNN
2020-11-19 v2.3.0 加入Attention
2020-11-26 v2.3.1 加入focal loss用于改善标签分布不平衡的情况
2020-11-19 v2.4.0 增加每个类别的指标,重构指标计算逻辑
2021-03-02 v2.5.0 使用Dataset替换自己写的数据加载器来加载数据
2021-03-15 v3.0.0 支持仅使用TextCNN/TextRCNN进行数据训练(基于词粒度的token,使用随机生成的Embedding层)
2021-03-16 v3.1.0 支持取用Bert的编码后接TextCNN/TextRCNN进行数据训练(此项目Bert不支持预训练);在log中打印配置
2021-03-17 v3.1.1 根据词频过滤一部分频率极低的词,不加入词表

数据集

我的另外一个爬虫项目app_comments_spider中爬取

原理

Word2vec

可以参考我的博客文章01-NLP介绍和词向量02-词向量第二部分和词义
也可看博客刘建平Pinard和文章技术干货 | 漫谈Word2vec之skip-gram模型

TextCNN

textcnn

TextRCNN

textrcnn

使用

配置

在config.py中配置好各个参数,文件中有详细参数说明

训练word2vec

在config.py中的mode中改成train_word2vec并运行

# [train_classifier, interactive_predict, train_word2vec]
mode = 'train_word2vec'

训练分类器

配置好下列参数

classifier_config = {
    # 模型选择
    'classifier': 'textcnn',
    # 训练数据集
    'train_file': 'data/data/train_data.csv',
    # 引入外部的词嵌入,可选word2vec、Bert
    # 此处只使用Bert Embedding,不对其做预训练
    # None:使用随机初始化的Embedding
    'embedding_method': 'Bert',
    # 不外接词向量的时候需要自定义的向量维度
    'embedding_dim': 300,
    # 存放词表的地方
    'token_file': 'data/data/token2id',
    # 验证数据集
    'dev_file': 'data/data/dev_data.csv',
    # 类别和对应的id
    'classes': {'negative': 0, 'positive': 1},
    # 模型保存的文件夹
    'checkpoints_dir': 'model/bert_textcnn',
    # 模型保存的名字
    'checkpoint_name': 'bert_textcnn',
    # 卷集核的个数
    'num_filters': 64,
    # 学习率
    'learning_rate': 0.001,
    # 训练epoch
    'epoch': 30,
    # 最多保存max_to_keep个模型
    'max_to_keep': 1,
    # 每print_per_batch打印
    'print_per_batch': 20,
    # 是否提前结束
    'is_early_stop': True,
    # 是否引入attention
    # 注意:textrcnn不支持
    'use_attention': False,
    # attention大小
    'attention_dim': 300,
    'patient': 8,
    'batch_size': 64,
    'max_sequence_length': 150,
    # 遗忘率
    'droupout_rate': 0.5,
    # 隐藏层维度
    # 使用textrcnn中需要设定
    'hidden_dim': 200,
    # 若为二分类则使用binary
    # 多分类使用micro或macro
    'metrics_average': 'binary',
    # 类别样本比例失衡的时候可以考虑使用
    'use_focal_loss': False
}

配置完参数之后开始训练模型

# [train_classifier, interactive_predict, train_word2vec]
mode = 'train_classifier'
  • textcnn训练结果

train_results_textcnn

  • att-textcnn训练结果

train_results_att-textcnn

  • textrcnn训练结果

train_results_textrcnn

测试

训练好textcnn可以开始测试

# [train_classifier, interactive_predict, train_word2vec]
mode = 'interactive_predict'
  • 交互测试结果

test

参考

text_classifier's People

Contributors

stanleylsx avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.