pytorch:实现简单的GAN示例(MNIST数据集)
作者:xckkcxxck 发布时间:2022-01-03 16:42:34
标签:pytorch,GAN,MNIST,数据集
我就废话不多说了,直接上代码吧!
# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
import torch
from torch import nn
from torch.autograd import Variable
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
def show_images(images): # 定义画图工具
images = np.reshape(images, [images.shape[0], -1])
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
fig = plt.figure(figsize=(sqrtn, sqrtn))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg,sqrtimg]))
return
def preprocess_img(x):
x = tfs.ToTensor()(x)
return (x - 0.5) / 0.5
def deprocess_img(x):
return (x + 1.0) / 2.0
class ChunkSampler(sampler.Sampler): # 定义一个取样的函数
"""Samples elements sequentially from some offset.
Arguments:
num_samples: # of desired datapoints
start: offset where we should start selecting from
"""
def __init__(self, num_samples, start=0):
self.num_samples = num_samples
self.start = start
def __iter__(self):
return iter(range(self.start, self.start + self.num_samples))
def __len__(self):
return self.num_samples
NUM_TRAIN = 50000
NUM_VAL = 5000
NOISE_DIM = 96
batch_size = 128
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)
#判别网络
def discriminator():
net = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1)
)
return net
#生成网络
def generator(noise_dim=NOISE_DIM):
net = nn.Sequential(
nn.Linear(noise_dim, 1024),
nn.ReLU(True),
nn.Linear(1024, 1024),
nn.ReLU(True),
nn.Linear(1024, 784),
nn.Tanh()
)
return net
#判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1
bce_loss = nn.BCEWithLogitsLoss()#交叉熵损失函数
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
size = logits_real.shape[0]
true_labels = Variable(torch.ones(size, 1)).float()
false_labels = Variable(torch.zeros(size, 1)).float()
loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
return loss
def generator_loss(logits_fake): # 生成器的 loss
size = logits_fake.shape[0]
true_labels = Variable(torch.ones(size, 1)).float()
loss = bce_loss(logits_fake, true_labels)
return loss
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
return optimizer
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250,
noise_size=96, num_epochs=10):
iter_count = 0
for epoch in range(num_epochs):
for x, _ in train_data:
bs = x.shape[0]
# 判别网络
real_data = Variable(x).view(bs, -1) # 真实数据
logits_real = D_net(real_data) # 判别网络得分
sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
g_fake_seed = Variable(sample_noise)
fake_images = G_net(g_fake_seed) # 生成的假的数据
logits_fake = D_net(fake_images) # 判别网络得分
d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
D_optimizer.zero_grad()
d_total_error.backward()
D_optimizer.step() # 优化判别网络
# 生成网络
g_fake_seed = Variable(sample_noise)
fake_images = G_net(g_fake_seed) # 生成的假的数据
gen_logits_fake = D_net(fake_images)
g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
G_optimizer.zero_grad()
g_error.backward()
G_optimizer.step() # 优化生成网络
if (iter_count % show_every == 0):
print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
show_images(imgs_numpy[0:16])
plt.show()
print()
iter_count += 1
D = discriminator()
G = generator()
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)
来源:https://blog.csdn.net/xckkcxxck/article/details/83037025
0
投稿
猜你喜欢
- 比如说点的是图片的左边,还是右边,上边还是下边?点击图片左右显示上下张,我怎么知道?这样就可以做出像QQ空间那样,打开上一个图片和下一个图片
- 本文介绍了python十进制和二进制的转换方法(含浮点数),分享给大家,也给自己留个笔记,具体如下:我终于写完了 , 十进制转二进制的小数部
- system函数 说明:执行外部程序并显示输出资料。 语法:string system(string command, int [retur
- 此文译自Fred Wilson 2010年2月在迈阿密举行的Web未来应用的年会上的演讲谢谢青云推荐了这篇这么好的演说谢谢卓和百忙中抽空帮我
- HTML被直接硬编码在 Python 代码之中。def current_datetime(request): now = dat
- 前一段时间有发过一个简单的JMAIL邮件发邮件的代码,今天就把这个代码做一个具体的注解,并增加了另外两个格式的代码,并举几个简单
- 很多人认为python中的字典是无序的,因为它是按照hash来存储的,但是python中有个模块collections(英文,收集、集合),
- fsockopen函数能够运用,首先要开启php.ini中的allow_url_open=on;fsockopen是对socket客户端代码
- 本文只考虑模板中的字符串,不考虑字符串中带标签的情况。模板中的字符串文字不会自动转义,因为这里默认模板的作者已经正确书写模板的内容。{{ d
- 首先,我要在这里写上一些很官方的概念,意在说明面向对象是很具体化的,很实体的模式,不能让有些人看见“对象&rdq
- 最近在工作中遇到了一个小问题,如果要将字符串型的数据转换成dict类型,我第一时间就想到了使用json函数。但是里面出现了一些问题1、通过j
- 如果直接在命令行中利用input和raw_input读入一个文件来处理,并且想要采用直接将文件拖入命令行来处理的方式,input方法可以直接
- 在我的上篇文章发出之后,我听到对“WEb2.0视觉风格”这个称谓的不认同声音。其实这并不出乎我的意料,因为,我在认真的开始思考“WEb2.0
- import osimport sysimport MySQLdbdef getStatus(conn):  
- python配置文件有.conf,.ini,.txt等多种python集成的 标准库的 ConfigParser 模块提供一套 API 来读
- 一、Ajax简介Ajax被认为是(Asynchronous JavaScript and XML)的缩写,允许浏览器与服务器通信而无需刷新当
- 思考一个问题:怎么实现在第一次检索的基础上进行二次检索?通常,我们的做法是第一次检索时保存检索条件,在第二次行检索时组合两次检索条件对数据库
- 随着网站的内容的增多和用户访问量的增多,无可避免的是网站加载会越来越慢,受限于带宽和服务器同一时间的请求次数的限制,我们往往需要在此时对我们
- 打开php.ini,首先找到file_uploads = on ;是否允许通过HTTP上传文件的开关。默认为ON即是开upload_tmp_
- Python用Pillow(PIL)进行简单的图像操作方法颜色与RGBA值计算机通常将图像表示为RGB值,或者再加上alpha值(通透度,透