×

PyTorch教程11.4之Bahdanau注意力机制

消耗积分:0 | 格式:pdf | 大小:0.38 MB | 2023-06-05

王彬

分享资料个

当我们在10.7 节遇到机器翻译时,我们设计了一个基于两个 RNN 的序列到序列 (seq2seq) 学习的编码器-解码器架构 ( Sutskever et al. , 2014 )具体来说,RNN 编码器将可变长度序列转换为固定形状的上下文变量。然后,RNN 解码器根据生成的标记和上下文变量逐个标记地生成输出(目标)序列标记。

回想一下我们在下面重印的图 10.7.2 (图 11.4.1)以及一些额外的细节。通常,在 RNN 中,有关源序列的所有相关信息都由编码器转换为某种内部固定维状态表示。正是这种状态被解码器用作生成翻译序列的完整和唯一的信息源。换句话说,seq2seq 机制将中间状态视为可能作为输入的任何字符串的充分统计。

https://file.elecfans.com/web2/M00/A9/C9/poYBAGR9N_qACEJlAAF4rEvQWMo465.svg

图 11.4.1序列到序列模型。编码器生成的状态是编码器和解码器之间唯一共享的信息。

虽然这对于短序列来说是相当合理的,但很明显这对于长序列来说是不可行的,比如一本书的章节,甚至只是一个很长的句子。毕竟,一段时间后,中间表示中将根本没有足够的“空间”来存储源序列中所有重要的内容。因此,解码器将无法翻译又长又复杂的句子。第一个遇到的人是 格雷夫斯 ( 2013 )当他们试图设计一个 RNN 来生成手写文本时。由于源文本具有任意长度,他们设计了一个可区分的注意力模型来将文本字符与更长的笔迹对齐,其中对齐仅在一个方向上移动。这反过来又利用了语音识别中的解码算法,例如隐马尔可夫模型 Rabiner 和 Juang,1993 年

受到学​​习对齐的想法的启发, Bahdanau等人。( 2014 )提出了一种没有单向对齐限制的可区分注意力模型。在预测标记时,如果并非所有输入标记都相关,则模型仅对齐(或关注)输入序列中被认为与当前预测相关的部分。然后,这用于在生成下一个令牌之前更新当前状态。虽然在其描述中相当无伤大雅,但这种Bahdanau 注意力机制可以说已经成为过去十年深度学习中最有影响力的想法之一,并催生了 Transformers Vaswani等人,2017 年以及许多相关的新架构。

import torch
from torch import nn
from d2l import torch as d2l
from mxnet import init, np, npx
from mxnet.gluon import nn, rnn
from d2l import mxnet as d2l

npx.set_np()
import jax
from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l
import tensorflow as tf
from d2l import tensorflow as d2l

11.4.1。模型

我们遵循第 10.7 节的 seq2seq 架构引入的符号 ,特别是(10.7.3)关键思想是,而不是保持状态,即上下文变量c将源句子总结为固定的,我们动态更新它,作为原始文本(编码器隐藏状态)的函数ht) 和已经生成的文本(解码器隐藏状态st′−1). 这产生 ct′, 在任何解码时间步后更新 t′. 假设输入序列的长度T. 在这种情况下,上下文变量是注意力池的输出:

(11.4.1)ct′=∑t=1Tα(st′−1,ht)ht.

我们用了st′−1作为查询,和 ht作为键和值。注意 ct′然后用于生成状态 st′并生成一个新令牌(参见 (10.7.3))。特别是注意力权重 α使用由 ( 11.3.7 )定义的附加注意评分函数按照 (11.3.3)计算这种使用注意力的 RNN 编码器-解码器架构如图 11.4.2所示请注意,后来对该模型进行了修改,例如在解码器中包含已经生成的标记作为进一步的上下文(即,注意力总和确实停止在T而是它继续进行t′−1). 例如,参见Chan等人。( 2015 )描述了这种应用于语音识别的策略。

https://file.elecfans.com/web2/M00/AA/44/pYYBAGR9N_2AIf3lAAG83XwjOJ8743.svg

图 11.4.2具有 Bahdanau 注意机制的 RNN 编码器-解码器模型中的层。

11.4.2。用注意力定义解码器

要实现带有注意力的 RNN 编码器-解码器,我们只需要重新定义解码器(从注意力函数中省略生成的符号可以简化设计)。让我们通过定义一个意料之中的命名类来开始具有注意力的解码器的基本接口 AttentionDecoder

class AttentionDecoder(d2l.Decoder): #@save
  """The base attention-based decoder interface."""
  def __init__(self):
    super().__init__()

  @property
  def attention_weights(self):
    raise NotImplementedError
class AttentionDecoder(d2l.Decoder): #@save
  """The base attention-based decoder interface."""
  def __init__(self):
    super().__init__()

  @property
  def attention_weights(self):
    raise NotImplementedError
class AttentionDecoder(d2l.Decoder): #@save
  """The base attention-based decoder interface."""
  def __init__(self):
    super().__init__()

  @property
  def attention_weights(self):
    raise NotImplementedError

