pytorch GAN生成对抗网络实例
作者:全栈的方向 发布时间:2022-06-30 03:41:27
标签:pytorch,GAN,生成对抗网络
我就废话不多说了,直接上代码吧!
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)
np.random.seed(1)
BATCH_SIZE = 64
LR_G = 0.0001
LR_D = 0.0001
N_IDEAS = 5
ART_COMPONENTS = 15
PAINT_POINTS = np.vstack([np.linspace(-1,1,ART_COMPONENTS) for _ in range(BATCH_SIZE)])
def artist_works():
a = np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis]
paintings = a*np.power(PAINT_POINTS,2) + (a-1)
paintings = torch.from_numpy(paintings).float()
return Variable(paintings)
G = nn.Sequential(
nn.Linear(N_IDEAS,128),
nn.ReLU(),
nn.Linear(128,ART_COMPONENTS),
)
D = nn.Sequential(
nn.Linear(ART_COMPONENTS,128),
nn.ReLU(),
nn.Linear(128,1),
nn.Sigmoid(),
)
opt_D = torch.optim.Adam(D.parameters(),lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(),lr=LR_G)
plt.ion()
for step in range(10000):
artist_paintings = artist_works()
G_ideas = Variable(torch.randn(BATCH_SIZE,N_IDEAS))
G_paintings = G(G_ideas)
prob_artist0 = D(artist_paintings)
prob_artist1 = D(G_paintings)
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1-prob_artist1))
G_loss = torch.mean(torch.log(1 - prob_artist1))
opt_D.zero_grad()
D_loss.backward(retain_variables=True)
opt_D.step()
opt_G.zero_grad()
G_loss.backward()
opt_G.step()
if step % 50 == 0:
plt.cla()
plt.plot(PAINT_POINTS[0],G_paintings.data.numpy()[0],c='#4ad631',lw=3,label='Generated painting',)
plt.plot(PAINT_POINTS[0],2 * np.power(PAINT_POINTS[0], 2) + 1,c='#74BCFF',lw=3,label='upper bound',)
plt.plot(PAINT_POINTS[0],1 * np.power(PAINT_POINTS[0], 2) + 0,c='#FF9359',lw=3,label='lower bound',)
plt.text(-.5,2.3,'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size':15})
plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 15})
plt.ylim((0,3))
plt.legend(loc='upper right', fontsize=12)
plt.draw()
plt.pause(0.01)
plt.ioff()
plt.show()
来源:https://blog.csdn.net/pureszgd/article/details/75095332


猜你喜欢
- set oSQLServer =server.createobject("SQLDMO.SQLServer"
- 本文实例讲述了javascript常见数字进制转换的方法。分享给大家供大家参考,具体如下:基本思路是先把其他进制的转化成 十进制,然后再转化
- 糟糕的SQL查询语句可对整个应用程序的运行产生严重的影响,其不仅消耗掉更多的数据库时间,且它将对其他应用组件产生影响。如同其它学科,优化查询
- pandas删除部分数据后重新索引在使用pandas时,由于隔行读取删除了部分数据,导致删除数据后的索引不连续:原数据删除部分数据后在绑定p
- 先看效果,实现一个图片左右摇动,在一般的H5宣传页,商家活动页面我们会看到这样的动画,小程序的动画效果不同于css3动画效果,是通过js来完
- <HTML> <BODY> <
- 模块a.py 想用 b.py中公有数据 cntb的python文件#!/usr/bin/env python# coding:utf8fro
- 最近很多小伙伴在尝鲜chatGPT,使用中遇到网站的1020的错误码,博主也遇到了相似的问题,不同的人运行环境不一样,可能解决方案不一样,接
- 在IE比较简单,大家都知道用setHomePage来设置,懒人写法:<a href="#setHomePage"
- 本篇文章来介绍一道非常常见的面试题,到底有多常见呢?可能很多面试的开场白就是由此开始的。那就是 new 和 make 这两个内置函数的区别。
- 本文实例讲述了Python实现队列的方法。分享给大家供大家参考。具体实现方法如下:#!/usr/bin/env python queue =
- 前言Vscode是是一个强大的跨平台工具,我自己电脑是mac,公司电脑是win而且是内部环境,导致公司安装软件很费劲。好在vscode许多插
- testify在团队里推行单元测试的时候,有一个反对的意见是:写单元测试耗时太多。且不论这个意见对错,单元测试确实不应该太费时间。这时候,一
- 一、scrapy1.1 概述Scrapy,Python开发的一个快速、高层次的屏幕抓取和web抓取框架,用于抓取web站点并从页面中提取结构
- 本文实例为大家分享了js文字列表无缝滚动的具体代码,供大家参考,具体内容如下HTML代码:<div id="rule&quo
- Python作为一种功能强大的编程语言,因其简单易学而受到很多开发者的青睐。那么,Python 的应用领域有哪些呢?概括起来,Python的
- OAuth是一个关于授权(authorization)的开放网络标准,在全世界得到广泛应用,目前的版本是2.0版。本文对OAuth 2.0的
- 本文实例讲述了Python高级变量类型。分享给大家供大家参考,具体如下:目标列表元组字典字符串公共方法变量高级知识点回顾Python 中数据
- 如何解决pycharm配置跨域不提示?正常我们需在在如上中间件内配置跨域,但是2019之前的版本配置中间件可能需要全部自己敲出来,不会有提示
- 不难,代码总共也就25行,大致逻辑如下。总共分为是下面两步在云服务器上部署自定义消息处理服务这里需要我们自定义来处理用户发送过来的消息首先导