Masked LM 完成对联模型

上次完成了“AI版对穿肠”,使用了UniLM模型,权重采用albert作为初始化权重。相当于把对联模型当成一般的seq2seq的结构,实际上对联的模型的输入和输出是等长的,也可以看成是序列标注问题,只不过标签类别是整个词库。也就是说可以在Bert的输出层接一个与词库维度相同的密集层就可以了,这不正是 Masked Language Model吗。

Masked Language Model是Bert的一个预训练任务,通过随机(15%的概率)mask掉部分词作为输入,然后预测对应位置的词,本质上就是输入与输出相等的seq2seq的模型,用于对联模型肯定行得通。这么做的优势在于输出不需要自回归解码了,预测的时候可以直接取概率最大的词,在模型预测和推理过程中会减少很多计算量。

等长的seq2seq模型除了写诗和对联以外,还有一个任务就是文本纠错,输入错别子的文本,输出正确的文本,也是完全没有问题的,当然这只能适用于错别字,对于多字或者少字,以及语法上的错误是无能为力的。

模型构建

本文采用albert base版,其余版本的模型暂不做尝试。数据集仍然采用冯重朴梨味斋散叶的博客训练的20个epoch,大致挑选了效果还不错的输出:

 - 今日天气多云多美丽
-- 今朝人光有水更和谐

 - 珍藏惟有诗三卷
-- 雅读不无酒一杯

 - 狂笔一挥天地动
-- 高风三卷古今流

 - 推窗问月诗何在
-- 对案观风酒不来

 - 彩屏如画,望秀美崤函,花团锦簇
-- 玉阁似诗,看和谐华苑,水舞莺流

接下来测试一下模型的推理时间,模型都采用albert-base,使用ipython里面的%time和%timeit进行测试。

模型名称 time timeit
MLM CPU times: user 9.42 ms, sys: 29 ms, total: 38.5 ms Wall time: 37.5 ms 7.97 ms ± 403 µs per loop 
UniML CPU times: user 117 ms, sys: 15.1 ms, total: 132 ms Wall time: 107 ms 93.1 ms ± 622 µs per loop

测试的结果在我们的预料之中,Masked LM 没有自回归解码,推理的时间少了11倍多。

总结

Masked LM虽然是Bert的一项预训练任务,其实可以用来完成一些seq2seq的任务的,比如对联和写诗,除此之外还可以文本的纠错(已有论文)。其实Masked LM还有更多的用途,比如可以用于情感分析,在原始文本末尾加入“你认为这样_好”,填空处选择“很或者不”两个字,这样可以完成情感分类的问题。

文章最后比较了UniLM与MLM的推理时间,发现在不用自回归解码的情况下完成对联任务,MLM快了11倍多。不过MLM只能对于输入和输出等长的序列问题有效,而UniLM则是通用的seq2seq的模型。

代码见个人github

标签:

发表评论

邮箱地址不会被公开。