文章简介
本文主要介绍了使用
数据准备
为了满足模型训练的需要,应该准备
query:52+758 ans: 810
我们这里限定了加法运算的两个整数都是
while len(questions) < TRAINING_SIZE: f = lambda: int("".join(np.random.choice(list("1234567890")) for _ in range(np.random.randint(1, DIGITS + 1)))) a, b = f(), f() ... q = "{}+{}".format(a, b) query = q + " " * (MAXLEN - len(q)) ans = str(a + b) ans += " " * (DIGITS + 1 - len(ans)) questions.append(query) expected.append(ans)
样本创建好之后,需要对样本进行
x = np.zeros((len(questions), MAXLEN, len(chars)), dtype=np.bool) y = np.zeros((len(questions), DIGITS + 1, len(chars)), dtype=np.bool) for i, sentence in enumerate(questions): x[i] = ctable.encode(sentence, MAXLEN) for i, sentence in enumerate(expected): y[i] = ctable.encode(sentence, DIGITS + 1)
模型搭建
模型结构很简单,主要使用了 LSTM 层、RepeatVector 层、 Dense 层,都是基础知识,不做过多解释,编译模型时候设置损失函数为
model = keras.Sequential() model.add(layers.LSTM(128, input_shape=(MAXLEN, len(chars)))) model.add(layers.RepeatVector(DIGITS + 1)) model.add(layers.LSTM(128, return_sequences=True)) model.add(layers.Dense(len(chars), activation="softmax")) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
模型训练
选取
Iter 1 1407/1407 [==============================] - 11s 6ms/step - loss: 1.7796 - accuracy: 0.3499 - val_loss: 1.5788 - val_accuracy: 0.4065 Iter 2 1407/1407 [==============================] - 9s 6ms/step - loss: 1.3928 - accuracy: 0.4762 - val_loss: 1.2489 - val_accuracy: 0.5346 ... Iter 28 1407/1407 [==============================] - 9s 6ms/step - loss: 0.0205 - accuracy: 0.9944 - val_loss: 0.0257 - val_accuracy: 0.9917 Iter 29 1407/1407 [==============================] - 9s 6ms/step - loss: 0.0256 - accuracy: 0.9926 - val_loss: 0.0747 - val_accuracy: 0.9827
效果展示
下面展示了 10 条样本结果,预测正确的有
Q 537+65 A 602 ? 602 Q 0+998 A 998 ? 998 Q 50+691 A 741 ? 741 Q 104+773 A 877 ? 877 Q 21+84 A 105 ? 105 Q 318+882 A 1200 ? 1200 Q 850+90 A 940 ? 940 Q 96+11 A 107 ? 907 Q 1+144 A 145 ? 145 Q 809+4 A 813 ? 813
参考
https://github.com/wangdayaya/DP_2023/blob/main/NLP%20%E6%96%87%E7%AB%A0/Sequence%20to%20sequence%20learning%20for%20performing%20number%20addition.py