Table of Contents
模型结构图
模型结构
import torch.nn as nnimport torch.nn.functional as F
class RNN(nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.hidden_size = hidden_size self.i2h = nn.Linear(input_size, hidden_size) self.h2h = nn.Linear(hidden_size, hidden_size) self.h2o = nn.Linear(hidden_size, output_size) self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden): hidden = F.tanh(self.i2h(input) + self.h2h(hidden)) output = self.h2o(hidden) output = self.softmax(output) return output, hidden
def initHidden(self): return torch.zeros(1, self.hidden_size)
n_hidden = 128rnn = RNN(n_letters, n_hidden, n_categories)单词One-hot
# word to vecimport torch
# 26*2 个字母和分隔符def letterToIndex(letter): return all_letters.find(letter)
# letterToTensor (1,letters) (one-hot)def letterToTensor(letter): tensor = torch.zeros(1, n_letters) tensor[0][letterToIndex(letter)] = 1 return tensor
# line_1_nletterdef lineToTensor(line): tensor = torch.zeros(len(line), 1, n_letters) for li, letter in enumerate(line): tensor[li][0][letterToIndex(letter)] = 1 return tensorprint(letterToTensor('J'))print(lineToTensor('Jones').size())训练过程
每次训练输入一个单词,参数更新需手动操作
def train(category_tensor, line_tensor): hidden = rnn.initHidden() rnn.zero_grad() category_tensor, line_tensor = category_tensor.to(device), line_tensor.to(device)
for i in range(line_tensor.shape[0]):
output, hidden = rnn(line_tensor[i], hidden)
loss = criterion(output, category_tensor) loss.backward() # Add parameters' gradients to their values, multiplied by learning rate for p in rnn.parameters(): p.data.add_(p.grad.data, alpha=-learning_rate) return output, loss.item()输出转换
tensor.topk()默认是从最后一个维度,选择topk个,默认从大到小,结果默认排序,top_n表数据, top_i表索引
def categoryFromOutput(output): top_n, top_i = output.topk(1) # top_n is value, top_i is index category_i = top_i[0].item() return all_categories[category_i], category_i
print(categoryFromOutput(output))绘制混淆矩阵
主要在于统计,正常分类和误分类点
# 混淆矩阵confusion = torch.zeros(n_categories, n_categories)n_confusion = 10000
# just return an output given a linedef evaluate(line_tensor): hidden = rnn.initHidden() for i in range(line_tensor.shape[0]): output, hidden = rnn(line_tensor[i], hidden) return output
# go through a batch of examples and record which are correctly guessedfor i in range(n_confusion): category, line, category_tensor, line_tensor = randomTrainingExample() output = evaluate(line_tensor) guess, guess_i = categoryFromOutput(output) category_i = all_categories.index(category) # x轴是实际的类,y轴是推测的类 confusion[category_i][guess_i] += 1
# Normalize by dividing every row by its sum,归一化for i in range(n_categories): confusion[i] = confusion[i] / confusion[i].sum()# set up plotfig = plt.figure()ax = fig.add_subplot(111)cax = ax.matshow(confusion.numpy())fig.colorbar(cax)
# Set up axesax.set_xticklabels([''] + all_categories, rotation=90)ax.set_yticklabels([''] + all_categories)
# Force label at every tickax.xaxis.set_major_locator(ticker.MultipleLocator(1))ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
# sphinx_gallery_thumbnail_number = 2plt.show()