飞道的博客

Pytorch TextCNN实现中文文本分类(附完整训练代码)

1164人阅读  评论(0)

Pytorch TextCNN实现中文文本分类(附完整训练代码)

目录

Pytorch TextCNN实现中文文本分类(附完整训练代码)

一、项目介绍

二、中文文本数据集

(1)THUCNews文本数据集

(2) 今日头条文本数据集 

(3)自定义文本数据集

三、TextCNN模型结构

(1)TextCNN模型结构

(2)TextCNN实现

四、训练词嵌入word2vec(可选)

五、文本预处理

(1)句子分词处理:jieba中文分词

(2)特殊字符处理

(3)文本数据增强

六、训练过程 

(1)项目框架说明

(2)准备Train和Test文本数据

(3)配置文件:config_textfolder.yaml

(4)开始训练

(5)可视化训练过程

(6)一些优化建议

七. 模型测试效果

八.项目源码下载


一、项目介绍

本篇将分享一个NLP项目实例,利用深度学习框架Pytorch,构建TextCNN模型,实现一个简易的中文文本分类模型;基于该项目训练的TextCNN的文本分类模型在THUCNews数据集上,训练集的Accuracy 99%左右,测试集的Accuracy在88.36%左右。

【尊重原则,转载请注明出处】https://blog.csdn.net/guyuealian/article/details/127846717

二、中文文本数据集

中文文本数据集特别多,这里仅仅介绍2个常用的文本文本分类数据集

(1)THUCNews文本数据集

THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档(2.19 GB),均为UTF-8纯文本格式。我们在原始新浪新闻分类体系的基础上,重新整合划分出14个候选分类类别:财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐。使用THUCTC工具包在此数据集上进行评测,准确率可以达到88.6%。

  • 官方数据集下载链接: http://thuctc.thunlp.org/message
  • 百度网盘下载链接: https://pan.baidu.com/s/1DT5xY9m2yfu1YGaGxpWiBQ 提取码: bbpe
  • THUCTC: 一个高效的中文文本分类工具包: THUCTC: 一个高效的中文文本分类工具

(2) 今日头条文本数据集 

今日头条文本数据集数据来源于今日头条客户端,约382688条,分布于15个分类中。

数据格式:

6552431613437805063_!_102_!_news_entertainment_!_谢娜为李浩菲澄清网络谣言,之后她的两个行为给自己加分_!_佟丽娅,网络谣言,快乐大本营,李浩菲,谢娜,观众们

每行为一条数据,以_!_分割的个字段,从前往后分别是 新闻ID,分类code(见下文),分类名称(见下文),新闻字符串(仅含标题),新闻关键词;分类code与名称:


  
  1. 100 民生 故事 news_story
  2. 101 文化 文化 news_culture
  3. 102 娱乐 娱乐 news_entertainment
  4. 103 体育 体育 news_sports
  5. 104 财经 财经 news_finance
  6. 106 房产 房产 news_house
  7. 107 汽车 汽车 news_car
  8. 108 教育 教育 news_edu
  9. 109 科技 科技 news_tech
  10. 110 军事 军事 news_military
  11. 112 旅游 旅游 news_travel
  12. 113 国际 国际 news_world
  13. 114 证券 股票 stock
  14. 115 农业 三农 news_agriculture
  15. 116 电竞 游戏 news_game

GitHub - aceimnorstuvwxz/toutiao-text-classfication-dataset: 今日头条中文新闻(文本)分类数据集

(3)自定义文本数据集

如果需要新增类别数据,或者需要自定数据集进行训练,可以如下进行处理:

  • Train和Test数据集:一个样本一个txt文本,要求相同类别的文本,放在同一个文件夹下;且子目录文件夹命名为类别名称,如

  • 类别文件 class_name.txt : (一行一个列表,最后一行,请多回车一行)

  
  1. A
  2. B
  3. C
  4. D
  • 修改配置文件数据路径:config.yaml

  
  1. # 训练数据集,可支持多个数据集
  2. train_data:
  3. - "data/dataset/train"
  4. # 测试数据集
  5. test_data:
  6. - "data/dataset/test"
  7. vocab_file: "./data/dataset/vocabulary.json" # 字典文件(会根据训练数据集自动生成)
  8. # 类别文件
  9. class_name: "data/dataset/class_name.txt"

三、TextCNN模型结构

(1)TextCNN模型结构

TextCNN文本分类的网络结,如下图所示,可以分为4部分:分别为输入层,CNN层,池化层和输出层

 以中文文本情感分类(二分类)作为简单的例子。

  1. 输入层:也称embedding层,TextCNN的输入序列是一个固定长度的句子:图示中是由11个词组成一条句子(context_size=11),每个词用6维词向量表示(embedding_dim=6),即输入通道数in_channels=6。因此输入序列shape=(11,6),加上Batch这个维度,则是shape=(batch_size,context_size,embedding_dim)=(B,11,6)
  2. CNN层,也称卷积层,由一维卷积核(Conv1d)组成,左边的一维卷积核大小为2(kernel_size=2),输出通道数分别设为4;右边的一维卷积核大小为4(kernel_size=4),输出通道数分别设为5;卷积步长stride=1;因此,一维卷积计算后,左边一维卷积输出宽度=11−2+1=10,右边边一维卷积输出宽度11−4+1=8。
  3. 池化层:将CNN层的输出的9个通道经过时序最大池化(max_pool1d),并将池化输出cat连结成一个9维向量。
  4. 分类层:也是输出层,由简单的全连接层组成;对于简单二分类,其输出维度2,即正面情感和负面情感的预测(概率)。

(2)TextCNN实现

