Datawhale AI夏令营- 讯飞机器翻译挑战赛baseline解析
1o0.0o1 2024-07-14 11:31:01 阅读 90
讯飞机器翻译挑战赛题
赛题数据分析NLP前置知识GRUSeq2SeqEncoder编码器Decoder解码器
数据处理思路数据清洗构建数据集类型TranslateDataset
模型搭建和训练基于seq2seq 的模型基于transformer的模型训练代码
基于BLUE4的评估指标BLEU4 的计算步骤:举个例子:1. 分词:2. 计算n-gram匹配:3. 计算精确度:4. 计算加权几何平均:5. 计算惩罚因子BP:6. 计算最终的BLEU4分数:
实现代码
推理
赛题数据分析
在官方提供的数据集中,我们可以了解到:训练集又14w条数据,同时官方也提供了一个测试集给我们用来模型评估。同时还提供了一个术语词典,作为一些特殊词语的翻译对照表。
了解到这些信息后,我们可以下载这些数据看看这些数据怎样的。
我们打开训练集train.txt后发现,英文和中文通过制表符\t来分割,我们后续就可以每一行读取,然后通过\t划分来拿到中文数据和英文数据。这些后续会在数据处理部分说明
NLP前置知识
本次baseline用的是seq2seq的模型, 然后Encode和decode部分使用的是GRU模型。下面我将一一讲解这两种模型
GRU
在讲GRU之前,需要先补充一下RNN是什么。如下图:
RNN模型在每个时间步接收一个字的输入,生成隐藏状态和输出,再将隐藏状态与下一个字输入到模型中,重复此过程。
GRU(门控循环单元)是RNN的变体,能够有效捕捉长序列语义关联,缓解梯度消失或爆炸现象,其核心结构由更新门和重置门两部分组成。
如果看不懂的话,可以直接理解为GRU的输入,输出都和RNN是一致的,但是比RNN更加厉害。
Seq2Seq
结构图如下:
Seq2Seq由两个结构组成,分别是Encoder和Decoder块
Seq2Seq模型由两个主要部分组成:Encoder(编码器)和Decoder(解码器),两者均为GRU网络。以下是该模型的详细介绍:
Encoder编码器
输入序列:模型接收输入序列
x
=
[
x
1
,
x
2
,
x
3
,
x
4
]
x = [x_1, x_2, x_3, x_4]
x=[x1,x2,x3,x4]。GRU网络:输入序列逐步传递给GRU单元,每个输入
x
i
x_i
xi 生成相应的隐藏状态
h
i
h_i
hi。
第一个GRU单元接收
x
1
x_1
x1,生成隐藏状态
h
1
h_1
h1。第二个GRU单元接收
x
2
x_2
x2和
h
1
h_1
h1,生成隐藏状态
h
2
h_2
h2。如此反复,直到最后一个输入
x
4
x_4
x4 生成隐藏状态
h
4
h_4
h4。 上下文向量:最后一个隐藏状态
h
4
h_4
h4 作为上下文向量
c
c
c,用于解码阶段。
Decoder解码器
初始状态:解码器的初始隐藏状态由编码器生成的上下文向量
c
c
c 初始化。GRU网络:解码器逐步生成输出序列
y
=
[
y
1
,
y
2
,
y
3
,
y
4
]
y = [y_1, y_2, y_3, y_4]
y=[y1,y2,y3,y4]。
解码器的第一个GRU单元接收上下文向量
c
c
c 和初始隐藏状态,生成第一个隐藏状态
h
1
′
h'_1
h1′和输出
y
1
y_1
y1。第二个GRU单元接收第一个隐藏状态
h
1
′
h'_1
h1′,生成第二个隐藏状态
h
2
′
h'_2
h2′ 和输出
y
2
y_2
y2。如此反复,直到生成最后一个隐藏状态
h
4
′
h'_4
h4′和输出
y
4
y_4
y4。
过程描述
编码过程:输入序列
x
x
x被编码器处理,生成最终的上下文向量
c
c
c。解码过程:解码器利用上下文向量
c
c
c和初始状态生成目标输出序列
y
y
y。
此模型的目的是将一个序列(如一句话)转换为另一个序列(如另一种语言的翻译),其中编码器将输入序列编码为固定大小的上下文向量,解码器再将该向量解码为目标序列。Seq2Seq模型广泛应用于机器翻译、文本摘要等任务。
数据处理思路
记住我们正常的神经网络是无法直接识别中文或者英文的字符串输入的。所以这一步我们的目标只有一个,那就是将数据变成神经网络可以识别到的数据类型。
因此在这个阶段,我们要做的是:
数据清洗构建数据集类型TranslateDataset分词构建词表将术语词典引入到训练集中
数据清洗
这一步baseline没有,个人自行扩充的,可以参考一下。
我们在获取到数据集后,首先要做的就是一个数据探查任务。我们打开train.txt可以看到
There’s a tight and surprising link between the ocean’s health and ours, says marine biologist Stephen Palumbi. He shows how toxins at the bottom of the ocean food chain find their way into our bodies, with a shocking story of toxic contamination from a Japanese fish market. His work points a way forward for saving the oceans’ health – and humanity’s.
生物学家史蒂芬·帕伦认为,海洋的健康和我们的健康之间有着紧密而神奇的联系。他通过日本一个渔场发生的让人震惊的有毒污染的事件,展示了位于海洋食物链底部的有毒物质是如何进入我们的身体的。他的工作主要是未来拯救海洋健康的方法——同时也包括人类的。
There’s这些,如果我们直接构建词表的话,有可能出现分词为’的情况。所以我们要将这些There’s变成There is。
除此之外,我们要删除一些特殊字符,只保留一些标点符号和数字等。代码如下:
import contractions
def unicodeToAscii(text):
return ''.join(c for c in unicodedata.normalize('NFD', text) if unicodedata.category(c) != 'Mn')
def preprocess_en(text):
text = unicodeToAscii(text.strip())
text = contractions.fix(text)
text = re.sub(r'\([^)]*\)', '', text)
text = re.sub(r"[^a-zA-Z0-9.!?]+", r" ", text) # 保留数字
return text
处理后的数据
There is a tight and surprising link between the ocean s health and ours says marine biologist Stephen Palumbi . He shows how toxins at the bottom of the ocean food chain find their way into our bodies with a shocking story of toxic contamination from a Japanese fish market . His work points a way forward for saving the oceans health and humanity s .
可以看出There’s 已经变成了There is了
接着对中文数据进行处理。在中文数据中,经过探查,竟然发现有(掌声)这种不该出现在翻译文本中的脏数据。比如:
他指着我碗底的三粒米, 然后说"吃干净。" (笑声)
他说,“如果你要回你的车,那么我就要tase(用高压眩晕枪射击)你
Okay. Good. 好,很好!(笑)
But many people see the same thing and think things differently, and one of them is here, Ratan Tata. 看到的是同样的东西, 但很多人的想法却不一样, 其中一个就是,Ratan Tata (Tata集团的现任主席)。
这些脏数据可以使用正则表达式剔除,代码如下:
def preprocess_zh(text):
# 去除(掌声)这些脏数据
text = re.sub(r'\([^)]*\)', '', text)
text = re.sub(r"[^\u4e00-\u9fa5,。!?0-9]", "", text) # 保留数字
return text
这一步操作虽然会删除一些可能真的需要()翻译的内容,但是也是小部分,比如:
Kary Mullis: They might have done it for the teddy bear, yeah. (Kary Mullis回答:)那他们可能也会吧。
构建数据集类型TranslateDataset
代码如下:
class TranslationDataset(Dataset):
def __init__(self, filename, terminology):
self.data = []
with open(filename, 'r', encoding='utf-8') as f:code>
for line in f:
en, zh = line.strip().split('\t')
self.data.append((en, zh))
self.terminology = terminology
# 创建词汇表,注意这里需要确保术语词典中的词也被包含在词汇表中
self.en_tokenizer = get_tokenizer('basic_english')
self.zh_tokenizer = list # 使用字符级分词
en_vocab = Counter(self.terminology.keys()) # 确保术语在词汇表中
zh_vocab = Counter()
for en, zh in self.data:
en_vocab.update(self.en_tokenizer(en))
zh_vocab.update(self.zh_tokenizer(zh))
# 添加术语到词汇表
self.en_vocab = ['<pad>', '<sos>', '<eos>'] + list(self.terminology.keys()) + [word for word, _ in en_vocab.most_common(10000)]
self.zh_vocab = ['<pad>', '<sos>', '<eos>'] + [word for word, _ in zh_vocab.most_common(10000)]
self.en_word2idx = { word: idx for idx, word in enumerate(self.en_vocab)}
self.zh_word2idx = { word: idx for idx, word in enumerate(self.zh_vocab)}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
en, zh = self.data[idx]
en_tensor = torch.tensor([self.en_word2idx.get(word, self.en_word2idx['<sos>']) for word in self.en_tokenizer(en)] + [self.en_word2idx['<eos>']])
zh_tensor = torch.tensor([self.zh_word2idx.get(word, self.zh_word2idx['<sos>']) for word in self.zh_tokenizer(zh)] + [self.zh_word2idx['<eos>']])
return en_tensor, zh_tensor
这个代码主要完成了以下几件事情:
读入了训练集的路径, 将训练集划分为中文和英文文本对加载了术语词表对中英文句子进行分词构建词表(同时添加特殊符号进去词表)构建继承了Dateset类所需要重写的函数__len__和__getitem__
通过这一步,我们可以构建一个数据集对象。这个对象包含了词表和数据内容。
在我们的认知中,神经网络的输入通常是数字信息,那么如何将字符信息转化为数字信息呢?这时我们就需要用到词表。我们可以将句子进行分词,然后对照词表将每个词转化为对应的序号。例如:
假设有以下句子:“我爱自然语言处理”。我们首先进行分词,得到 [“我”, “爱”, “自然”, “语言”, “处理”]。假设词表如下:
词语 | 序号 |
---|---|
我 | 1 |
爱 | 2 |
自然 | 3 |
语言 | 4 |
处理 | 5 |
通过词表,我们可以将句子中的每个词转化为相应的数字序号,得到 [1, 2, 3, 4, 5]。这样,字符信息就成功地转化为神经网络可以处理的数字信息了。
然后神经网络同样预测的也是数字,然后我们根据词表再将神经网络的输出转化为对应的语言。这就完成了预测过程。
这一步在上面代码的__init__和__getitem__完成了
模型搭建和训练
基于seq2seq 的模型
我们使用的是Seq2Seq结构的模型,结构如图所示
两个模块:
Encoder模块Decoder模块
每个模块中由一个GRU网络组成。
这里的原理是Encoder将输入的英文的信息压缩成一个隐藏语义变量c,这个c大概可以理解为下图红框的内容:
然后将c传入解码器Decoder后,解码器根据这个语义信息和来预测我们的中文。
代码如下:
class Encoder(nn.Module):
def __init__(self, input_dim, emb_dim, hid_dim, n_layers, dropout):
super().__init__()
self.embedding = nn.Embedding(input_dim, emb_dim)
self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, dropout=dropout, batch_first=True)
self.dropout = nn.Dropout(dropout)
def forward(self, src):
# src shape: [batch_size, src_len]
embedded = self.dropout(self.embedding(src))
# embedded shape: [batch_size, src_len, emb_dim]
outputs, hidden = self.rnn(embedded)
# outputs shape: [batch_size, src_len, hid_dim]
# hidden shape: [n_layers, batch_size, hid_dim]
return outputs, hidden
class Decoder(nn.Module):
def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
super().__init__()
self.output_dim = output_dim
self.embedding = nn.Embedding(output_dim, emb_dim)
self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, dropout=dropout, batch_first=True)
self.fc_out = nn.Linear(hid_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, input, hidden):
# input shape: [batch_size, 1]
# hidden shape: [n_layers, batch_size, hid_dim]
embedded = self.dropout(self.embedding(input))
# embedded shape: [batch_size, 1, emb_dim]
output, hidden = self.rnn(embedded, hidden)
# output shape: [batch_size, 1, hid_dim]
# hidden shape: [n_layers, batch_size, hid_dim]
prediction = self.fc_out(output.squeeze(1))
# prediction shape: [batch_size, output_dim]
return prediction, hidden
然后将Encoder和Decoder拼接在一起,形成Seq2Seq结构
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, device):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device
def forward(self, src, trg, teacher_forcing_ratio=0.5):
# src 的形状:[batch_size, src_len]
# trg 的形状:[batch_size, trg_len]
batch_size = src.shape[0] # 获取批次大小
trg_len = trg.shape[1] # 获取目标序列的长度
trg_vocab_size = self.decoder.output_dim # 获取目标词汇表的大小
# 初始化输出张量,形状为 [batch_size, trg_len, trg_vocab_size]
outputs = torch.zeros(batch_size, trg_len, trg_vocab_size).to(self.device)
# 将源序列输入编码器,获得隐藏状态
_, hidden = self.encoder(src)
# 取目标序列的第一个词作为解码器的初始输入,通常是开始标记
input = trg[:, 0].unsqueeze(1) # Start token
# 循环遍历目标序列长度,逐步生成输出
for t in range(1, trg_len):
# 将当前输入和隐藏状态输入解码器,获取输出和新的隐藏状态
output, hidden = self.decoder(input, hidden)
# 将当前时间步的输出保存到 outputs 张量中
outputs[:, t, :] = output
# 确定是否使用教师强制
teacher_force = random.random() < teacher_forcing_ratio
# 获取输出中概率最高的词的索引
top1 = output.argmax(1)
# 如果使用教师强制,下一步的输入为目标序列中的下一个词,否则使用当前时间步输出中概率最高的词
input = trg[:, t].unsqueeze(1) if teacher_force else top1.unsqueeze(1)
return outputs # 返回所有时间步的输出
这样我们的Seq2Seq构建的模型已经完成
由于Seq2Seq是NLP比较基础的一个模型,结构简单,无法并行处理。导致训练速度比较慢,而且效果可能比较智障。测了一下全部训练集跑20轮,大概要几小时。最后成绩只有不到2分。上限比较低。所以升级框架为transform框架
基于transformer的模型
后续补充…
4090全数据训练了10个epoch,大概一个小时。效果能去到13.9分
训练代码
就很经典的结构,流程如下:
通过dataloader拿到这一轮的训练数据输入进去模型,得到预测结果拿预测结果和目标句子做交叉熵损失更新梯度不断循环
def train(model, iterator, optimizer, criterion, clip):
model.train()
epoch_loss = 0
for i, (src, trg) in enumerate(iterator):
src, trg = src.to(device), trg.to(device)
optimizer.zero_grad()
output = model(src, trg)
output_dim = output.shape[-1]
output = output[:, 1:].contiguous().view(-1, output_dim)
trg = trg[:, 1:].contiguous().view(-1)
loss = criterion(output, trg)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
optimizer.step()
epoch_loss += loss.item()
return epoch_loss / len(iterator)
main代码
# 主函数
if __name__ == '__main__':
start_time = time.time() # 开始计时
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#terminology = load_terminology_dictionary('../dataset/en-zh.dic')
terminology = load_terminology_dictionary('./data/en-zh.dic')
# 加载数据
dataset = TranslationDataset('./data/train.txt',terminology = terminology)
# 选择数据集的前N个样本进行训练
N = 1000 #int(len(dataset) * 1) # 或者你可以设置为数据集大小的一定比例,如 int(len(dataset) * 0.1)
subset_indices = list(range(N))
subset_dataset = Subset(dataset, subset_indices)
train_loader = DataLoader(subset_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
# 定义模型参数
INPUT_DIM = len(dataset.en_vocab)
OUTPUT_DIM = len(dataset.zh_vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
# 初始化模型
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)
model = Seq2Seq(enc, dec, device).to(device)
# 定义优化器和损失函数
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=dataset.zh_word2idx['<pad>'])
# 训练模型
N_EPOCHS = 10
CLIP = 1
for epoch in range(N_EPOCHS):
train_loss = train(model, train_loader, optimizer, criterion, CLIP)
print(f'Epoch: { epoch+1:02} | Train Loss: { train_loss:.3f}')
# 在训练循环结束后保存模型
torch.save(model.state_dict(), './translation_model_GRU.pth')
end_time = time.time() # 结束计时
# 计算并打印运行时间
elapsed_time_minute = (end_time - start_time)/60
print(f"Total running time: { elapsed_time_minute:.2f} minutes")
基于BLUE4的评估指标
BLEU是评估机器翻译和自然语言生成任务的一种标准指标。BLEU4是指采用4-gram的BLEU评分方法。BLEU评分的核心思想是通过计算机器生成的文本与参考文本之间的n-gram匹配程度来衡量生成文本的质量。具体而言,BLEU4会考虑1-gram到4-gram的匹配情况。
BLEU4 的计算步骤:
分词:将待评估的生成文本和参考文本分词。计算n-gram匹配:计算1-gram、2-gram、3-gram和4-gram的匹配情况。计算精确度:对于每个n-gram,计算生成文本中的n-gram与参考文本中的n-gram的匹配数量,然后除以生成文本中n-gram的总数,得到精确度。计算加权几何平均:将1-gram到4-gram的精确度取对数,然后计算它们的加权几何平均。惩罚因子:为了解决短文本可能获得较高分数的问题,引入惩罚因子BP(Brevity Penalty),BP的计算基于生成文本和参考文本的长度比值。计算最终的BLEU4分数:将加权几何平均与惩罚因子相乘得到最终的BLEU4分数。
举个例子:
假设生成句子和参考句子有些不同:
参考句子:"The cat is on the mat"
生成句子:"The cat sat on the mat"
1. 分词:
参考句子分词结果:[“The”, “cat”, “is”, “on”, “the”, “mat”]
生成句子分词结果:[“The”, “cat”, “sat”, “on”, “the”, “mat”]
2. 计算n-gram匹配:
1-gram匹配:5个 (“The”, “cat”, “on”, “the”, “mat”)2-gram匹配:4个 (“The cat”, “on the”, “the mat”)3-gram匹配:2个 (“on the mat”)4-gram匹配:1个 (“on the mat”)
3. 计算精确度:
1-gram精确度:5/6 ≈ 0.8332-gram精确度:4/5 = 0.83-gram精确度:2/4 = 0.54-gram精确度:1/3 ≈ 0.333
4. 计算加权几何平均:
BLEU4
=
exp
(
1
4
(
log
0.833
+
log
0.8
+
log
0.5
+
log
0.333
)
)
≈
0.599
\text{BLEU4} = \exp\left(\frac{1}{4} (\log 0.833 + \log 0.8 + \log 0.5 + \log 0.333)\right) ≈ 0.599
BLEU4=exp(41(log0.833+log0.8+log0.5+log0.333))≈0.599
5. 计算惩罚因子BP:
生成句子和参考句子长度相同,因此BP = 1。
6. 计算最终的BLEU4分数:
最终的BLEU4分数 ≈ 0.599(加权几何平均) * 1(BP) ≈ 0.599
在这个例子中,生成句子和参考句子有一些不同,因此BLEU4得分相对较低,反映了生成句子与参考句子的匹配程度不高。
实现代码
import torch
from sacrebleu.metrics import BLEU
from typing import List
# 假设我们已经定义了TranslationDataset, Encoder, Decoder, Seq2Seq类
def load_sentences(file_path: str) -> List[str]:
with open(file_path, 'r', encoding='utf-8') as f:code>
return [line.strip() for line in f]
# 更新translate_sentence函数以考虑术语词典
def translate_sentence(sentence: str, model: Seq2Seq, dataset: TranslationDataset, terminology, device: torch.device, max_length: int = 50):
model.eval()
tokens = dataset.en_tokenizer(sentence)
tensor = torch.LongTensor([dataset.en_word2idx.get(token, dataset.en_word2idx['<sos>']) for token in tokens]).unsqueeze(0).to(device) # [1, seq_len]
with torch.no_grad():
_, hidden = model.encoder(tensor)
translated_tokens = []
input_token = torch.LongTensor([[dataset.zh_word2idx['<sos>']]]).to(device) # [1, 1]
for _ in range(max_length):
output, hidden = model.decoder(input_token, hidden)
top_token = output.argmax(1)
translated_token = dataset.zh_vocab[top_token.item()]
if translated_token == '<eos>':
break
# 如果翻译的词在术语词典中,则使用术语词典中的词
if translated_token in terminology.values():
for en_term, ch_term in terminology.items():
if translated_token == ch_term:
translated_token = en_term
break
translated_tokens.append(translated_token)
input_token = top_token.unsqueeze(1) # [1, 1]
return ''.join(translated_tokens)
def evaluate_bleu(model: Seq2Seq, dataset: TranslationDataset, src_file: str, ref_file: str, terminology,device: torch.device):
model.eval()
src_sentences = load_sentences(src_file)
ref_sentences = load_sentences(ref_file)
translated_sentences = []
for src in src_sentences:
translated = translate_sentence(src, model, dataset, terminology, device)
translated_sentences.append(translated)
bleu = BLEU()
score = bleu.corpus_score(translated_sentences, [ref_sentences])
return score
# 主函数
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载术语词典
terminology = load_terminology_dictionary('./data/en-zh.dic')
# 创建数据集实例时传递术语词典
dataset = TranslationDataset('./data/train.txt', terminology)
# 定义模型参数
INPUT_DIM = len(dataset.en_vocab)
OUTPUT_DIM = len(dataset.zh_vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
# 初始化模型
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)
model = Seq2Seq(enc, dec, device).to(device)
# 加载训练好的模型
model.load_state_dict(torch.load('./translation_model_GRU.pth'))
# 评估BLEU分数
bleu_score = evaluate_bleu(model, dataset, './data/dev_en.txt', './data/dev_zh.txt', terminology = terminology,device = device)
print(f'BLEU-4 score: { bleu_score.score:.2f}')
这里是对验证集进行评估,由于Seq2Seq模型训练上限比较低,看上去有点像智障。所以BLEU为0也不是什么奇怪的事情
推理
最后我们训练完这个模型后肯定要对模型进行推理,推理代码如下:
def inference(model: Seq2Seq, dataset: TranslationDataset, src_file: str, save_dir:str, terminology, device: torch.device):
model.eval()
src_sentences = load_sentences(src_file)
translated_sentences = []
for src in src_sentences:
translated = translate_sentence(src, model, dataset, terminology, device)
#print(translated)
translated_sentences.append(translated)
#print(translated_sentences)
# 将列表元素连接成一个字符串,每个元素后换行
text = '\n'.join(translated_sentences)
# 打开一个文件,如果不存在则创建,'w'表示写模式
with open(save_dir, 'w', encoding='utf-8') as f:code>
# 将字符串写入文件
f.write(text)
#return translated_sentences
# 主函数
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载术语词典
terminology = load_terminology_dictionary('./data/en-zh.dic')
# 加载数据集和模型
dataset = TranslationDataset('./data/train.txt',terminology = terminology)
# 定义模型参数
INPUT_DIM = len(dataset.en_vocab)
OUTPUT_DIM = len(dataset.zh_vocab)
ENC_EMB_DIM = 256
DEC_EMB_DIM = 256
HID_DIM = 512
N_LAYERS = 2
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
# 初始化模型
enc = Encoder(INPUT_DIM, ENC_EMB_DIM, HID_DIM, N_LAYERS, ENC_DROPOUT)
dec = Decoder(OUTPUT_DIM, DEC_EMB_DIM, HID_DIM, N_LAYERS, DEC_DROPOUT)
model = Seq2Seq(enc, dec, device).to(device)
# 加载训练好的模型
model.load_state_dict(torch.load('./translation_model_GRU.pth'))
save_dir = './data/submit.txt'
inference(model, dataset, src_file="./data/test_en.txt", save_dir = save_dir, terminology = terminology, device = device)code>
print(f"翻译完成!文件已保存到{ save_dir}")
运行完就会得到一个submit.txt文件,提交这个就可以拿到这个比赛的分数啦。
一提交(被狠狠的打击到了):
痛定思痛下马上写了一个transform框架的,然后狠狠的上了一把分(
关于transformer的baseline我有空在写多一篇文章发一下。
声明
本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。