与之前的Decoder-Encoder篇类似,但是这次我们上一点难度。把加减乘除和小数都算进去。理论存在,但是得拿出来溜溜

构建数据集

首先,因为共有0-9,+,-,*,/,.,\t这16个字符,其中\t前面是问题,后面是网络的预测,跟上一个篇章保持一致,然后因为每个算式的长度不一样,我们需要用pad填充,所以,先定义一个长度为17的字母表

1
alphabet = ['<pad>','<sep>','@','#','+', '-', '*', '/', '.', '='] + [str(x) for x in range(10)]

其中<pad><sep>分别是填充符和制表符。

接下来就是随机生成加减法字符串,并根据alphabet转换为id。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 生成一组加减法的字符串,设定最大长度为20
def generate_random_math_problem():
num1 = random.randint(0, 100)
num2 = random.randint(1, 100)
operation = random.choice(["+", "-", "*", "/"])

if operation == "+":
result = num1 + num2
elif operation == "-":
result = num1 - num2
elif operation == "*":
result = num1 * num2
else:
result = round(num1 / num2, 2)

problem_str = f"{num1}{operation}{num2}=\t{result}\t"
problem_list = [alphabet.index(x) if x != '\t' else alphabet.index('<sep>') for x in problem_str]
return problem_list

然后根据Pythoch的Dataset类,构建数据集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# 定义数据集
class MyDataSet(Data.Dataset):
def __init__(self, len=10000):
self.datas = []
for i in range(len):
self.datas.append(generate_random_math_problem())

def __getitem__(self, item): #从上面的列表中按item索引取出一个数据(一段对话),构造gpt的输入和输出,打包成字典返回
data = self.datas[item] #从上面的列表中按item索引取出一个数据(一段对话)
decoder_input = data[:-1] #输入和输出错开一个位置
decoder_output = data[1:]

decoder_input_len = len(decoder_input) #这个句子的长度,其实输入和输出长度是一样的
decoder_output_len = len(decoder_output)

return {"decoder_input": decoder_input, "decoder_input_len": decoder_input_len,
"decoder_output": decoder_output, "decoder_output_len": decoder_output_len}

def __len__(self):
return len(self.datas)

# 这个方法会作为DataLoader的collate_fn的参数。猜测是因为如果不写这个,torch会调用默认的collate_fn,也就是把这个batch列表的数据转为torch矩阵,但是这里batch内每个数据长度都不一样,无法直接转为矩阵,就会报错。
def padding_batch(self, batch): #接收getitem方法返回的batch,
decoder_input_lens = [d["decoder_input_len"] for d in batch] #取出batch里面每一个输入数据(每一段话)的长度
decoder_output_lens = [d["decoder_output_len"] for d in batch] #取出batch里面每一个输出数据(每一段话)的长度

decoder_input_maxlen = max(decoder_input_lens) #batch里面一段话的最大长度
decoder_output_maxlen = max(decoder_output_lens)

for d in batch: # 对当前batch的每一个decoder_input和decoder_output数据填充"<pad>",填充到和batch里面的有的最大长度为止
d["decoder_input"].extend([alphabet.index('<pad>')] * (decoder_input_maxlen - d["decoder_input_len"]))
d["decoder_output"].extend([alphabet.index('<pad>')] * (decoder_output_maxlen - d["decoder_output_len"]))
decoder_inputs = torch.tensor([d["decoder_input"] for d in batch], dtype=torch.long) #转type
decoder_outputs = torch.tensor([d["decoder_output"] for d in batch], dtype=torch.long)

return decoder_inputs, decoder_outputs #形状[b,decoder_input_maxlen], [b,decoder_output_maxlen] type为torch.long

通过这个方法,DataLoader就可以加载数据集了,输入和输出刚好错开一个字符,至于为什么要错开一个字符呢?大家可以自己思考一下。

训练

这里的训练就按照常规的方式其实就可以了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
max_pos = 30  #公式长度最长30
d_model = 32 # Embedding Size
d_ff = 2048 # FeedForward dimension
d_k = d_v = 64 # dimension of K(=Q), V
n_layers = 6 # number of Encoder of Decoder Layer
n_heads = 8 # number of heads in Multi-Head Attention

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 32
epochs = 30
dataset = MyDataSet(70000)
data_loader = Data.DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.padding_batch) #这里将collate_fn设置为上面定义的padding_batch方法
model = GPT().to(device)
for epoch in range(epochs):
for i, (dec_inputs, dec_outputs) in enumerate(tqdm(data_loader)): #dec_inputs: [b, tgt_len] , dec_outputs: [b, tgt_len]
optimizer.zero_grad()
dec_inputs, dec_outputs =dec_inputs.to(device), dec_outputs.to(device)
# outputs: [batch_size * tgt_len, tgt_vocab_size] tgt_len<=30