我们需要在Seq2SeqAttentionDecoder 类中实现 RNN 解码器。解码器的状态初始化为(i)编码器最后一层在所有时间步的隐藏状态,用作注意力的键和值;(ii) 编码器在最后一步的所有层的隐藏状态。这用于初始化解码器的隐藏状态;(iii) 编码器的有效长度,以排除注意力池中的填充标记。在每个解码时间步,解码器最后一层的隐藏状态,在前一个时间步获得,用作注意机制的查询。注意机制的输出和输入嵌入都被连接起来作为 RNN 解码器的输入。

class Seq2SeqAttentionDecoder(AttentionDecoder):
  def __init__(self, vocab_size, embed_size<

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

评论(0)
发评论

下载排行榜

全部0条评论

快来发表一下你的评论吧 !

'+ '

'+ '

'+ ''+ '
'+ ''+ ''+ '
'+ ''+ '' ); $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code ==5){ $(pop_this).attr('href',"/login/index.html"); return false } if(data.code == 2){ //跳转到VIP升级页面 window.location.href="//m.lene-v.com/vip/index?aid=" + webid return false } //是会员 if (data.code > 0) { $('body').append(htmlSetNormalDownload); var getWidth=$("#poplayer").width(); $("#poplayer").css("margin-left","-"+getWidth/2+"px"); $('#tips').html(data.msg) $('.download_confirm').click(function(){ $('#dialog').remove(); }) } else { var down_url = $('#vipdownload').attr('data-url'); isBindAnalysisForm(pop_this, down_url, 1) } }); }); //是否开通VIP $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code == 2 || data.code ==5){ //跳转到VIP升级页面 $('#vipdownload>span').text("开通VIP 免费下载") return false }else{ // 待续费 if(data.code == 3) { vipExpiredInfo.ifVipExpired = true vipExpiredInfo.vipExpiredDate = data.data.endoftime } $('#vipdownload .icon-vip-tips').remove() $('#vipdownload>span').text("VIP免积分下载") } }); }).on("click",".download_cancel",function(){ $('#dialog').remove(); }) var setWeixinShare={};//定义默认的微信分享信息,页面如果要自定义分享,直接更改此变量即可 if(window.navigator.userAgent.toLowerCase().match(/MicroMessenger/i) == 'micromessenger'){ var d={ title:'PyTorch教程11.4之Bahdanau注意力机制',//标题 desc:$('[name=description]').attr("content"), //描述 imgUrl:'https://'+location.host+'/static/images/ele-logo.png',// 分享图标,默认是logo link:'',//链接 type:'',// 分享类型,music、video或link,不填默认为link dataUrl:'',//如果type是music或video,则要提供数据链接,默认为空 success:'', // 用户确认分享后执行的回调函数 cancel:''// 用户取消分享后执行的回调函数 } setWeixinShare=$.extend(d,setWeixinShare); $.ajax({ url:"//www.lene-v.com/app/wechat/index.php?s=Home/ShareConfig/index", data:"share_url="+encodeURIComponent(location.href)+"&format=jsonp&domain=m", type:'get', dataType:'jsonp', success:function(res){ if(res.status!="successed"){ return false; } $.getScript('https://res.wx.qq.com/open/js/jweixin-1.0.0.js',function(result,status){ if(status!="success"){ return false; } var getWxCfg=res.data; wx.config({ //debug: true, // 开启调试模式,调用的所有api的返回值会在客户端alert出来,若要查看传入的参数,可以在pc端打开,参数信息会通过log打出,仅在pc端时才会打印。 appId:getWxCfg.appId, // 必填,公众号的唯一标识 timestamp:getWxCfg.timestamp, // 必填,生成签名的时间戳 nonceStr:getWxCfg.nonceStr, // 必填,生成签名的随机串 signature:getWxCfg.signature,// 必填,签名,见附录1 jsApiList:['onMenuShareTimeline','onMenuShareAppMessage','onMenuShareQQ','onMenuShareWeibo','onMenuShareQZone'] // 必填,需要使用的JS接口列表,所有JS接口列表见附录2 }); wx.ready(function(){ //获取“分享到朋友圈”按钮点击状态及自定义分享内容接口 wx.onMenuShareTimeline({ title: setWeixinShare.title, // 分享标题 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享给朋友”按钮点击状态及自定义分享内容接口 wx.onMenuShareAppMessage({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 type: setWeixinShare.type, // 分享类型,music、video或link,不填默认为link dataUrl: setWeixinShare.dataUrl, // 如果type是music或video,则要提供数据链接,默认为空 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ”按钮点击状态及自定义分享内容接口 wx.onMenuShareQQ({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到腾讯微博”按钮点击状态及自定义分享内容接口 wx.onMenuShareWeibo({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ空间”按钮点击状态及自定义分享内容接口 wx.onMenuShareQZone({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); }); }); } }); } function openX_ad(posterid, htmlid, width, height) { if ($(htmlid).length > 0) { var randomnumber = Math.random(); var now_url = encodeURIComponent(window.location.href); var ga = document.createElement('iframe'); ga.src = 'https://www1.elecfans.com/www/delivery/myafr.php?target=_blank&cb=' + randomnumber + '&zoneid=' + posterid+'&prefer='+now_url; ga.width = width; ga.height = height; ga.frameBorder = 0; ga.scrolling = 'no'; var s = $(htmlid).append(ga); } } openX_ad(828, '#berry-300', 300, 250);