Python人工智能学习PyTorch实现WGAN示例详解
作者:Swayzzu 发布时间:2022-10-20 18:49:32
标签:pytorch,WGAN,人工智能
1.GAN简述
在GAN中,有两个模型,一个是生成模型,用于生成样本,一个是判别模型,用于判断样本是真还是假。但由于在GAN中,使用的JS散度去计算损失值,很容易导致梯度弥散的情况,从而无法进行梯度下降更新参数,于是在WGAN中,引入了Wasserstein Distance,使得训练变得稳定。本文中我们以服从高斯分布的数据作为样本。
2.生成器模块
这里从2维数据,最终生成2维,主要目的是为了可视化比较方便。也就是说,在生成模型中,我们输入杂乱无章的2维的数据,通过训练之后,可以生成一个赝品,这个赝品在模仿高斯分布。
3.判别器模块
判别器同样输入的是2维的数据。比如我们上面的生成器,生成了一个2维的赝品,输入判别器之后,它能够最终输出一个sigmoid转换后的结果,相当于是一个概率,从而判别,这个赝品到底能不能达到以假乱真的程度。
4.数据生成模块
由于我们使用的是高斯模型,因此,直接生成我们需要的数据即可。我们在这个模块中,生成8个服从高斯分布的数据。
5.判别器训练
由于使用JS散度去计算损失的时候,会很容易出现梯度极小,接近于0的情况,会使得梯度下降无法进行,因此计算损失的时候,使用了Wasserstein Distance,去度量两个分布之间的差异。因此我们假如了梯度惩罚的因子。
其中,梯度惩罚的模块如下:
6.生成器训练
这里的训练是紧接着判别器训练的。也就是说,在一个周期里面,先训练判别器,再训练生成器。
7.结果可视化
通过visdom可视化损失值,通过matplotlib可视化分布的预测结果。
来源:https://blog.csdn.net/Swayzzu/article/details/121192285


猜你喜欢
- 读取mat文件生成h5文件1. Matlab生成 .mat 文件p = rand(1,10);q = ones(10);save('
- 代码# -*- coding:utf-8 -*-import osimport timef
- 客户端调用XMLHTTP的过程很简单,只有5个步骤: 1、创建XMLHTTP对象 2、打开与服务端的连接,同时定义指令发送方式,服务网页(U
- 方法1:1.安装requests_toolbelt依赖库#代码实现def upload(self): login_
- 如下所示:import ospath="/home/test/" #待读取的文件夹path_list=os.listdi
- 在 玉伯 的文章 《一道大题目,嘿嘿》 中有这样一段代码:[] == ![]也许很多同学迷惑:咦,这个如何转换呢?首先,我们了解下逻辑 NO
- 我有个MM在网上面安了家,想做一个关于特效的网站。她虽然懂一点网页制作,但是她的机器配置比较低,有时为了反复试验页面上一些特殊效果,而打开D
- python3的编码问题。打开python开发工具IDLE,新建‘codetest.py'文件,并写代码如下:import sysp
- 详细介绍Scrapy shell的使用Scrapy shell是Scrapy框架提供的一个非常有用的工具,可以帮助开发者快速地测试和调试Sc
- 单表的唯一查询用:distinct多表的唯一查询用:group bydistinct 查询多表时,left join 还有效,全连接无效,在
- /**//// <summary> /// 生成带CDATA的节点 /// </summary> /// <p
- 这是17年的第一篇博文,话说这天又是产品同学跑过来问我说:hi,lenny,你看现在市面上流行各种装逼H5,随便输入点名字啥的就给我生成房产
- 这里需要用到一个Django插件:django-pagination安装打开控制台 输入pip install dj-pagination实
- 适合的读者:有经验的开发员,专业前端人员。 原作者: Dmitry A. Soshnikov 发布时间: 2010-09-02 原文:htt
- MySQL由于它本身的小巧和操作的高效, 在数据库应用中越来越多的被采用.我在开发一个P2P应用的时候曾经使用MySQL来保存P2P节点,由
- 我来讲解属性部分, 这是相当有用的, 可要认真上课.首先,jquery中对html标签属性进行操作的关键词是 attr .没错,就4个字母,
- 我就废话不多说了,直接上代码吧!其实也不难,使用tertools.chain将参数链接起来即可import itertools...self
- 在用户登录windows操作系统的时候,如果触发到了登录表单的密码录入框上,并且此时按下了“大写锁定键(Caps Lock)”,那么界面上会
- 在python3.x中,可以使用pymysql来MySQL数据库的连接,并实现数据库的各种操作,本次博客主要介绍了pymysql的安装和使用
- 说明本例子利用TensorFlow搭建一个全连接神经网络,实现对MNIST手写数字的识别。先上代码from tensorflow.examp