做网站阜新百度手机助手安卓版
门控循环单元GRU
对于一个序列,不是每个观察值都是同等重要的,可能会遇到一下几种情况:
-
早期观测值对预测所有未来观测值都具有非常重要的意义。
考虑极端情况,第一个观测值包含一个校验和,目的是在序列的末尾辨别校验和事否正确,我们希望有某些机制在一个记忆元里存储重要的早期信息。如果没有这样的机制,我们将不得不给这个观测值指定一个非常大的梯度。
-
一些词元没有相关的观测值
在对网页内容进行情感分析时,可能一些辅助的HTML代码与网页传达的情绪无关,我们希望有一些机制来跳过隐状态中的此类词元
-
序列的各个部分存在逻辑中断
书的章节之间可能也会有过渡,证券的熊市,牛市之间可能会有过渡。这种情况下, 最好有一种方法来重置我们的内部状态表示
有很多方法来解决这类问题,最早的方法是"长短期记忆"(long-short-term memory,LSTM)。门控循环单元(gated recurrent unit,GRU)是一个稍微简化的变体,通常能提供同等的效果,并且计算速度更快。
1.门控隐状态
门控循环单元与普通的循环神经网络之间的关键区别在于: 前者支持隐状态的门控。 这意味着模型有专门的机制来确定应该何时更新隐状态, 以及应该何时重置隐状态。这些机制是可学习的。
1.1 重置门和更新门
重置门和更新门的输入如图所示。重置门允许我们控制”可能还想记住“的过去状态的数量;更新门将允许我们控制新状态中有多少个是旧状态的副本。
其中输入是由当前时间步的输入和前一时间步的隐状态给出,两个门的输出由使用sigmoid激活函数的两个全连接层给出。
假设输入是一个小批量 X t ∈ R n × d X_t\in \R^{n\times d} Xt∈Rn×d(样本数量 n n n,输入个数 d d d),上一个时间步的隐状态是 H t − 1 ∈ R n × h H_{t-1}\in \R^{n\times h} Ht−1∈Rn×h(隐藏单元个数 h h h)。那么重置门 R t R_t Rt和更新门 Z t Z_t Zt(均为 R n × h \R^{n\times h} Rn×h)的计算如下所示:
R t = σ ( X t W x r + H t − 1 W h r + b r ) Z t = σ ( X t W x z + H t − 1 W h z + b z ) R_t = \sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r)\\ Z_t = \sigma(X_t W_{xz}+H_{t-1}W_{hz}+b_z) Rt=σ(XtWxr+Ht−1Whr+br)Zt=σ(XtWxz+Ht−1Whz+bz)
其中 W x r , W x z ∈ R d × h W_{xr},W_{xz}\in \R^{d\times h} Wxr,Wxz∈Rd×h和 W h r , W h z ∈ R h × h W_{hr},W_{hz}\in \R^{h\times h} Whr,Whz∈Rh×h是权重参数, b r , b z ∈ R 1 × h b_r,b_z\in \R^{1\times h} br,bz∈R1×h是偏置参数。求和过程中会触发广播机制。 我们使用sigmoid函数将输入值转换到区间¥(0,1)$。
1.2 候选隐状态
将重置门 R t R_t Rt与常规隐状态更新机制集成,得到在时间步 t t t的候选隐状态 H ^ t ∈ R n × h \hat{H}_t\in\R ^{n\times h} H^t∈Rn×h:
H ^ t = t a n h ( X t W x h + ( R t ⊙ H t − 1 ) W h h + b h ) \hat{H}_t = tanh(X_tW_{xh}+(R_t\odot H_{t-1})W_{hh}+b_h) H^t=tanh(XtWxh+(Rt⊙Ht−1)Whh+bh)
其中 W x h ∈ R d × h W_{xh}\in\R^{d\times h} Wxh∈Rd×h和 W h h ∈ R h × h W_{hh}\in \R ^{h\times h} Whh∈Rh×h是权重参数, b h ∈ R 1 × h b_h\in \R^{1\times h} bh∈R1×h是偏置项,符号 ⊙ \odot ⊙是Hadamard积(按元素乘积)运算符,此处使用tanh非线性激活函数确保候选隐状态中的值保持在区间 ( − 1 , 1 ) (-1,1) (−1,1)中。。
R t ⊙ H t − 1 R_t\odot H_{t-1} Rt⊙Ht−1的元素相乘可以减少以往状态的影响,每当重置门 R t R_t Rt中的项接近1时,我们恢复一个普通的循环神经网络,如果 R t R_t Rt全为0,则之前的信息全部遗忘。重置门是可以学习的,通过学习,可以根据目前的输入决定哪些东西需要遗忘。
1.3 隐状态
1.2中得出的是候选隐状态,真正的隐状态需要结合更新门的效果。这一步确定新的隐状态 H t ∈ R n × h H_t\in \R^{n\times h} Ht∈Rn×h在多大程度上来自旧的状态 H t − 1 H_{t-1} Ht−1和新的候选状态 H t ^ \hat{H_t} Ht^。更新门 Z t Z_t Zt仅需要在 H t − 1 H_{t-1} Ht−1和 H ^ t \hat{H}_t H^t之间进行按元素的凸组合就可以实现,于是得出了最终的更新公式:
H t = Z t ⊙ H t − 1 + ( 1 − Z t ) ⊙ H ^ t H_t =Z_t \odot H_{t-1}+(1-Z_t)\odot \hat{H}_t Ht=Zt⊙Ht−1+(1−Zt)⊙H^t
容易看出,更新门 Z t Z_t Zt越趋近1,模型就倾向只保留旧状态,此时来自输入 X t X_t Xt的信息基本上被忽略,从而有效地跳过了依赖链条中的时间步 t t t。相反,当 Z t Z_t Zt接近0时,新的隐状态 H t H_t Ht就会接近候选隐状态 H t ^ \hat {H_t} Ht^
2.代码实现
2.1 从零开始
import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xz, W_hz, b_z = three() # 更新门参数W_xr, W_hr, b_r = three() # 重置门参数W_xh, W_hh, b_h = three() # 候选隐状态参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return paramsdef init_gru_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = H @ W_hq + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)
2.2 训练与预测
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
2.3 简洁实现
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)