根据TextCNN网络结构,我们可以使用Pytorch构建一个TextCNN模型


  
  1. # -*-coding: utf-8 -*-
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. class GlobalMaxPool1d(nn.Module):
  6. def __init__( self):
  7. super(GlobalMaxPool1d, self).__init__()
  8. def forward( self, x):
  9. return F.max_pool1d(x, kernel_size=x.shape[ 2]) # shape: (batch_size, channel, 1)
  10. class TextCNN(nn.Module):
  11. def __init__( self, num_classes, num_embeddings=-1, embedding_dim=128, kernel_sizes=[3, 4, 5, 6],
  12. num_channels=[256, 256, 256, 256], embeddings_pretrained=None):
  13. """
  14. :param num_classes: 输出维度(类别数num_classes)
  15. :param num_embeddings: size of the dictionary of embeddings,词典的大小(vocab_size),
  16. 当num_embeddings<0,模型会去除embedding层
  17. :param embedding_dim: the size of each embedding vector,词向量特征长度
  18. :param kernel_sizes: CNN层卷积核大小
  19. :param num_channels: CNN层卷积核通道数
  20. :param embeddings_pretrained: embeddings pretrained参数,默认None
  21. :return:
  22. """
  23. super(TextCNN, self).__init__()
  24. self.num_classes = num_classes
  25. self.num_embeddings = num_embeddings
  26. # embedding层
  27. if self.num_embeddings > 0:
  28. # embedding之后的shape: torch.Size([200, 8, 300])
  29. self.embedding = nn.Embedding(num_embeddings, embedding_dim)
  30. if embeddings_pretrained is not None:
  31. self.embedding = self.embedding.from_pretrained(embeddings_pretrained, freeze= False)
  32. # 卷积层
  33. self.cnn_layers = nn.ModuleList() # 创建多个一维卷积层
  34. for c, k in zip(num_channels, kernel_sizes):
  35. cnn = nn.Sequential(
  36. nn.Conv1d(in_channels=embedding_dim,
  37. out_channels=c,
  38. kernel_size=k),
  39. nn.BatchNorm1d(c),
  40. nn.ReLU(inplace= True),
  41. )
  42. self.cnn_layers.append(cnn)
  43. # 最大池化层
  44. self.pool = GlobalMaxPool1d()
  45. # 输出层
  46. self.classify = nn.Sequential(
  47. nn.Dropout(p= 0.2),
  48. nn.Linear( sum(num_channels), self.num_classes)
  49. )
  50. def forward( self, input):
  51. """
  52. :param input: (batch_size, context_size, embedding_size(in_channels))
  53. :return:
  54. """
  55. if self.num_embeddings > 0:
  56. # 得到词嵌入(b,context_size)-->(b,context_size,embedding_dim)
  57. input = self.embedding( input)
  58. # (batch_size, context_size, channel)->(batch_size, channel, context_size)
  59. input = input.permute( 0, 2, 1)
  60. y = []
  61. for layer in self.cnn_layers:
  62. x = layer( input)
  63. x = self.pool(x).squeeze(- 1)
  64. y.append(x)
  65. y = torch.cat(y, dim= 1)
  66. out = self.classify(y)
  67. return out
  68. if __name__ == "__main__":
  69. device = "cuda:0"
  70. batch_size = 4
  71. num_classes = 2 # 输出类别
  72. context_size = 7 # 句子长度(字词个数)
  73. num_embeddings = 1024 # 词典的大小(vocab_size)
  74. embedding_dim = 6 # 词向量特征长度
  75. kernel_sizes = [ 2, 4] # CNN层卷积核大小
  76. num_channels = [ 4, 5] # CNN层卷积核通道数
  77. input = torch.ones(size=(batch_size, context_size)).long().to(device)
  78. model = TextCNN(num_classes=num_classes,
  79. num_embeddings=num_embeddings,
  80. embedding_dim=embedding_dim,
  81. kernel_sizes=kernel_sizes,
  82. num_channels=num_channels,
  83. )
  84. model = model.to(device)
  85. model. eval()
  86. output = model( input)
  87. print( "-----" * 10)
  88. print(model)
  89. print( "-----" * 10)
  90. print( " input.shape:{}". format( input.shape))
  91. print( "output.shape:{}". format(output.shape))
  92. print( "-----" * 10)

测试模型打印结果:


四、训练词嵌入word2vec(可选)

  • 不管是CNN还是RNN模型,都是无法直接处理字符类别的单词,因此需要通过某种方法把单词变成数字形式的向量才能作为模型的输入。把单词映射到向量空间中的一个向量的做法称为词嵌入(word embedding),对应的向量称为词向量(word vector)
  • 上面的TextCNN模型代码中,定义了一个可学习的embedding层,即词嵌入word2vec,其作用就是将word序号ID转换为vector;当然你也可以通过gensim训练自己的word2vec模型,然后在数据处理中先将文本转换为词向量,这样TextCNN就没有必要添加embedding层了。

