SkillAgentSearch skills...

Textclf

TextClf :基于Pytorch/Sklearn的文本分类框架,包括逻辑回归、SVM、TextCNN、TextRNN、TextRCNN、DRNN、DPCNN、Bert等多种模型,通过简单配置即可完成数据处理、模型训练、测试等过程。

Install / Use

/learn @luopeixiang/Textclf

README

License

目录:

TextClf简介

前言

TextClf 是一个面向文本分类场景的工具箱,它的目标是可以通过配置文件快速尝试多种分类算法模型、调整参数、搭建baseline,从而让使用者能有更多精力关注于数据本身的特点,做针对性改进优化。

TextClf有以下这些特性:

  • 同时支持机器学习模型如逻辑回归、线性向量机与深度学习模型如TextCNN、TextRNN、TextRCNN、DRNN、DPCNN、Bert等等。
  • 支持多种优化方法,如AdamAdamWAdamaxRMSprop等等
  • 支持多种学习率调整的方式,如ReduceLROnPlateauStepLRMultiStepLR
  • 支持多种损失函数,如CrossEntropyLossCrossEntropyLoss with label smoothingFocalLoss
  • 可以通过和程序交互生成配置,再通过修改配置文件快速调整参数。
  • 在训练深度学习模型时,支持使用对embedding层和classifier层分别使用不同的学习率进行训练
  • 支持从断点(checkpoint)重新训练
  • 具有清晰的代码结构,可以让你很方便的加入自己的模型,使用textclf,你可以不用去关注优化方法、数据加载等方面,可以把更多精力放在模型实现上。

与其他文本分类框架 NeuralClassifier 的比较:

  • NeuralClassifier不支持机器学习模型,也不支持Bert/Xlnet等深度的预训练模型。

  • TextClf会比NeuralClassifier对新手更加友好,清晰的代码结构也会使得你能方便地对它进行拓展。

  • 特别地,对于深度学习模型,TextClf将其看成两个部分,Embedding层和Classifier层。

    Embedding层可以是随机初始化的词向量,也可以是预训练好的静态词向量(word2vec、glove、fasttext),也可以是动态词向量如BertXlnet等等。

    Classifier层可以是MLP,CNN,将来也会支持RCNN,RNN with attention等各种模型。

    通过将embedding层和classifier层分开,在配置深度学习模型时,我们可以选择对embedding层和classifier层进行排列组合,比如Bert embedding + CNNword2vec + RCNN 等等。

    这样,通过比较少的代码实现,textclf就可以涵盖更多的模型组合的可能。

系统设计思路

TextClf将文本分类的流程看成预处理、模型训练、模型测试三个阶段。

预处理阶段做的事情主要是:

  • 读入原始数据,进行分词,构建词典
  • 分析标签分布等数据特点
  • 保存成二进制的形式方便快速读入

数据经过预处理之后,我们就可以在上面训练各种模型、比较模型的效果。

模型训练阶段负责的是:

  • 读入预处理过的数据
  • 根据配置初始化模型、优化器等训练模型必需的因素
  • 训练模型,根据需要最优模型

测试阶段的功能主要是:

  • 加载训练阶段保存的模型进行测试
  • 支持使用文件输入或者终端输入两种方式进行测试

为了方便地对预处理、模型训练、模型测试阶段进行控制,TextClf使用了json文件来对相关的参数(如预处理中指定原始文件的路径、模型训练阶段指定模型参数、优化器参数等等)进行配置。运行的时候,只要指定配置文件,TextClf就会根据文件中的参数完成预处理、训练或者测试等工作,详情可参见 快速开始 部分。

目录结构

textclf源代码目录下有六个子目录和两个文件,每项的作用如下所示:

├── config		# 包括预处理、模型训练、模型测试的各种参数及其默认设置
├── data		# 数据预处理、数据加载的代码
├── models		# 主要包括深度学习模型的实现
├── tester		# 负责加载模型进行测试
├── __init__.py # 模块的初始化文件
├── main.py		# textclf的接口文件,运行textclf会调用该文件中的main函数
├── trainer		# 负责模型的训练
└── utils		# 包含各种工具函数

