与我们大多数从头开始的实施一样, 第 9.5 节旨在深入了解每个组件的工作原理。但是,当您每天使用 RNN 或编写生产代码时,您会希望更多地依赖于减少实现时间(通过为通用模型和函数提供库代码)和计算时间(通过优化这些库实现)。本节将向您展示如何使用深度学习框架提供的高级 API 更有效地实现相同的语言模型。和以前一样,我们首先加载时间机器数据集。
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
9.6.1. 定义模型
我们使用由高级 API 实现的 RNN 定义以下类。
Specifically, to initialize the hidden state, we invoke the member method begin_state
. This returns a list that contains an initial hidden state for each example in the minibatch, whose shape is (number of hidden layers, batch size, number of hidden units). For some models to be introduced later (e.g., long short-term memory), this list will also contain other information.
class RNN(d2l.Module): #@save
"""The RNN model implemented with high-level APIs."""
def __init__(self, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = rnn.RNN(num_hiddens)
def forward(self, inputs, H=None):
if H is None:
H, = self.rnn.begin_state(inputs.shape[1], ctx=inputs.ctx)
outputs, (H, ) = self.rnn(inputs, (H, ))
return outputs, H
Flax does not provide an RNNCell for concise implementation of Vanilla RNNs as of today. There are more advanced variants of RNNs like LSTMs and GRUs which are available in the Flax linen
API.
class RNN(d2l.Module): #@save
"""The RNN model implemented with high-level APIs."""
def __init__(self, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = tf.keras.layers.SimpleRNN(
num_hiddens, return_sequences=True, return_state=True,
time_major=True)
def forward(self, inputs, H=None):
outputs, H = self.rnn(inputs, H)
return outputs, H
继承自9.5 节RNNLMScratch
中的类 ,下面的类定义了一个完整的基于 RNN 的语言模型。请注意,我们需要创建一个单独的全连接输出层。RNNLM
class RNNLM(d2l.RNNLMScratch): #@save
"""The RNN-based language model implemented with high-level APIs."""
training: bool = True
def setup(self):
self.linear = nn.Dense(self.vocab_size)
def output_layer(self, hiddens):
return self.linear(hiddens).swapaxes(0, 1)
def forward(self, X, state=None):
embs = self.one_hot(X)
rnn_outputs, _ = self.rnn(embs, state, self.training)
return self.output_layer(rnn_outputs)
9.6.2. 训练和预测
在训练模型之前,让我们使用随机权重初始化的模型进行预测。鉴于我们还没有训练网络,它会产生无意义的预测。
'it hasgggggggggggggggggggg'
'it hasxlxlxlxlxlxlxlxlxlxl'
接下来,我们利用高级 API 训练我们的模型。
与第 9.5 节相比,该模型实现了相当的困惑度,但由于实现优化,运行速度更快。和以前一样,我们可以在指定的前缀字符串之后生成预测标记。