小言_互联网的博客

情感分析系列(四)——使用BERT进行情感分析

356人阅读  评论(0)

一、系列文章

  1. 情感分析系列(一)——IMDb数据集及其预处理
  2. 情感分析系列(二)——使用BiLSTM进行情感分析
  3. 情感分析系列(三)——使用TextCNN进行情感分析

二、数据预处理

因为BERT接受的输入和BiLSTM以及TextCNN不太一样,所以我们不能沿用之前的数据预处理方式。

事实上,我们只需对原先的序列在其开头加上 <cls> 词元,在其结尾加上 <sep> 词元,然后进行填充(注意是先添加特殊词元再进行填充)。于是 build_dataset() 重写为

def build_dataset(reviews, labels, vocab, max_len=512, bert_preprocess=False):
    if bert_preprocess:
        text_transform = T.Sequential(
            T.VocabTransform(vocab=vocab),
            T.Truncate(max_seq_len=max_len - 2),  # 之所以减2是因为接下来要添加两个特殊词元
            T.AddToken(token=vocab['<cls>'], begin=True),
            T.AddToken(token=vocab['<sep>'], begin=False),
            T.ToTensor(padding_value=vocab['<pad>']),
            T.PadTransform(max_length=max_len, pad_value=vocab['<pad>']),
        )
    else:
        text_transform = T.Sequential(
            T.VocabTransform(vocab=vocab),
            T.Truncate(max_seq_len=max_len),
            T.ToTensor(padding_value=vocab['<pad>']),
            T.PadTransform(max_length=max_len, pad_value=vocab['<pad>']),
        )
    dataset = TensorDataset(text_transform(reviews), torch.tensor(labels))
    return dataset

相应的 load_imdb() 函数也需要做一些改动

def load_imdb(bert_preprocess=False):
    reviews_train, labels_train = read_imdb(is_train=True)
    reviews_test, labels_test = read_imdb(is_train=False)
    vocab = build_vocab_from_iterator(reviews_train, min_freq=3, specials=['<pad>', '<unk>', '<cls>', '<sep>'])
    vocab.set_default_index(vocab['<unk>'])
    train_data = build_dataset(reviews_train, labels_train, vocab, bert_preprocess=bert_preprocess)
    test_data = build_dataset(reviews_test, labels_test, vocab, bert_preprocess=bert_preprocess)
    return train_data, test_data, vocab

三、搭建BERT

3.1 从零开始训练BERT

我们使用Hugging Face中的 BertForSequenceClassification 来进行情感分类

class BERT(nn.Module):
    def __init__(self, vocab):
        super().__init__()
        self.vocab = vocab
        self.config = BertConfig()
        self.config.vocab_size = len(self.vocab)
        self.bert = BertForSequenceClassification(config=self.config)

    def forward(self, input_ids):
        attention_mask = (input_ids != self.vocab['<pad>']).long().float()
        logits = self.bert(input_ids=input_ids, attention_mask=attention_mask).logits
        return logits

上述配置的BERT共有1.1亿个参数,这里采用数据并行(DataParallel)的方式进行训练(4张A40-48GB):

set_seed(42)

BATCH_SIZE = 256
LEARNING_RATE = 1e-4
NUM_EPOCHS = 40

train_data, test_data, vocab = load_imdb(bert_preprocess=True)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

devices = [f'cuda:{
     i}' for i in range(torch.cuda.device_count())]
model = nn.DataParallel(BERT(vocab), device_ids=devices).to(devices[0])
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999), eps=1e-9, weight_decay=0.01)
scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=1000, num_training_steps=10000)

for epoch in range(1, NUM_EPOCHS + 1):
    print(f'Epoch {
     epoch}\n' + '-' * 32)
    avg_train_loss = 0
    for batch_idx, (X, y) in enumerate(train_loader):
        X, y = X.to(devices[0]), y.to(devices[0])
        pred = model(X)
        loss = criterion(pred, y)
        avg_train_loss += loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if (batch_idx + 1) % 10 == 0:
            print(f"[{
     (batch_idx + 1) * BATCH_SIZE:>5}/{
     len(train_loader.dataset):>5}] train loss: {
     loss:.4f}")

    print(f"Avg train loss: {
     avg_train_loss/(batch_idx + 1):.4f}\n")

acc = 0
for X, y in test_loader:
    with torch.no_grad():
        X, y = X.to(devices[0]), y.to(devices[0])
        pred = model(X)
        acc += (pred.argmax(1) == y).sum().item()

print(f"Accuracy: {
     acc / len(test_loader.dataset):.4f}")

最终结果:

Accuracy: 0.8532

可以看到虽然我们使用了更大的模型,但最终效果并没有比使用了预训练词向量的BiLSTM和TextCNN好,接下来我们使用预训练的BERT并在IMDb上做微调来观察效果。

3.2 使用预训练的BERT+微调

在使用预训练的BERT进行微调时,我们的词表也需要保持和预训练时的词表一致,因此需要对 load_imdb() 函数重写:

def load_imdb():
    with open('./vocab.txt', 'r') as f:
        freq = []
        for token in f.readlines():
            freq.append((token.strip(), 1))
    v = vocab(OrderedDict(freq))
    v.set_default_index(v['[UNK]'])

    text_transform = T.Sequential(
        T.VocabTransform(vocab=v),
        T.Truncate(max_seq_len=510),
        T.AddToken(token=v['[CLS]'], begin=True),
        T.AddToken(token=v['[SEP]'], begin=False),
        T.ToTensor(padding_value=v['[PAD]']),
        T.PadTransform(max_length=512, pad_value=v['[PAD]']),
    )

    reviews_train, labels_train = read_imdb(is_train=True)
    reviews_test, labels_test = read_imdb(is_train=False)

    train_data = TensorDataset(text_transform(reviews_train), torch.tensor(labels_train))
    test_data = TensorDataset(text_transform(reviews_test), torch.tensor(labels_test))

    return train_data, test_data, v

定义模型:

class BERT(nn.Module):
    def __init__(self, vocab):
        super().__init__()
        self.vocab = vocab
        self.bert = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

    def forward(self, input_ids):
        attention_mask = (input_ids != self.vocab['[PAD]']).long().float()
        logits = self.bert(input_ids=input_ids, attention_mask=attention_mask).logits
        return logits

微调时,我们仅需要对以下几个地方做出改动:

LEARNING_RATE = 5e-5
NUM_EPOCHS = 3
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=NUM_EPOCHS * len(train_loader))

最终结果:

Accuracy: 0.9329

四、一些心得

  • 从零开始训练的时候一定要做warmup,如果使用固定的学习率(1e-4)会导致loss一直下不去(会在0.6-0.7之间反复横跳)。关于超参数的配置可参考这篇文档,其中学习率调度器中的 num_warmup_stepsnum_training_steps 需要根据自己数据集的情况做出相应的调整;
  • 使用预训练的模型时词表也一定要使用预训练时所用的词表,否则会报错 RuntimeError: CUDA error: device-side assert triggered

🧑‍💻 创作不易,如需源码可前往 SA-IMDb 进行查看,下载时还请您随手给一个follow和star,谢谢!


转载:https://blog.csdn.net/raelum/article/details/127617537
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场