安装

依赖环境:python >=3.6

使用pip安装:

pip install textclf

安装成功之后就可以使用textclf了!

快速开始

下面我们看一下如何使用textclf训练模型进行文本分类。

在目录examples/toutiao 下有以下文件:

  3900行 train.csv
   600行 valid.csv
   600行 test.csv
  5100行 total

这些数据来自 今日头条新闻分类数据集, 在这里用作演示。

文件的格式如下:

下周一(5.7日)手上持有这些股的要小心   news_finance
猪伪狂犬苗的免疫方案怎么做?    news_edu
小米7未到!这两款小米手机目前性价比最高,米粉:可惜买不到       news_tech
任何指望技术来解决社会公正、公平的设想,都是幻想        news_tech
诸葛亮能借东风火烧曹营,为什么火烧司马懿却没料到会下雨?        news_culture
福利几款旅行必备神器,便宜实用颜值高!  news_travel
抵押车要怎样年审和购买保险?    news_car
现在一万一平米的房子,十年后大概卖多少钱?      news_house
第一位有中国国籍的外国人,留中国五十多年,死前留下这样的话!    news_world
为什么A股投资者越保护越亏?     stock

文件每一行由两个字段组成,分别是句子和对应的label,句子和label之间使用\t字符隔开。

预处理

第一步是预处理。预处理将会完成读入原始数据,进行分词,构建词典,保存成二进制的形式方便快速读入等工作。要对预处理的参数进行控制,需要相应的配置文件,textclf中的help-config功能可以帮助我们快速生成配置,运行:

textclf help-config

输入0让系统为我们生成默认的PreprocessConfig,接着将它保存成preprocess.json文件:

(textclf) luo@luo-pc:~/projects$ textclf help-config
Config  有以下选择(Default: DLTrainerConfig): 
0. PreprocessConfig     预处理的设置
1. DLTrainerConfig      训练深度学习模型的设置
2. DLTesterConfig       测试深度学习模型的设置
3. MLTrainerConfig      训练机器学习模型的设置
4. MLTesterConfig       测试机器学习模型的设置
输入您选择的ID (q to quit, enter for default):0
Chooce value PreprocessConfig   预处理的设置
输入保存的文件名(Default: config.json): preprocess.json
已经将您的配置写入到 preprocess.json,你可以在该文件中查看、修改参数以便后续使用
Bye!

打开文件preprocess.json,可以看到以下内容:

{
    "__class__": "PreprocessConfig",
    "params": {
        "train_file": "train.csv",
        "valid_file": "valid.csv",
        "test_file": "test.csv",
        "datadir": "dataset",
        "tokenizer": "char",
        "nwords": -1,           
        "min_word_count": 1
    }
}

params中是我们可以进行设置的参数,这些字段的详细含义可以查看文档。 这里我们只需要把datadir字段修改成toutiao目录即可 (最好使用绝对路径,若使用相对路径,要确保当前工作目录正确访问该路径。)

然后,就可以根据配置文件进行预处理了:

textclf --config-file preprocess.json preprocess

如无错误,输出如下:

(textclf) luo@V_PXLUO-NB2:~/textclf/test$ textclf --config-file config.json preprocess
Tokenize text from /home/luo/textclf/textclf_source/examples/toutiao/train.csv...
3900it [00:00, 311624.35it/s]
Tokenize text from /home/luo/textclf/textclf_source/examples/toutiao/valid.csv...
600it [00:00, 299700.18it/s]
Tokenize text from /home/luo/textclf/textclf_source/examples/toutiao/test.csv...
600it [00:00, 289795.30it/s]
Label Prob:
+--------------------+-------------+-------------+------------+
|                    |   train.csv |   valid.csv |   test.csv |
+====================+=============+=============+============+
| news_finance       |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| news_edu           |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| news_tech          |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| news_culture       |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| news_travel        |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| news_car           |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| news_house         |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| news_world         |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| stock              |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| news_story         |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| news_agriculture   |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| news_entertainment |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| news_military      |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| news_sports        |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| news_game          |      0.0667 |      0.0667 |     0.0667 |
+--------------------+-------------+-------------+------------+
| Sum                |   3900.0000 |    600.0000 |   600.0000 |
+--------------------+-------------+-------------+------------+
Dictionary Size: 2981
Saving data to ./textclf.joblib...

