×

PyTorch教程3.2之面向对象的设计实现

消耗积分:0 | 格式:pdf | 大小:0.22 MB | 2023-06-05

陈飞

分享资料个

在我们对线性回归的介绍中,我们介绍了各种组件,包括数据、模型、损失函数和优化算法。事实上,线性回归是最简单的机器学习模型之一。然而,训练它使用许多与本书中其他模型所需的组件相同的组件。因此,在深入了解实现细节之前,有必要设计一些贯穿本书的 API。将深度学习中的组件视为对象,我们可以从为这些对象及其交互定义类开始。这种面向对象的实现设计将极大地简化演示,您甚至可能想在您的项目中使用它。

受PyTorch Lightning等开源库的启发,在高层次上我们希望拥有三个类:(i)Module包含模型、损失和优化方法;(ii)DataModule提供用于训练和验证的数据加载器;(iii) 两个类结合使用该类 Trainer,这使我们能够在各种硬件平台上训练模型。本书中的大部分代码都改编自Moduleand DataModuleTrainer只有在讨论 GPU、CPU、并行训练和优化算法时,我们才会涉及该类。

import time
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l
import time
import numpy as np
from mxnet.gluon import nn
from d2l import mxnet as d2l
import time
from dataclasses import field
from typing import Any
import jax
import numpy as np
from flax import linen as nn
from flax.training import train_state
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import time
import numpy as np
import tensorflow as tf
from d2l import torch as d2l

3.2.1. 公用事业

我们需要一些实用程序来简化 Jupyter 笔记本中的面向对象编程。挑战之一是类定义往往是相当长的代码块。笔记本电脑的可读性需要简短的代码片段,穿插着解释,这种要求与 Python 库常见的编程风格不相容。第一个实用函数允许我们在创建类后将函数注册为类中的方法。事实上,即使我们已经创建了类的实例,我们也可以这样做!它允许我们将一个类的实现拆分成多个代码块。

def add_to_class(Class): #@save
  """Register functions as methods in created class."""
  def wrapper(obj):
    setattr(Class, obj.__name__, obj)
  return wrapper

让我们快速浏览一下如何使用它。我们计划 A用一个方法来实现一个类do我们可以先声明类并创建一个实例,而不是在同一个代码块中A同时 拥有两者的代码doAa

class A:
  def __init__(self):
    self.b = 1

a = A()

do接下来我们像往常一样 定义方法,但不在 classA的范围内。相反,我们add_to_class用类A作为参数来装饰这个方法。这样做时,该方法能够访问 的成员变量,A正如我们所期望的那样,如果它已被定义为 的A定义的一部分。让我们看看当我们为实例调用它时会发生什么a

@add_to_class(A)
def do(self):
  print('Class attribute "b" is', self.b)

a.do()
Class attribute "b" is 1
@add_to_class(A)
def do(self):
  print('Class attribute "b" is', self.b)

a.do()
Class attribute "b" is 1
@add_to_class(A)
def do(self):
  print('Class attribute "b" is', self.b)

a.do()
Class attribute "b" is 1
@add_to_class(A)
def do(self):
  print('Class attribute "b" is', self.b)

a.do()
Class attribute "b" is 1

第二个是实用程序类,它将类 __init__方法中的所有参数保存为类属性。这使我们无需额外代码即可隐式扩展构造函数调用签名。

class HyperParameters: #@save
  """The base class of hyperparameters."""
  def save_hyperparameters(self, ignore=[]):
    raise NotImplemented

我们将其实施推迟到第 23.7 节HyperParameters要使用它,我们定义继承自该方法并调用 save_hyperparameters该方法的类__init__

# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
  def __init__(self, a, b, c):
    self.save_hyperparameters(ignore=['c'])
    print('self.a =', self.a, 'self.b =', self.b)
    print('There is no self.c =', not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True
# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
  def __init__(self, a, b, c):
    self.save_hyperparameters(ignore=['c'])
    print('self.a =', self.a, 'self.b =', self.b)
    print('There is no self.c =', not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True
# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
  def __init__(self, a, b, c):
    self.save_hyperparameters(ignore=['c'])
    print('self.a =', self.a, 'self.b =', self.b)
    print('There is no self.c =', not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True
# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
  def __init__(self, a, b, c):
    self.save_hyperparameters(ignore=['c'])
    print('self.a =', self.a, 'self.b =', self.b)
    print('There is no self.c =', not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True

最后一个实用程序允许我们在实验进行时以交互方式绘制实验进度。为了尊重更强大(和复杂)的TensorBoard,我们将其命名为ProgressBoard实现推迟到 第 23.7 节现在,让我们简单地看看它的实际效果。

该方法在图中 draw绘制一个点,并在图例中指定。可选的仅通过显示来平滑线条(x, y)labelevery_n1/n图中的点。他们的价值是从平均n原始图中的邻居点。

class ProgressBoard(d2l.HyperParameters): #@save
  """The board that plots data points in animation."""
  def __init__(self, xlabel=None, ylabel=None, xlim=None,
         ylim=None, xscale='linear', yscale='linear',
         ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
         fig=None, axes=None, figsize=(3.5, 2.5), display=True):
    self.save_hyperparameters()

  def draw(self, x, y, label, every_n=1):
    raise NotImpleme

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

评论(0)
发评论

下载排行榜

全部0条评论

快来发表一下你的评论吧 !

'+ '

'+ '

'+ ''+ '
'+ ''+ ''+ '
'+ ''+ '' ); $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code ==5){ $(pop_this).attr('href',"/login/index.html"); return false } if(data.code == 2){ //跳转到VIP升级页面 window.location.href="//m.lene-v.com/vip/index?aid=" + webid return false } //是会员 if (data.code > 0) { $('body').append(htmlSetNormalDownload); var getWidth=$("#poplayer").width(); $("#poplayer").css("margin-left","-"+getWidth/2+"px"); $('#tips').html(data.msg) $('.download_confirm').click(function(){ $('#dialog').remove(); }) } else { var down_url = $('#vipdownload').attr('data-url'); isBindAnalysisForm(pop_this, down_url, 1) } }); }); //是否开通VIP $.get('/article/vipdownload/aid/'+webid,function(data){ if(data.code == 2 || data.code ==5){ //跳转到VIP升级页面 $('#vipdownload>span').text("开通VIP 免费下载") return false }else{ // 待续费 if(data.code == 3) { vipExpiredInfo.ifVipExpired = true vipExpiredInfo.vipExpiredDate = data.data.endoftime } $('#vipdownload .icon-vip-tips').remove() $('#vipdownload>span').text("VIP免积分下载") } }); }).on("click",".download_cancel",function(){ $('#dialog').remove(); }) var setWeixinShare={};//定义默认的微信分享信息,页面如果要自定义分享,直接更改此变量即可 if(window.navigator.userAgent.toLowerCase().match(/MicroMessenger/i) == 'micromessenger'){ var d={ title:'PyTorch教程3.2之面向对象的设计实现',//标题 desc:$('[name=description]').attr("content"), //描述 imgUrl:'https://'+location.host+'/static/images/ele-logo.png',// 分享图标,默认是logo link:'',//链接 type:'',// 分享类型,music、video或link,不填默认为link dataUrl:'',//如果type是music或video,则要提供数据链接,默认为空 success:'', // 用户确认分享后执行的回调函数 cancel:''// 用户取消分享后执行的回调函数 } setWeixinShare=$.extend(d,setWeixinShare); $.ajax({ url:"//www.lene-v.com/app/wechat/index.php?s=Home/ShareConfig/index", data:"share_url="+encodeURIComponent(location.href)+"&format=jsonp&domain=m", type:'get', dataType:'jsonp', success:function(res){ if(res.status!="successed"){ return false; } $.getScript('https://res.wx.qq.com/open/js/jweixin-1.0.0.js',function(result,status){ if(status!="success"){ return false; } var getWxCfg=res.data; wx.config({ //debug: true, // 开启调试模式,调用的所有api的返回值会在客户端alert出来,若要查看传入的参数,可以在pc端打开,参数信息会通过log打出,仅在pc端时才会打印。 appId:getWxCfg.appId, // 必填,公众号的唯一标识 timestamp:getWxCfg.timestamp, // 必填,生成签名的时间戳 nonceStr:getWxCfg.nonceStr, // 必填,生成签名的随机串 signature:getWxCfg.signature,// 必填,签名,见附录1 jsApiList:['onMenuShareTimeline','onMenuShareAppMessage','onMenuShareQQ','onMenuShareWeibo','onMenuShareQZone'] // 必填,需要使用的JS接口列表,所有JS接口列表见附录2 }); wx.ready(function(){ //获取“分享到朋友圈”按钮点击状态及自定义分享内容接口 wx.onMenuShareTimeline({ title: setWeixinShare.title, // 分享标题 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享给朋友”按钮点击状态及自定义分享内容接口 wx.onMenuShareAppMessage({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 type: setWeixinShare.type, // 分享类型,music、video或link,不填默认为link dataUrl: setWeixinShare.dataUrl, // 如果type是music或video,则要提供数据链接,默认为空 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ”按钮点击状态及自定义分享内容接口 wx.onMenuShareQQ({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到腾讯微博”按钮点击状态及自定义分享内容接口 wx.onMenuShareWeibo({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); //获取“分享到QQ空间”按钮点击状态及自定义分享内容接口 wx.onMenuShareQZone({ title: setWeixinShare.title, // 分享标题 desc: setWeixinShare.desc, // 分享描述 link: setWeixinShare.link, // 分享链接 imgUrl: setWeixinShare.imgUrl, // 分享图标 success: function () { setWeixinShare.success; // 用户确认分享后执行的回调函数 }, cancel: function () { setWeixinShare.cancel; // 用户取消分享后执行的回调函数 } }); }); }); } }); } function openX_ad(posterid, htmlid, width, height) { if ($(htmlid).length > 0) { var randomnumber = Math.random(); var now_url = encodeURIComponent(window.location.href); var ga = document.createElement('iframe'); ga.src = 'https://www1.elecfans.com/www/delivery/myafr.php?target=_blank&cb=' + randomnumber + '&zoneid=' + posterid+'&prefer='+now_url; ga.width = width; ga.height = height; ga.frameBorder = 0; ga.scrolling = 'no'; var s = $(htmlid).append(ga); } } openX_ad(828, '#berry-300', 300, 250);