Pytorch使用MNIST数据集实现基础GAN和DCGAN详解
作者:shiheyingzhe 发布时间:2021-11-17 02:14:33
原始生成对抗网络Generative Adversarial Networks GAN包含生成器Generator和判别器Discriminator,数据有真实数据groundtruth,还有需要网络生成的“fake”数据,目的是网络生成的fake数据可以“骗过”判别器,让判别器认不出来,就是让判别器分不清进入的数据是真实数据还是fake数据。总的来说是:判别器区分真实数据和fake数据的能力越强越好;生成器生成的数据骗过判别器的能力越强越好,这个是矛盾的,所以只能交替训练网络。
需要搭建生成器网络和判别器网络,训练的时候交替训练。
首先训练判别器的参数,固定生成器的参数,让判别器判断生成器生成的数据,让其和0接近,让判别器判断真实数据,让其和1接近;
接着训练生成器的参数,固定判别器的参数,让生成器生成的数据进入判别器,让判断结果和1接近。生成器生成数据需要给定随机初始值
线性版:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.gridspec as gridspec
def showimg(images,count):
images=images.detach().numpy()[0:16,:]
images=255*(0.5*images+0.5)
images = images.astype(np.uint8)
grid_length=int(np.ceil(np.sqrt(images.shape[0])))
plt.figure(figsize=(4,4))
width = int(np.sqrt((images.shape[1])))
gs = gridspec.GridSpec(grid_length,grid_length,wspace=0,hspace=0)
# gs.update(wspace=0, hspace=0)
print('starting...')
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([width,width]),cmap = plt.cm.gray)
plt.axis('off')
plt.tight_layout()
print('showing...')
plt.tight_layout()
plt.savefig('./GAN_Image/%d.png'%count, bbox_inches='tight')
def loadMNIST(batch_size): #MNIST图片的大小是28*28
trans_img=transforms.Compose([transforms.ToTensor()])
trainset=MNIST('./data',train=True,transform=trans_img,download=True)
testset=MNIST('./data',train=False,transform=trans_img,download=True)
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=10)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=10)
return trainset,testset,trainloader,testloader
class discriminator(nn.Module):
def __init__(self):
super(discriminator,self).__init__()
self.dis=nn.Sequential(
nn.Linear(784,300),
nn.LeakyReLU(0.2),
nn.Linear(300,150),
nn.LeakyReLU(0.2),
nn.Linear(150,1),
nn.Sigmoid()
)
def forward(self, x):
x=self.dis(x)
return x
class generator(nn.Module):
def __init__(self,input_size):
super(generator,self).__init__()
self.gen=nn.Sequential(
nn.Linear(input_size,150),
nn.ReLU(True),
nn.Linear(150,300),
nn.ReLU(True),
nn.Linear(300,784),
nn.Tanh()
)
def forward(self, x):
x=self.gen(x)
return x
if __name__=="__main__":
criterion=nn.BCELoss()
num_img=100
z_dimension=100
D=discriminator()
G=generator(z_dimension)
trainset, testset, trainloader, testloader = loadMNIST(num_img) # data
d_optimizer=optim.Adam(D.parameters(),lr=0.0003)
g_optimizer=optim.Adam(G.parameters(),lr=0.0003)
'''
交替训练的方式训练网络
先训练判别器网络D再训练生成器网络G
不同网络的训练次数是超参数
也可以两个网络训练相同的次数
这样就可以不用分别训练两个网络
'''
count=0
#鉴别器D的训练,固定G的参数
epoch = 100
gepoch = 1
for i in range(epoch):
for (img, label) in trainloader:
# num_img=img.size()[0]
real_img=img.view(num_img,-1)#展开为28*28=784
real_label=torch.ones(num_img)#真实label为1
fake_label=torch.zeros(num_img)#假的label为0
#compute loss of real_img
real_out=D(real_img) #真实图片送入判别器D输出0~1
d_loss_real=criterion(real_out,real_label)#得到loss
real_scores=real_out#真实图片放入判别器输出越接近1越好
#compute loss of fake_img
z=torch.randn(num_img,z_dimension)#随机生成向量
fake_img=G(z)#将向量放入生成网络G生成一张图片
fake_out=D(fake_img)#判别器判断假的图片
d_loss_fake=criterion(fake_out,fake_label)#假的图片的loss
fake_scores=fake_out#假的图片放入判别器输出越接近0越好
#D bp and optimize
d_loss=d_loss_real+d_loss_fake
d_optimizer.zero_grad() #判别器D的梯度归零
d_loss.backward() #反向传播
d_optimizer.step() #更新判别器D参数
#生成器G的训练compute loss of fake_img
for j in range(gepoch):
fake_label = torch.ones(num_img) # 真实label为1
z = torch.randn(num_img, z_dimension) # 随机生成向量
fake_img = G(z) # 将向量放入生成网络G生成一张图片
output = D(fake_img) # 经过判别器得到结果
g_loss = criterion(output, fake_label)#得到假的图片与真实标签的loss
#bp and optimize
g_optimizer.zero_grad() #生成器G的梯度归零
g_loss.backward() #反向传播
g_optimizer.step()#更新生成器G参数
print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
'D real: {:.6f}, D fake: {:.6f}'.format(
i, epoch, d_loss.data[0], g_loss.data[0],
real_scores.data.mean(), fake_scores.data.mean()))
showimg(fake_img,count)
# plt.show()
count += 1
这里的图分别是 epoch为0、50、100、150、190的运行结果,可以看到图片中的数字并不单一
卷积版 Deep Convolutional Generative Adversarial Networks:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import matplotlib.gridspec as gridspec
import os
def showimg(images,count):
images=images.to('cpu')
images=images.detach().numpy()
images=images[[6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96]]
images=255*(0.5*images+0.5)
images = images.astype(np.uint8)
grid_length=int(np.ceil(np.sqrt(images.shape[0])))
plt.figure(figsize=(4,4))
width = images.shape[2]
gs = gridspec.GridSpec(grid_length,grid_length,wspace=0,hspace=0)
print(images.shape)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape(width,width),cmap = plt.cm.gray)
plt.axis('off')
plt.tight_layout()
# print('showing...')
plt.tight_layout()
# plt.savefig('./GAN_Imaget/%d.png'%count, bbox_inches='tight')
def loadMNIST(batch_size): #MNIST图片的大小是28*28
trans_img=transforms.Compose([transforms.ToTensor()])
trainset=MNIST('./data',train=True,transform=trans_img,download=True)
testset=MNIST('./data',train=False,transform=trans_img,download=True)
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=10)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=10)
return trainset,testset,trainloader,testloader
class discriminator(nn.Module):
def __init__(self):
super(discriminator,self).__init__()
self.dis=nn.Sequential(
nn.Conv2d(1,32,5,stride=1,padding=2),
nn.LeakyReLU(0.2,True),
nn.MaxPool2d((2,2)),
nn.Conv2d(32,64,5,stride=1,padding=2),
nn.LeakyReLU(0.2,True),
nn.MaxPool2d((2,2))
)
self.fc=nn.Sequential(
nn.Linear(7 * 7 * 64, 1024),
nn.LeakyReLU(0.2, True),
nn.Linear(1024, 1),
nn.Sigmoid()
)
def forward(self, x):
x=self.dis(x)
x=x.view(x.size(0),-1)
x=self.fc(x)
return x
class generator(nn.Module):
def __init__(self,input_size,num_feature):
super(generator,self).__init__()
self.fc=nn.Linear(input_size,num_feature) #1*56*56
self.br=nn.Sequential(
nn.BatchNorm2d(1),
nn.ReLU(True)
)
self.gen=nn.Sequential(
nn.Conv2d(1,50,3,stride=1,padding=1),
nn.BatchNorm2d(50),
nn.ReLU(True),
nn.Conv2d(50,25,3,stride=1,padding=1),
nn.BatchNorm2d(25),
nn.ReLU(True),
nn.Conv2d(25,1,2,stride=2),
nn.Tanh()
)
def forward(self, x):
x=self.fc(x)
x=x.view(x.size(0),1,56,56)
x=self.br(x)
x=self.gen(x)
return x
if __name__=="__main__":
criterion=nn.BCELoss()
num_img=100
z_dimension=100
D=discriminator()
G=generator(z_dimension,3136) #1*56*56
trainset, testset, trainloader, testloader = loadMNIST(num_img) # data
D=D.cuda()
G=G.cuda()
d_optimizer=optim.Adam(D.parameters(),lr=0.0003)
g_optimizer=optim.Adam(G.parameters(),lr=0.0003)
'''
交替训练的方式训练网络
先训练判别器网络D再训练生成器网络G
不同网络的训练次数是超参数
也可以两个网络训练相同的次数,
这样就可以不用分别训练两个网络
'''
count=0
#鉴别器D的训练,固定G的参数
epoch = 100
gepoch = 1
for i in range(epoch):
for (img, label) in trainloader:
# num_img=img.size()[0]
img=Variable(img).cuda()
real_label=Variable(torch.ones(num_img)).cuda()#真实label为1
fake_label=Variable(torch.zeros(num_img)).cuda()#假的label为0
#compute loss of real_img
real_out=D(img) #真实图片送入判别器D输出0~1
d_loss_real=criterion(real_out,real_label)#得到loss
real_scores=real_out#真实图片放入判别器输出越接近1越好
#compute loss of fake_img
z=Variable(torch.randn(num_img,z_dimension)).cuda()#随机生成向量
fake_img=G(z)#将向量放入生成网络G生成一张图片
fake_out=D(fake_img)#判别器判断假的图片
d_loss_fake=criterion(fake_out,fake_label)#假的图片的loss
fake_scores=fake_out#假的图片放入判别器输出越接近0越好
#D bp and optimize
d_loss=d_loss_real+d_loss_fake
d_optimizer.zero_grad() #判别器D的梯度归零
d_loss.backward() #反向传播
d_optimizer.step() #更新判别器D参数
#生成器G的训练compute loss of fake_img
for j in range(gepoch):
fake_label = Variable(torch.ones(num_img)).cuda() # 真实label为1
z = Variable(torch.randn(num_img, z_dimension)).cuda() # 随机生成向量
fake_img = G(z) # 将向量放入生成网络G生成一张图片
output = D(fake_img) # 经过判别器得到结果
g_loss = criterion(output, fake_label)#得到假的图片与真实标签的loss
#bp and optimize
g_optimizer.zero_grad() #生成器G的梯度归零
g_loss.backward() #反向传播
g_optimizer.step()#更新生成器G参数
# if ((i+1)%1000==0):
# print("[%d/%d] GLoss: %.5f" % (i + 1, gepoch, g_loss.data[0]))
print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
'D real: {:.6f}, D fake: {:.6f}'.format(
i, epoch, d_loss.data[0], g_loss.data[0],
real_scores.data.mean(), fake_scores.data.mean()))
showimg(fake_img,count)
plt.show()
count += 1
这里的gepoch设置为1,运行39次的结果是:
gepoch设置为2,运行0、25、50、75、100次的结果是:
gepoch设置为3,运行25、50、75次的结果是:
gepoch设置为4,运行0、10、20、30、35次的结果是:
gepoch设置为5,运行0、10、20、25、29次的结果是:
gepoch设置为3,z_dimension设置为190,epoch运行0、10、15、20、25、35的结果是:
可以看到生成的数字基本没有太多的规律,可能最终都是同个数字,不能生成指定的数字,CGAN就很好的解决这个问题,可以生成指定的数字 Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式
来源:https://blog.csdn.net/shiheyingzhe/article/details/83098339