预处理会打印每个数据集标签分布的信息。同时,处理过后的数据被保存到二进制文件./textclf.joblib中了。 (每个类别所含的样本数是相同的。)

预处理中的详细参数说明,请查看文档

训练一个逻辑回归模型

同样的,我们先使用textclf help-config生成train_lr.json配置文件,输入3 选择训练机器学习模型的配置。 根据提示分别选择CountVectorizer(文本向量化的方式)以及模型LR

(textclf) luo@luo-pc:~/projects$ textclf help-config
Config  有以下选择(Default: DLTrainerConfig): 
0. PreprocessConfig     预处理的设置
1. DLTrainerConfig      训练深度学习模型的设置
2. DLTesterConfig       测试深度学习模型的设置
3. MLTrainerConfig      训练机器学习模型的设置
4. MLTesterConfig       测试机器学习模型的设置
输入您选择的ID (q to quit, enter for default):3
Chooce value MLTrainerConfig    训练机器学习模型的设置
正在设置vectorizer
vectorizer 有以下选择(Default: CountVectorizer): 
0. CountVectorizer
1. TfidfVectorizer
输入您选择的ID (q to quit, enter for default):0
Chooce value CountVectorizer
正在设置model
model 有以下选择(Default: LogisticRegression): 
0. LogisticRegression
1. LinearSVM
输入您选择的ID (q to quit, enter for default):0
Chooce value LogisticRegression
输入保存的文件名(Default: config.json): train_lr.json
已经将您的配置写入到 train_lr.json,你可以在该文件中查看、修改参数以便后续使用
Bye!

对于更细粒度的配置,如逻辑回归模型的参数,CountVectorizer的参数,可以在生成的train_lr.json中进行修改。这里使用默认的配置进行训练:

textclf --config-file train_lr.json train

因为数据量比较小,所以应该马上就能看到结果。训练结束后,textclf会在测试集上测试模型效果,同时将模型保存在ckpts目录下。

机器学习模型训练中的详细参数说明,请查看文档

加载训练完毕的模型进行测试分析

首先使用help-config生成MLTesterConfig的默认设置到test_lr.json

(textclf) luo@luo-pc:~/projects$ textclf help-config
Config  有以下选择(Default: DLTrainerConfig): 
0. PreprocessConfig     预处理的设置
1. DLTrainerConfig      训练深度学习模型的设置
2. DLTesterConfig       测试深度学习模型的设置
3. MLTrainerConfig      训练机器学习模型的设置
4. MLTesterConfig       测试机器学习模型的设置
输入您选择的ID (q to quit, enter for default):4
Chooce value MLTesterConfig     测试机器学习模型的设置
输入保存的文件名(Default: config.json): test_lr.json
已经将您的配置写入到 test_lr.json,你可以在该文件中查看、修改参数以便后续使用
Bye!

test_lr.json中的input_file字段修改成query_intent_toy_data/test.csv 的路径,然后进行测试:

textclf --config-file test_lr.json test

测试结束,textclf将会打印出准确率、每个label的f1值:

Writing predicted labels to predict.csv
Acc in test file:66.67%
Report:
                    precision    recall  f1-score   support

  news_agriculture     0.6970    0.5750    0.6301        40
          news_car     0.8056    0.7250    0.7632        40
      news_culture     0.7949    0.7750    0.7848        40
          news_edu     0.8421    0.8000    0.8205        40
news_entertainment     0.6000    0.6000    0.6000        40
      news_finance     0.2037    0.2750    0.2340        40
         news_game     0.7111    0.
View on GitHub
GitHub Stars245
CategoryDevelopment
Updated7h ago
Forks40

Languages

Python

Security Score

100/100

Audited on Apr 1, 2026

No findings