项目仓库中,提供了基于gensim的word2vec训练代码: word2vec.py ,用户只需要修改好数据路径即可开始训练


  
  1. # -*-coding: utf-8 -*-
  2. """
  3. @Author : panjq
  4. @E-mail : 390737991@qq.com
  5. @Date : 2022-09-26 14:50:34
  6. @Brief :
  7. """
  8. import os
  9. import sys
  10. sys.path.insert( 0, os.getcwd())
  11. import random
  12. import numpy as np
  13. from gensim.models import word2vec
  14. from core.utils import jieba_utils, nlp_utils
  15. from pybaseutils import file_utils
  16. class ChineseWord2Vector( object):
  17. """中文word2vec"""
  18. def __init__( self, stop_words=[], vector_size=128, window=5, min_count=5, epochs=10, workers=4):
  19. """
  20. :param stop_words: 停用词,用于ignore的字词
  21. :param vector_size: 是每个词的向量维度embedding_size
  22. :param window: 是词向量训练时的上下文扫描窗口大小,窗口为5就是考虑前5个词和后5个词
  23. :param min_count: 设置最低频数,默认是5,如果一个词语在文档中出现的次数小于5,那么就会丢弃
  24. :param epochs: Number of iterations (epochs) over the corpus. (Formerly: `iter`)
  25. :param workers: 是训练的线程数,默认是当前运行机器的处理器核数
  26. """
  27. self.stop_words = stop_words if stop_words else jieba_utils.get_common_stop_words()
  28. self.vector_size = vector_size
  29. self.epochs = epochs
  30. self.window = window
  31. self.min_count = min_count
  32. self.workers = workers
  33. self.model: word2vec.Word2Vec = None
  34. def init_model( self):
  35. self.index_to_key = self.model.wv.index_to_key
  36. self.key_to_index = self.model.wv.key_to_index
  37. self.embedding = self.model.wv.vectors
  38. self.vector_size = self.model.wv.vector_size
  39. return self.model
  40. def cut_words_files( self, corpus: str, cutwords: str, user_file: str = "data/user_dict.txt", stop_words=[]):
  41. """
  42. :param corpus: 语料文件
  43. :param cutwords: jieba分词后保存的根目录
  44. :param user_file: 用户自定义的文件
  45. :param stop_words: 停用词,用于ignore的字词
  46. :return:
  47. """
  48. jieba_utils.load_userdict(user_file)
  49. print( "corpus root :{}". format(corpus))
  50. print( "output cutwords :{}". format(cutwords))
  51. print( "user_file :{}". format(user_file))
  52. print( "stop_words :{}". format(stop_words))
  53. if not stop_words: stop_words = self.stop_words
  54. self.stop_words = stop_words
  55. nlp_utils.get_files_sentences_cutword(corpus, cutwords, stop_words=stop_words, block_size= 10000)
  56. # 若只有一个文件,使用LineSentence读取文件
  57. # sentences = word2vec.LineSentence(segment_path)
  58. # 若存在多文件,使用PathLineSentences读取文件列表
  59. # sentences = word2vec.PathLineSentences(cutwords)
  60. sentences = word2vec.PathLineSentences(cutwords)
  61. return sentences
  62. def start_train( self, sentences):
  63. """
  64. :param sentences: *.txt文件路径,所有字词需要预处理并被空格分隔
  65. sentences可以是LineSentence或者PathLineSentences读取的文件对象,也可以是
  66. The `sentences` iterable can be simply a list of lists of tokens,
  67. 如lists=[['我','是','中国','人'],['我','的','家乡','在','广东']]
  68. """
  69. self.model = word2vec.Word2Vec(sentences,
  70. vector_size=self.vector_size,
  71. window=self.window,
  72. min_count=self.min_count,
  73. workers=self.workers,
  74. epochs=self.epochs,
  75. seed= 2020,
  76. )
  77. def save_model( self, model_file) -> word2vec.Word2Vec:
  78. file_utils.create_file_path(model_file)
  79. self.model.save(model_file)
  80. self.init_model()
  81. return self.model
  82. def load_model( self, model_file) -> word2vec.Word2Vec:
  83. self.model = word2vec.Word2Vec.load(model_file)
  84. self.init_model()
  85. return self.model
  86. def get_similarity( self, key1, key2):
  87. """Compute cosine similarity between two keys."""
  88. return self.model.wv.similarity(key1, key2)
  89. def get_index( self, key, default=None):
  90. """Return the integer index (slot/position) where the given key's vector is stored in the backing vectors array."""
  91. return self.model.wv.get_index(key, default=default)
  92. def get_vector( self, key, norm=False):
  93. """Get the key's vector, as a 1D numpy array."""
  94. return self.model.wv.get_vector(key, norm=norm)
  95. def get_text_vector( self, text, context_size=-1, pad_token='<pad>'):
  96. """
  97. 将句子中的所有词转为词向量
  98. :param text:
  99. :return: context_size 句子最大长度max_size
  100. :return: pad_token 句子不足时,是否填充0
  101. """
  102. if context_size > 0: text = text[ 0: min( 6 * context_size, len(text))]
  103. words = jieba_utils.cut_content_word(text, stop_words=self.stop_words)
  104. words = jieba_utils.padding_words(words, context_size=context_size, pad_token=pad_token)
  105. vector = self.get_words_vector(words)
  106. return vector
  107. def get_words_vector( self, words):
  108. """
  109. 将word转换为vecror
  110. :param words:
  111. :return:
  112. """
  113. vector = []
  114. for w in words:
  115. try:
  116. v = self.get_vector(w)
  117. except Exception as e:
  118. v = np.zeros(shape=(self.model.vector_size,), dtype=np.float32)
  119. vector.append(v)
  120. vector = np.asarray(vector, dtype=np.float32)
  121. return vector
  122. def get_words_vector_padding( self, words, context_size=256, random_crop=False, padding=True):
  123. vector = []
  124. for w in words:
  125. try:
  126. v = self.get_vector(w)
  127. vector.append(v)
  128. except Exception as e:
  129. pass
  130. if len(vector) == 0: return []
  131. vector = np.asarray(vector, dtype=np.float32)
  132. nums, dims = vector.shape
  133. pad = context_size - nums
  134. if padding and pad > 0:
  135. zeros = np.zeros(shape=(pad, dims), dtype=np.float32)
  136. vector = np.concatenate([vector, zeros], axis= 0)
  137. if random_crop and pad < 0:
  138. start = random.randint( 0, nums - context_size)
  139. vector = vector[start:start + context_size, :]
  140. else:
  141. vector = vector[ 0:context_size, :]
  142. return vector
  143. def train_simple_demo():
  144. source = './data/source' # 文本数据路径
  145. user_file = 'data/user_dict.txt'
  146. cutwords = os.path.join(os.path.dirname(source), "cutwords") # 分词结果
  147. model_file = os.path.join(os.path.dirname(source), "word2vec", "simple_word2vec128.model")
  148. wv_trainer = ChineseWord2Vector(vector_size= 128, window= 10, min_count= 5, epochs= 10)
  149. sentences = wv_trainer.cut_words_files(source, cutwords, user_file=user_file)
  150. wv_trainer.start_train(sentences)
  151. wv_trainer.save_model(model_file)
  152. model = wv_trainer.load_model(model_file)
  153. print( "save word2vec:{}". format(model_file))
  154. # 测试
  155. w1 = '沙瑞金'
  156. w2 = '高育良'
  157. w3 = '车'
  158. vector = wv_trainer.get_vector(w1)
  159. print( "({},{}),similarity={}". format(w1, w2, model.wv.similarity(w1, w2)))
  160. print( "({},{}),similarity={}". format(w1, w3, model.wv.similarity(w1, w3)))
  161. # print("{} shape={},vector= \n{}".format(w1, vector.shape, vector))
  162. vector = wv_trainer.get_text_vector( "我是一名中国人zhongguo")
  163. def train_THUCNews():
  164. source = '/home/dm/nasdata/dataset/csdn/Text/THUCNews' # 文本数据路径
  165. user_file = "./data/user_dict.txt"
  166. cutwords = os.path.join(os.path.dirname(source), "THUCNews-cutwords") # 分词结果
  167. model_file = os.path.join(os.path.dirname(source), "word2vec128.model")
  168. wv_trainer = ChineseWord2Vector(vector_size= 128, window= 10, min_count= 5, epochs= 10)
  169. sentences = wv_trainer.cut_words_files(source, cutwords, user_file=user_file)
  170. wv_trainer.start_train(sentences)
  171. wv_trainer.save_model(model_file)
  172. model = wv_trainer.load_model(model_file)
  173. print( "save word2vec:{}". format(model_file))
  174. # 测试
  175. w1 = '北京'
  176. w2 = '上海'
  177. w3 = '吃饭'
  178. vector = wv_trainer.get_vector(w1)
  179. print( "({},{}),similarity={}". format(w1, w2, model.wv.similarity(w1, w2)))
  180. print( "({},{}),similarity={}". format(w1, w3, model.wv.similarity(w1, w3)))
  181. # print("{} shape={},vector= \n{}".format(w1, vector.shape, vector))
  182. vector = wv_trainer.get_text_vector( "我是一名中国人zhongguo")
  183. if __name__ == '__main__':
  184. # 简单的训练词嵌入模型
  185. train_simple_demo()
  186. # 使用THUCNews数据训练词嵌入模型
  187. # train_THUCNews()