猜你喜欢
- 本文主要通过对海康摄像头进行抓包,模拟发送了udp包,并抓取摄像头返回的数据包,解析并提取相关信息。通过抓包发现,海康摄像头发送、接收数据使
- 本文介绍基于Python中ArcPy模块,对大量不同时相的栅格遥感影像按照其成像时间依次执行批量拼接的方法。在前期的文章Python arc
- 今天遇到一个非常基础的问题,结果搞了好久好久.....赶快写一篇博客记录一下:本来两个不一样的字符串,在if 的条件判断中被判定为True,
- PHP fprintf() 函数实例把一些文本写入到名为 "test.txt" 的文本文件:<?php $numb
- 本文实例讲述了mysql索引基数概念与用法。分享给大家供大家参考,具体如下:Cardinality(索引基数)是mysql索引很重要的一个概
- 我就废话不多说了,直接上代码吧!from PIL import Image# 通道转换def change_image_channels(i
- 1、就按单介绍MySQL服务器的安全基础是:用户应该对他们需要的数据具有适当的访问权,既不能多也不能少。换句话说,用户不能对过多的数据具有过
- 1. 在游戏循环中监听事件事件event:就是游戏启动后,用户针对游戏所做的操作例如:点击关闭按钮,点击鼠标,按下键盘监听:在游戏循环中,判
- easy_thumbnails:A powerful, yet easy to implement thumbnailing applica
- 命令首先数据库迁移的两大命令: python manage.py makemigrations & python manage.py
- 如下所示:file->settings->Editor->General->Console里面的console co
- 实现用户登录并且输入错误三次后锁定该用户我的测试环境,win7,python3.5.1提示输入用户名,和密码判断是否被锁定判断用户名和密码是
- MySQL多字段相同数据去重复MySQL多字段去重复实际上是单字段去重复的衍生,原理就是把多字段数据通过子查询合并为单字段的数据表,再通过单
- 1、前言 MySQL 是完全网络化的跨平台关系型数据库系统,同时是具有客户机/服务器体系结构的分布式数据库管理系统。它具有功能强、使用简便、
- 本文实例为大家分享了python OpenCV实现答题卡识别判卷的具体代码,供大家参考,具体内容如下完整代码:#导入工具包import nu
- 此文从以下几个方面来整理关于分区表的概念及操作:1.表空间及分区表的概念2.表分区的具体作用3.表分区的优缺点4.表分区的几种类型及操作方法
- 有个需求需要把markdown转成html模块,查询了一下刚好有这个模块安装 pip install amrkdown安装完成直接转换并保存
- Bit-Packed Data TypesMySQL有一些存储类型使用一个值中的一些单个的比特位来紧凑的存储数据。纯技术上将,不管是底层的存
- 在许多语言中,你可以轻松地将任何数据类型转换为字符串,只需将其与字符串连接,或者使用类型转换表达式即可。但是,如果你在Go中尝试执行似乎很明
- 本文实例讲述了python获取Linux下文件版本信息、公司名和产品名的方法,分享给大家供大家参考。具体如下:区别于前文所述。本例是在lin