pytorch:实现简单的GAN示例(MNIST数据集)
作者:xckkcxxck 发布时间:2022-01-03 16:42:34
标签:pytorch,GAN,MNIST,数据集
我就废话不多说了,直接上代码吧!
# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
import torch
from torch import nn
from torch.autograd import Variable
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
def show_images(images): # 定义画图工具
images = np.reshape(images, [images.shape[0], -1])
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
fig = plt.figure(figsize=(sqrtn, sqrtn))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg,sqrtimg]))
return
def preprocess_img(x):
x = tfs.ToTensor()(x)
return (x - 0.5) / 0.5
def deprocess_img(x):
return (x + 1.0) / 2.0
class ChunkSampler(sampler.Sampler): # 定义一个取样的函数
"""Samples elements sequentially from some offset.
Arguments:
num_samples: # of desired datapoints
start: offset where we should start selecting from
"""
def __init__(self, num_samples, start=0):
self.num_samples = num_samples
self.start = start
def __iter__(self):
return iter(range(self.start, self.start + self.num_samples))
def __len__(self):
return self.num_samples
NUM_TRAIN = 50000
NUM_VAL = 5000
NOISE_DIM = 96
batch_size = 128
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)
#判别网络
def discriminator():
net = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1)
)
return net
#生成网络
def generator(noise_dim=NOISE_DIM):
net = nn.Sequential(
nn.Linear(noise_dim, 1024),
nn.ReLU(True),
nn.Linear(1024, 1024),
nn.ReLU(True),
nn.Linear(1024, 784),
nn.Tanh()
)
return net
#判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1
bce_loss = nn.BCEWithLogitsLoss()#交叉熵损失函数
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
size = logits_real.shape[0]
true_labels = Variable(torch.ones(size, 1)).float()
false_labels = Variable(torch.zeros(size, 1)).float()
loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
return loss
def generator_loss(logits_fake): # 生成器的 loss
size = logits_fake.shape[0]
true_labels = Variable(torch.ones(size, 1)).float()
loss = bce_loss(logits_fake, true_labels)
return loss
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
return optimizer
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250,
noise_size=96, num_epochs=10):
iter_count = 0
for epoch in range(num_epochs):
for x, _ in train_data:
bs = x.shape[0]
# 判别网络
real_data = Variable(x).view(bs, -1) # 真实数据
logits_real = D_net(real_data) # 判别网络得分
sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
g_fake_seed = Variable(sample_noise)
fake_images = G_net(g_fake_seed) # 生成的假的数据
logits_fake = D_net(fake_images) # 判别网络得分
d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
D_optimizer.zero_grad()
d_total_error.backward()
D_optimizer.step() # 优化判别网络
# 生成网络
g_fake_seed = Variable(sample_noise)
fake_images = G_net(g_fake_seed) # 生成的假的数据
gen_logits_fake = D_net(fake_images)
g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
G_optimizer.zero_grad()
g_error.backward()
G_optimizer.step() # 优化生成网络
if (iter_count % show_every == 0):
print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
show_images(imgs_numpy[0:16])
plt.show()
print()
iter_count += 1
D = discriminator()
G = generator()
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)
来源:https://blog.csdn.net/xckkcxxck/article/details/83037025


猜你喜欢
- 一、数据插入思路如果一条一条插入普通表的话,效率太低下,但内存表插入速度是很快的,可以先建立一张内存表,插入数据后,在导入到普通表中。1、创
- networkx是Python的一个包,用于构建和操作复杂的图结构,提供分析图的算法。图是由顶点、边和可选的属性构成的数据结构,顶点表示数据
- 使用matplotlib绘图时,在弹出的窗口中默认是有工具栏的,那么这些工具栏是如何定义的呢?工具栏的三种模式matplotlib的基础配置
- 本文实例讲述了PHP获取当前相对于域名目录的方法。分享给大家供大家参考。具体如下:http://127.0.0.1/dev/classd/i
- 太多的小伙伴正在学习Python,就说自己以后要做全栈开发,大家知道这是做什么的吗?我们现在所知道的知识点,哪些是以后你要从事这个全栈所需要
- 处理前文件内容代码处理后的# 读取代码fr = open('three.txt', 'r')dic = {}
- spark编程python实例ValueError: Cannot run multiple SparkContexts at once;
- CreateOrUpdate 是业务开发中很常见的场景,我们支持用户对某个业务实体进行创建/配置。希望实现的 repository 接口要达
- 前言问题:做requests请求时遇到如下报错:{“code”:“500&
- 疫情还没结束,小编只能宅在家里,哪哪也去不了,今天突发奇想给大家分享一篇教程关于Python paramiko 模块浅谈与SSH主要功能模拟
- 在现代的 web 框架里面,基本都有实现了依赖注入的功能,可以让我们很方便地对应用的依赖进行管理,同时免去在各个地方 new 对象的麻烦。比
- 前端实现用ligerUI实现分页,感觉用框架确实简单,闲着无聊,模拟着liger的分页界面实现了一遍(只要是功能,样式什么无视)
- python数组和列表的区别列表和数组的定义列表用于顺序存储结构。它可以方便、高效的的添加删除元素,并且列表中的元素可以是多种类型。数组是一
- 引伸阅读解读absolute与relativeposition:relative/absolute无法冲破的等级定位一直是WEB标准应用中的
- 测试sql: 代码如下:SET STATISTICS IO ON SET STATISTICS TIME ON SELECT COUNT(1
- 时下,个性ico图标却成为一些主流大牌网站提高用户体验(UE)的一个“时髦”玩法,那么,是如何在IE地址栏显示出网站的个性图标的呢?常浏览网
- MSDN上看了一下说是sql server 2005不支持在分布式事务处理中存在指向本地的链接服务器(环回链接服务器)个人尝试了下是由于在双
- 无论是在小得可怜的免费数据库空间或是大型电子商务网站,合理的设计表结构、充分利用空间是十分必要的。这就要求我们对数据库系统的常用数据类型有充
- 一:手写数字模型构建与保存1 加载数据集# 1加载数据digits_data = load_digits()可以先简单查看下 手写数字集,如
- Ajax的流行给用户体验带来了很大程序的提升,而“注册“这项做为互联网最常用到的功能也自然而然的成为Ajax最常光顾的地方,实时判断用户输入