样例中,使用小说《人民名义》 训练一个word2vec模型,训练完成后,测试单词(沙瑞金,高育良)的相似性similarity=0.8832;而(沙瑞金,车)的相似性similarity=0.4969。


五、文本预处理

Pytorch的提供文本处理工具torchtext;该工具功能非常强大,提供了很多nlp方面的数据集,可以直接加载使用,也提供了不少训练好的词向量之类的;但该工具封装的太高级了,实际使用起来,限制也太多了,灵活性不高,导致这个模块使用起来特别的别扭。所有后面干脆自己写Dataset数据处理方式了;

对于中文文本数据预处理,主要有两部分:句子分词处理(英文文本不需要分词),特殊字符处理

(1)句子分词处理:jieba中文分词

本博客使用jieba工具进行中文分词,工具比较简单,就不单独说明了,安装方法:

pip install jieba 

(2)特殊字符处理

jieba分词后,会出现很多特殊字符,需要进一步做一些的处理

  • 一些换行符,空格等特殊字符,以及一些标点符号(,。!?《》)等,这些特殊的字符称为stop_words,需要剔除
  • 一些英文字母大小需要转换统一为小写
  • 一些繁体字统一转换为简体字等
  • 一些专有名词,比如地名,人名这些,分词时需要整体切词:jieba.load_userdict(file)

(3)文本数据增强

在计算机视觉图像识别任务中,图像数据增强主要有:裁剪、翻转、旋转、⾊彩变换等⽅式,其目的增加数据的多样性,提高模型的泛化能力。但是NLP任务中的数据是离散的,无法像操作图片一样连续的方式操作文字,这导致我们⽆法对输⼊数据进⾏直接简单地转换,换掉⼀个词就有可能改变整个句⼦的含义。

常用的NLP文本数据增强方法主要有:

  • 随机截取: 随机截取文本一个片段
  • 同义词替换(SR: Synonyms Replace):不考虑stopwords,在句⼦中随机抽取n个词,然后从同义词词典中随机抽取同义词,并进⾏替换。
  • 随机插⼊(RI: Randomly Insert):不考虑stopwords,随机抽取⼀个词,然后在该词的同义词集合中随机选择⼀个,插⼊原句⼦中的随机位置。
  • 随机交换(RS: Randomly Swap):句⼦中,随机选择两个词,位置交换。
  • 随机删除(RD: Randomly Delete):句⼦中的每个词,以概率p随机删除