# with torch.cuda.amp.autocast(): # 半精度训练
outputs, dec_self_attns = model(dec_inputs)
loss = criterion(outputs, dec_outputs.view(-1)) #outputs :(b * tgt_len, vocab_size),dec_outputs.view(-1) :(b * tgt_len) tgt_len<=300

print_loss_total += loss.item()
epoch_loss += loss.item()
loss.backward() #梯度反向传播

# 梯度裁剪,防止梯度爆炸。如果loss超过clip,将梯度值缩小为原来的(loss/clip)分之一
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

optimizer.step() #更新模型权重
torch.save(model.state_dict(), r'./MyGPT2-0622.pt') #保存模型权重

简单跑一下就可以了,这里具体过程就不放了。

测试

测试代码

测试的时候,我们需要一个生成器,这里我们使用贪心算法,每次都选择概率最大的那个字符,然后再输入到模型中,直到遇到结束符号或者达到最大长度为止。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def greedy_decoder(model, dec_input): #dec_input :[1,tgt_len]   此时tgt_len就是句子长度
terminal = False
start_dec_len = len(dec_input[0])
# 一直预测下一个单词,直到预测到"<sep>"结束,如果一直不到"<sep>",则根据长度退出循环,并在最后加上”<sep>“字符
while not terminal:
if len(dec_input[0]) - start_dec_len > 100:
next_symbol = alphabet.index('<sep>')
dec_input = torch.cat(
[dec_input.detach(), torch.tensor([[next_symbol]], dtype=dec_input.dtype, device=device)], -1)
break

# forward
dec_outputs, _ = model.decoder(dec_input)
projected = model.projection(dec_outputs) #[1, tgt_len, vocab_size]

prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1] #[1]是索引,我们只要索引就行了。[0]是具体概率数值,不需要 形状[tgt_len]
next_word = prob.data[-1] #最后一个字对应的id
next_symbol = next_word
if next_symbol == alphabet.index('<sep>'): #如果预测到"<sep>"则结束
terminal = True

dec_input = torch.cat(
[dec_input.detach(), torch.tensor([[next_symbol]], dtype=dec_input.dtype, device=device)], -1)

return dec_input # [1,tgt_len+n] 因为多了n个预测的字

然后我们就可以测试一下了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
model.load_state_dict(torch.load(r'weights/01/MyGPT2-0622.pt'))
print('problem \t pre \t\t\t true')
for i in range(20):
num1 = random.randint(0, 100)
num2 = random.randint(1, 100)
operation = random.choice(["+", "-", "*", "/"])

if operation == "+":
result = num1 + num2
elif operation == "-":
result = num1 - num2
elif operation == "*":
result = num1 * num2
else:
result = round(num1 / num2, 2)

problem_str = f"{num1}{operation}{num2}=\t"
test_x = torch.Tensor([alphabet.index(x) if x != '\t' else alphabet.index('<sep>') for x in problem_str])
test_x = test_x.long().view(1, -1).cuda()
test_y = greedy_decoder(model,test_x)
out = [alphabet[int(x)] for x in test_y[0]]
print(''.join(out).replace('<sep>', '\t\t'), '\t', result)

测试结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
problem 	pre 		 true
31-13= 18 18
13+90= 103 103
95*79= 7505 7505
13-28= -15 -15
14/35= 0.4 0.4
71/29= 2.65 2.45
80-57= 23 23
73/92= 0.79 0.79
25-33= -8 -8
32-48= -16 -16
52*8= 416 416
99+54= 153 153
68/27= 2.56 2.52
60/59= 1.02 1.02
74*56= 4144 4144
53/72= 0.73 0.74
4-70= -66 -66
69/22= 3.14 3.14
48-82= -34 -34
50*23= 1150 1150

可以看到,仅仅跑了30个epoch,但是准确率已经非常高了,只有除法可能还存在一些误差,可以通过fine-tune来提高拟合度,这个结果还是非常Amazing的,也可以从侧面看出self-attention的强大之处。