if operation == "+": result = num1 + num2 elif operation == "-": result = num1 - num2 elif operation == "*": result = num1 * num2 else: result = round(num1 / num2, 2)
problem_str = f"{num1}{operation}{num2}=\t{result}\t" problem_list = [alphabet.index(x) if x != '\t'else alphabet.index('<sep>') for x in problem_str] return problem_list
# 这个方法会作为DataLoader的collate_fn的参数。猜测是因为如果不写这个,torch会调用默认的collate_fn,也就是把这个batch列表的数据转为torch矩阵,但是这里batch内每个数据长度都不一样,无法直接转为矩阵,就会报错。 defpadding_batch(self, batch): #接收getitem方法返回的batch, decoder_input_lens = [d["decoder_input_len"] for d in batch] #取出batch里面每一个输入数据(每一段话)的长度 decoder_output_lens = [d["decoder_output_len"] for d in batch] #取出batch里面每一个输出数据(每一段话)的长度
for d in batch: # 对当前batch的每一个decoder_input和decoder_output数据填充"<pad>",填充到和batch里面的有的最大长度为止 d["decoder_input"].extend([alphabet.index('<pad>')] * (decoder_input_maxlen - d["decoder_input_len"])) d["decoder_output"].extend([alphabet.index('<pad>')] * (decoder_output_maxlen - d["decoder_output_len"])) decoder_inputs = torch.tensor([d["decoder_input"] for d in batch], dtype=torch.long) #转type decoder_outputs = torch.tensor([d["decoder_output"] for d in batch], dtype=torch.long)
max_pos = 30#公式长度最长30 d_model = 32# Embedding Size d_ff = 2048# FeedForward dimension d_k = d_v = 64# dimension of K(=Q), V n_layers = 6# number of Encoder of Decoder Layer n_heads = 8# number of heads in Multi-Head Attention
model.load_state_dict(torch.load(r'weights/01/MyGPT2-0622.pt')) print('problem \t pre \t\t\t true') for i inrange(20): num1 = random.randint(0, 100) num2 = random.randint(1, 100) operation = random.choice(["+", "-", "*", "/"])
if operation == "+": result = num1 + num2 elif operation == "-": result = num1 - num2 elif operation == "*": result = num1 * num2 else: result = round(num1 / num2, 2)
problem_str = f"{num1}{operation}{num2}=\t" test_x = torch.Tensor([alphabet.index(x) if x != '\t'else alphabet.index('<sep>') for x in problem_str]) test_x = test_x.long().view(1, -1).cuda() test_y = greedy_decoder(model,test_x) out = [alphabet[int(x)] for x in test_y[0]] print(''.join(out).replace('<sep>', '\t\t'), '\t', result)