项目已经实现:随机截取,随机插⼊,随机删除等几种文本数据增强方式:


  
  1. # -*- coding: utf-8 -*-
  2. import math
  3. import random
  4. from typing import List
  5. def random_text_crop( text: List, label, context_size, token="<pad>", p=0.5):
  6. """
  7. 句⼦中的每个词,以概率p随机截取
  8. :param text:
  9. :param label:
  10. :param context_size:
  11. :param token:
  12. :param p:
  13. :return:
  14. """
  15. context_size = int(context_size)
  16. nums = len(text)
  17. pad = context_size - nums
  18. if pad > 0 and token:
  19. text = [token] * pad + text
  20. if random.random() < p and pad < 0:
  21. start = random.randint( 0, nums - context_size)
  22. text = text[start:start + context_size]
  23. elif len(text) > context_size:
  24. text = text[ 0:context_size]
  25. return text, label
  26. def random_text_mask( text: List, label, len_range=(0, 2), token="<pad>", p=0.5):
  27. """
  28. 句⼦中的每个词,以概率p替换成token
  29. :param text:
  30. :param label:
  31. :param len_range:
  32. :param p:
  33. :return:
  34. """
  35. if random.random() < p and len(text) > 2 * len_range[ 1]:
  36. nums = math.ceil(random.uniform(len_range[ 0], len_range[ 1]))
  37. for i in range(nums):
  38. index = int(random.uniform( 0, len(text) - 1))
  39. text[index] = token
  40. return text, label
  41. def random_text_delete( text: List, label, len_min, p=0.5):
  42. """
  43. 句⼦中的每个词,以概率p随机删除
  44. :param text:
  45. :param label:
  46. :param len_min: 句子最小长度,低于该值,不会删除
  47. :param p:
  48. :return:
  49. """
  50. if random.random() < p and len(text) > len_min:
  51. nums = int(random.uniform( 0, len(text) - len_min))
  52. for i in range(nums):
  53. index = int(random.uniform( 0, len(text)))
  54. del text[index]
  55. return text, label
  56. def random_text_insert( text: List, label, len_range=(0, 2), token="<pad>", p=0.5):
  57. """
  58. 句⼦中的每个词,以概率p随机插入
  59. :param text:
  60. :param label:
  61. :param len_range:
  62. :param p:
  63. :return:
  64. """
  65. if random.random() < p and len(text) > 2 * len_range[ 1]:
  66. nums = math.ceil(random.uniform(len_range[ 0], len_range[ 1]))
  67. for i in range(nums):
  68. index = int(random.uniform( 0, len(text) - 1))
  69. text.insert(index, token)
  70. return text, label
  71. if __name__ == '__main__':
  72. label = 1
  73. context_size = 10
  74. pad_token = "<pad>"
  75. p = 10
  76. for i in range( 10):
  77. text = "我是一名中国人,我爱中国,我的家乡在广东"
  78. text = "_".join(text).split( "_")
  79. len_range = ( 0, context_size // 4)
  80. # text, label = random_text_crop(text, label, 1.8 * context_size, token=None, p=0.8)
  81. # text, label = random_text_delete(text, label, len_min=1.5 * context_size)
  82. text, label = random_text_insert(text, label, len_range=len_range, token=pad_token)
  83. # text, label = random_text_mask(text, label, len_range=len_range, token=pad_token)
  84. # text, label = random_text_crop(text, label, context_size, token=pad_token, p=0.8)
  85. print(text, len(text))

六、训练过程 

项目以THUCNews文本分类数据集为作为训练数据,训练一个基于TextCNN的文本分类模型;这里为了简单,没有使用gensim训练word2vec词向量模型,而是在TextCNN模型代码中,定义了一个可学习的embedding层,用于代替word2vec

(1)项目框架说明


  
  1. .
  2. ├── configs # 训练配置文件
  3. ├── core # 模型和训练相关工具
  4. ├── data # 相关数据
  5. ├── modules # 相关依赖包模块
  6. ├── work_space # 训练模型输出文件目录
  7. ├── README.md # 项目工程说明文档
  8. ├── requirements.txt # 相关依赖包版本说明,请用pip安装
  9. ├── word2vec.py # 训练词嵌入模型
  10. ├── classifier.py # 测试文本分类脚本
  11. └── train.py # 训练文件

项目依赖的python包,请使用pip安装对应版本


  
  1. numpy==1.16.3
  2. matplotlib==3.1.0
  3. Pillow==6.0.0
  4. easydict==1.9
  5. opencv-contrib-python==4.5.2.52
  6. opencv-python==4.5.1.48
  7. pandas==1.1.5
  8. PyYAML==5.3.1
  9. scikit-image==0.17.2
  10. scikit-learn==0.24.0
  11. scipy==1.5.4
  12. seaborn==0.11.2
  13. tensorboard==2.5.0
  14. tensorboardX==2.1
  15. torch==1.7.1+cu110
  16. torchvision==0.8.2+cu110
  17. tqdm==4.55.1
  18. xmltodict==0.12.0
  19. basetrainer
  20. pybaseutils==0.6.9
  21. jieba==0.42.1
  22. gensim==4.2.0

(2)准备Train和Test文本数据

下载THUCNews文本数据集,并解压;由于原始数据没有划分训练集和测试集,需要自己手动划分,项目随机抽取每类的100张文本作为测试集,其余的为训练集;

然后根据自己的保存的数据路径,修改配置文件数据路径:config_textfolder.yaml


  
  1. # 训练数据集,可支持多个数据集
  2. train_data:
  3. - "/path/to/dataset/THUCNews/train"
  4. # 测试数据集
  5. test_data:
  6. - "/path/to/dataset/THUCNews/test"
  7. vocab_file: "./data/vocabulary/vocabulary.json" # 字典文件(会根据训练数据集自动生成),或者word2vec文件
  8. # 类别文件
  9. class_name: "path/to/dataset/THUCNews/class_name.txt"

(3)配置文件:config_textfolder.yaml


  
  1. # 训练数据集,可支持多个数据集
  2. train_data:
  3. - "/path/to/dataset/THUCNews/train"
  4. # 测试数据集
  5. test_data:
  6. - "/path/to/dataset/THUCNews/test"
  7. vocab_file: "./data/vocabulary/vocabulary.json" # 字典文件(会根据训练数据集自动生成),或者word2vec文件
  8. # 类别文件
  9. class_name: "path/to/dataset/THUCNews/class_name.txt"
  10. data_type: "textfolder" # 加载数据DataLoader方法:word2vec,textfolder
  11. flag: "" # 输出目录标识
  12. resample: True # 是否进行重采样
  13. work_dir: "work_space" # 保存输出模型的目录
  14. net_type: "TextCNN" # 骨干网络,支持:TextCNN,TextCNNv2,LSTM,BiLSTM等
  15. context_size: 300 # 句子长度
  16. topk: [ 1, ] # 计算topK的准确率
  17. batch_size: 128 # 批训练大小
  18. lr: 0.001 # 初始学习率
  19. optim_type: "Adam" # 选择优化器,SGD,Adam
  20. loss_type: "CELoss" # 选择损失函数:支持CrossEntropyLoss(CELoss)
  21. momentum: 0.9 # SGD momentum
  22. num_epochs: 160 # 训练循环次数
  23. num_workers: 12 # 加载数据工作进程数
  24. weight_decay: 0.00005 # weight_decay,默认5e-4
  25. #weight_decay: 0.0 # weight_decay,默认5e-4
  26. scheduler: "multi-step" # 学习率调整策略
  27. milestones: [ 90,120,140 ] # 下调学习率方式
  28. gpu_id: [ 0,1 ] # GPU ID
  29. log_freq: 10 # LOG打印频率
  30. pretrained: True # 是否使用pretrained模型
  31. finetune: False # 是否进行finetune
  • 目标支持模型主要有:TextCNN,LSTM,BiLSTM等,详见模型等 ,其他模型可以自定义添加
  • 训练参数可以通过config.yaml配置文件
参数 类型 参考值 说明
train_data str, list - 训练数据文件,可支持多个文件
test_data str, list - 测试数据文件,可支持多个文件
vocab_file str -
字典文件(会根据训练数据集自动生成),或者word2vec文件
class_name str - 类别文件
data_type str - 加载数据DataLoader方法
resample bool True 是否进行重采样
work_dir str work_space 训练输出工作空间
net_type str TextCNN 骨干网络,支持:TextCNN,LSTM,BiLSTM等
context_size int 128 句子长度
topk list [1,3,5] 计算topK的准确率
batch_size int 32 批训练大小
lr float 0.1 初始学习率大小
optim_type str SGD 优化器,{SGD,Adam}
loss_type str CELoss 损失函数
scheduler str multi-step 学习率调整策略,{multi-step,cosine}
milestones list [30,80,100] 降低学习率的节点,仅仅scheduler=multi-step有效
momentum float 0.9 SGD动量因子
num_epochs int 120 循环训练的次数
num_workers int 12 DataLoader开启线程数
weight_decay float 5e-4 权重衰减系数
gpu_id list [ 0 ] 指定训练的GPU卡号,可指定多个
log_freq int 20 显示LOG信息的频率
finetune str model.pth finetune的模型

(4)开始训练

整套训练代码非常简单操作,用户只需要将相同类别的数据放在同一个目录下,并填写好对应的数据路径,即可开始训练了。

  • 如果你想验证项目可不可以训练,请运行下面命令开始训练;项目自带了小批量的文本数据,方便测试项目代码;正确情况下,可以获得99%的文本分类准确率
python train.py -c configs/config.yaml 
  • 如果你想正式在THUCNews数据集上,训练TextCNN文本分类模型,请运行:
python train.py -c configs/config_textfolder.yaml

以下是训练代码:


  
  1. # -*-coding: utf-8 -*-
  2. """
  3. @Author : panjq
  4. @E-mail : 390737991@qq.com
  5. @Date : 2022-09-26 14:50:34
  6. @Brief :
  7. """
  8. import os
  9. import torch
  10. import argparse
  11. import torch.nn as nn
  12. import numpy as np
  13. import tensorboardX as tensorboard
  14. from tqdm import tqdm
  15. from torch.utils import data as data_utils
  16. from core.dataloader import build_dataset
  17. from core.models import build_models
  18. from core.criterion.build_criterion import get_criterion
  19. from core.utils import torch_tools, metrics, log
  20. from pybaseutils import file_utils, config_utils
  21. from pybaseutils.metrics import class_report
  22. class Trainer( object):
  23. def __init__( self, cfg):
  24. torch_tools.set_env_random_seed()
  25. # 设置输出路径
  26. time = file_utils.get_time()
  27. flag = [n for n in [cfg.net_type, cfg.loss_type, cfg.flag, time] if n]
  28. cfg.work_dir = os.path.join(cfg.work_dir, "_".join(flag))
  29. cfg.model_root = os.path.join(cfg.work_dir, "model")
  30. cfg.log_root = os.path.join(cfg.work_dir, "log")
  31. file_utils.create_dir(cfg.work_dir)
  32. file_utils.create_dir(cfg.model_root)
  33. file_utils.create_dir(cfg.log_root)
  34. file_utils.copy_file_to_dir(cfg.config_file, cfg.work_dir)
  35. config_utils.save_config(cfg, os.path.join(cfg.work_dir, "setup_config.yaml"))
  36. self.cfg = cfg
  37. self.topk = self.cfg.topk
  38. # 配置GPU/CPU运行设备
  39. self.gpu_id = cfg.gpu_id
  40. self.device = torch.device( "cuda:{}". format(cfg.gpu_id[ 0]) if torch.cuda.is_available() else "cpu")
  41. # 设置Log打印信息
  42. self.logger = log.set_logger(level= "debug", logfile=os.path.join(cfg.log_root, "train.log"))
  43. # 构建训练数据和测试数据
  44. self.train_loader = self.build_train_loader()
  45. self.test_loader = self.build_test_loader()
  46. # 构建模型
  47. self.model = self.build_model()
  48. # 构建损失函数
  49. self.criterion = self.build_criterion()
  50. # 构建优化器
  51. self.optimizer = self.build_optimizer()
  52. # 构建学习率调整策略
  53. self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, cfg.milestones)
  54. # 使用tensorboard记录和可视化Loss
  55. self.writer = tensorboard.SummaryWriter(cfg.log_root)
  56. # 打印信息
  57. self.num_samples = len(self.train_loader.sampler)
  58. self.logger.info( "=" * 60)
  59. self.logger.info( "work_dir :{}". format(cfg.work_dir))
  60. self.logger.info( "config_file :{}". format(cfg.config_file))
  61. self.logger.info( "gpu_id :{}". format(cfg.gpu_id))
  62. self.logger.info( "main device :{}". format(self.device))
  63. self.logger.info( "num_samples(train):{}". format(self.num_samples))
  64. self.logger.info( "num_classes :{}". format(cfg.num_classes))
  65. self.logger.info( "mean_num :{}". format(self.num_samples / cfg.num_classes))
  66. self.logger.info( "=" * 60)
  67. def build_optimizer( self, ):
  68. """build_optimizer"""
  69. if self.cfg.optim_type.lower() == "SGD".lower():
  70. optimizer = torch.optim.SGD(params=self.model.parameters(), lr=self.cfg.lr,
  71. momentum=self.cfg.momentum, weight_decay=self.cfg.weight_decay)
  72. elif self.cfg.optim_type.lower() == "Adam".lower():
  73. optimizer = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr, weight_decay=self.cfg.weight_decay)
  74. else:
  75. optimizer = None
  76. return optimizer
  77. def build_train_loader( self, ) -> data_utils.DataLoader:
  78. """build_train_loader"""
  79. self.logger.info( "build_train_loader,context_size:{}". format(self.cfg.context_size))
  80. dataset = build_dataset.load_dataset(data_type=self.cfg.data_type,
  81. filename=self.cfg.train_data,
  82. vocab_file=self.cfg.vocab_file,
  83. context_size=self.cfg.context_size,
  84. class_name=self.cfg.class_name,
  85. resample=self.cfg.resample,
  86. phase= "train",
  87. shuffle= True)
  88. shuffle = True
  89. sampler = None
  90. self.logger.info( "use resample:{}". format(self.cfg.resample))
  91. # if self.cfg.resample:
  92. # weights = torch.DoubleTensor(dataset.classes_weights)
  93. # sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
  94. # shuffle = False
  95. loader = data_utils.DataLoader(dataset=dataset, batch_size=self.cfg.batch_size, sampler=sampler,
  96. shuffle=shuffle, num_workers=self.cfg.num_workers)
  97. self.cfg.num_classes = dataset.num_classes
  98. self.cfg.num_embeddings = dataset.num_embeddings
  99. self.cfg.class_name = dataset.class_name
  100. file_utils.copy_file_to_dir(self.cfg.vocab_file, cfg.work_dir)
  101. return loader
  102. def build_test_loader( self, ) -> data_utils.DataLoader:
  103. """build_test_loader"""
  104. self.logger.info( "build_test_loader,context_size:{}". format(cfg.context_size))
  105. dataset = build_dataset.load_dataset(data_type=self.cfg.data_type,
  106. filename=self.cfg.test_data,
  107. vocab_file=self.cfg.vocab_file,
  108. context_size=self.cfg.context_size,
  109. class_name=self.cfg.class_name,
  110. phase= "test",
  111. resample= False,
  112. shuffle= False)
  113. loader = data_utils.DataLoader(dataset=dataset, batch_size=self.cfg.batch_size,
  114. shuffle= False, num_workers=self.cfg.num_workers)
  115. self.cfg.num_classes = dataset.num_classes
  116. self.cfg.num_embeddings = dataset.num_embeddings
  117. self.cfg.class_name = dataset.class_name
  118. return loader
  119. def build_model( self, ) -> nn.Module:
  120. """build_model"""
  121. self.logger.info( "build_model,net_type:{}". format(self.cfg.net_type))
  122. model = build_models.get_models(net_type=self.cfg.net_type,
  123. num_classes=self.cfg.num_classes,
  124. num_embeddings=self.cfg.num_embeddings,
  125. embedding_dim= 128,
  126. is_train= True,
  127. )
  128. if self.cfg.finetune:
  129. self.logger.info( "finetune:{}". format(self.cfg.finetune))
  130. state_dict = torch_tools.load_state_dict(self.cfg.finetune)
  131. model.load_state_dict(state_dict)
  132. model = model.to(self.device)
  133. model = nn.DataParallel(model, device_ids=self.gpu_id, output_device=self.device)
  134. return model
  135. def build_criterion( self, ):
  136. """build_criterion"""
  137. self.logger.info(
  138. "build_criterion,loss_type:{}, num_embeddings:{}". format(self.cfg.loss_type, self.cfg.num_embeddings))
  139. criterion = get_criterion(self.cfg.loss_type, self.cfg.num_embeddings, device=self.device)
  140. # criterion = torch.nn.CrossEntropyLoss()
  141. return criterion
  142. def train( self, epoch):
  143. """训练"""
  144. train_losses = metrics.AverageMeter()
  145. train_accuracy = {k: metrics.AverageMeter() for k in self.topk}
  146. self.model.train() # set to training mode
  147. log_step = max( len(self.train_loader) // cfg.log_freq, 1)
  148. for step, data in enumerate(tqdm(self.train_loader)):
  149. inputs, target = data
  150. inputs, target = inputs.to(self.device), target.to(self.device)
  151. outputs = self.model(inputs)
  152. loss = self.criterion(outputs, target)
  153. self.optimizer.zero_grad() # 反馈
  154. loss.backward()
  155. self.optimizer.step() # 更新
  156. train_losses.update(loss.cpu().data.item())
  157. # 计算准确率
  158. target = target.cpu()
  159. outputs = outputs.cpu()
  160. outputs = torch.nn.functional.softmax(outputs, dim= 1)
  161. pred_score, pred_index = torch. max(outputs, dim= 1)
  162. acc = metrics.accuracy(outputs.data, target, topk=self.topk)
  163. for i in range( len(self.topk)):
  164. train_accuracy[self.topk[i]].update(acc[i].data.item(), target.size( 0))
  165. if step % log_step == 0:
  166. lr = self.scheduler.get_last_lr()[ 0] # 获得当前学习率
  167. topk_acc = { "top{}". format(k): v.avg for k, v in train_accuracy.items()}
  168. self.logger.info(
  169. "train {}/epoch:{:0=3d},lr:{:3.4f},loss:{:3.4f},acc:{}". format(step, epoch, lr, train_losses.avg,
  170. topk_acc))
  171. topk_acc = { "top{}". format(k): v.avg for k, v in train_accuracy.items()}
  172. self.writer.add_scalar( "train-loss", train_losses.avg, epoch)
  173. self.writer.add_scalars( "train-accuracy", topk_acc, epoch)
  174. self.logger.info( "train epoch:{:0=3d},loss:{:3.4f},acc:{}". format(epoch, train_losses.avg, topk_acc))
  175. return topk_acc[ "top{}". format(self.topk[ 0])]
  176. def test( self, epoch):
  177. """测试"""
  178. test_losses = metrics.AverageMeter()
  179. test_accuracy = {k: metrics.AverageMeter() for k in self.topk}
  180. true_labels = np.ones( 0)
  181. pred_labels = np.ones( 0)
  182. self.model. eval() # set to evaluates mode
  183. with torch.no_grad():
  184. for step, data in enumerate(tqdm(self.test_loader)):
  185. inputs, target = data
  186. inputs, target = inputs.to(self.device), target.to(self.device)
  187. outputs = self.model(inputs)
  188. loss = self.criterion(outputs, target)
  189. test_losses.update(loss.cpu().data.item())
  190. # 计算准确率
  191. target = target.cpu()
  192. outputs = outputs.cpu()
  193. outputs = torch.nn.functional.softmax(outputs, dim= 1)
  194. pred_score, pred_index = torch. max(outputs, dim= 1)
  195. acc = metrics.accuracy(outputs.data, target, topk=self.topk)
  196. true_labels = np.hstack([true_labels, target.numpy()])
  197. pred_labels = np.hstack([pred_labels, pred_index.numpy()])
  198. for i in range( len(self.topk)):
  199. test_accuracy[self.topk[i]].update(acc[i].data.item(), target.size( 0))
  200. report = class_report.get_classification_report(true_labels, pred_labels, target_names=self.cfg.class_name)
  201. topk_acc = { "top{}". format(k): v.avg for k, v in test_accuracy.items()}
  202. lr = self.scheduler.get_last_lr()[ 0] # 获得当前学习率
  203. self.writer.add_scalar( "test-loss", test_losses.avg, epoch)
  204. self.writer.add_scalars( "test-accuracy", topk_acc, epoch)
  205. self.logger.info( "test epoch:{:0=3d},lr:{:3.4f},loss:{:3.4f},acc:{}". format(epoch, lr, test_losses.avg, topk_acc))
  206. self.logger.info( "{}". format(report))
  207. return topk_acc[ "top{}". format(self.topk[ 0])]
  208. def run( self):
  209. """开始运行"""
  210. self.max_acc = 0.0
  211. for epoch in range(self.cfg.num_epochs):
  212. train_acc = self.train(epoch) # 训练模型
  213. test_acc = self.test(epoch) # 测试模型
  214. self.scheduler.step() # 更新学习率
  215. lr = self.scheduler.get_last_lr()[ 0] # 获得当前学习率
  216. self.writer.add_scalar( "lr", lr, epoch)
  217. self.save_model(self.cfg.model_root, test_acc, epoch)
  218. self.logger.info( "epoch:{}, lr:{}, train acc:{:3.4f}, test acc:{:3.4f}".
  219. format(epoch, lr, train_acc, test_acc))
  220. def save_model( self, model_root, value, epoch):
  221. """保存模型"""
  222. # 保存最优的模型
  223. if value >= self.max_acc:
  224. self.max_acc = value
  225. model_file = os.path.join(model_root, "best_model_{:0=3d}_{:.4f}.pth". format(epoch, value))
  226. file_utils.remove_prefix_files(model_root, "best_model_*")
  227. torch.save(self.model.module.state_dict(), model_file)
  228. self.logger.info( "save best model file:{}". format(model_file))
  229. # 保存最新的模型
  230. name = "model_{:0=3d}_{:.4f}.pth". format(epoch, value)
  231. model_file = os.path.join(model_root, "latest_{}". format(name))
  232. file_utils.remove_prefix_files(model_root, "latest_*")
  233. torch.save(self.model.module.state_dict(), model_file)
  234. self.logger.info( "save latest model file:{}". format(model_file))
  235. self.logger.info( "-------------------------" * 4)
  236. def get_parser():
  237. # cfg_file = "configs/config_textfolder.yaml"
  238. cfg_file = "configs/config.yaml"
  239. parser = argparse.ArgumentParser(description= "Training Pipeline")
  240. parser.add_argument( "-c", "--config_file", help= "configs file", default=cfg_file, type= str)
  241. cfg = config_utils.parser_config(parser.parse_args(), cfg_updata= True)
  242. return cfg
  243. if __name__ == "__main__":
  244. cfg = get_parser()
  245. train = Trainer(cfg)
  246. train.run()

(5)可视化训练过程

训练过程可视化工具是使用Tensorboard,使用方法:

  
  1. # 基本方法
  2. tensorboard --logdir=path/to/log/
  3. # 例如(请修改自己的训练的模型路径)
  4. tensorboard --logdir=work_space/TextCNN_CELoss_20230106152138/log

可视化效果 

​​ ​​
​​ ​​
​​ ​​

(6)一些优化建议

训练完成后,目前,基于TextCNN的文本分类识别在THUCNews数据集上,训练集的Accuracy 99%左右,测试集的Accuracy在88.36%左右;如果想进一步提高准确率,可以尝试:

  1. 数据整合:部分分类之间本身模棱两可,例如体育和娱乐教育和科技本身类别就有很多相似之处,导致模型分类困难;THUCNews数据量虽然庞大,但不是十分干净,有很多脏数据;建议你,训练前,清洗或整合部分数据集,不然会影响模型的识别的准确率。
  2. 增加TextCNN参数量:比如将TextCNN的num_channels设置大一点;当然模型越复杂,越容易过拟合;
  3. 增加pretrained模型:项目构建TextCNN模型,随机初始化了一个可学习的二维矩阵:Embedding,该Embedding模型没有增加pretrained的,若能加入pretrained,其准确率会好很多。
  4. 文本数据增强:如同义词替换,文本随机插入,随机删除等处理,增强模型泛化能力
  5. 样本均衡:数据不均衡,部分类目数据太少; 建议进行样本均衡处理,减少长尾问题的影响
  6. 超参调优: 比如学习率调整策略,优化器(SGD,Adam等)
  7. 损失函数: 目前训练代码已经支持:交叉熵,LabelSmoothing,可以尝试FocalLoss等损失函数

七. 模型测试效果

classifier.py文件用于模型推理和测试脚本,填写好配置文件,模型文件以及测试文本路径即可运行测试了


  
  1. def get_parser():
  2. model_file = "work_space/TextCNN_CELoss_20221226114529/model/latest_model_159_0.8714.pth"
  3. config_file = os.path.join(os.path.dirname(os.path.dirname(model_file)), "config_textfolder.yaml")
  4. vocab_file = os.path.join(os.path.dirname(os.path.dirname(model_file)), "vocabulary.json")
  5. text_dir = "data/test-text"
  6. parser = argparse.ArgumentParser(description= "Inference Argument")
  7. parser.add_argument( "-c", "--config_file", help= "configs file", default=config_file, type= str)
  8. parser.add_argument( "-m", "--model_file", help= "model_file", default=model_file, type= str)
  9. parser.add_argument( "-v", "--vocab_file", help= "vocab_file", default=vocab_file, type= str)
  10. parser.add_argument( "--device", help= "cuda device id", default= "cuda:0", type= str)
  11. parser.add_argument( "--text_dir", help= "text", default=text_dir, type= str)
  12. return parser

在项目根目录终端运行命令(\表示换行符):


  
  1. #!/usr/bin/env bash
  2. # Usage:
  3. # python classifier.py -c "path/to/config.yaml" -m "path/to/model.pth" -v "path/to/vocabulary.json" --text_dir "path/to/text_dir"
  4. python classifier.py \
  5. -c "work_space/TextCNN_CELoss_20221226114529/config_textfolder.yaml" \
  6. -m "work_space/TextCNN_CELoss_20221226114529/model/latest_model_159_0.8714.pth" \
  7. -v "work_space/TextCNN_CELoss_20221226114529/vocabulary.json" \
  8. --text_dir "data/test-text"

运行测试结果: 


八.项目源码下载

整套项目源码下载:Pytorch TextCNN实现中文文本分类(附完整训练代码)

整套项目源码内容包含

  • 提供中文文本数据集:THUCNews
  • 项目支持训练词嵌入模型训练:word2vec.py
  • 项目提供Pytorch版本的中文文本分类模型训练:train.py,支持TextCNN, LSTM, BiLSTM等模型
  • 提供中文文本分类预测:classifier.py
  • 简单配置,一键开启训练自己的中文文本分类模型


转载:https://blog.csdn.net/guyuealian/article/details/127846717
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场