Pytorch使用MNIST数据集实现基础GAN和DCGAN详解
作者:shiheyingzhe 发布时间:2021-11-17 02:14:33
原始生成对抗网络Generative Adversarial Networks GAN包含生成器Generator和判别器Discriminator,数据有真实数据groundtruth,还有需要网络生成的“fake”数据,目的是网络生成的fake数据可以“骗过”判别器,让判别器认不出来,就是让判别器分不清进入的数据是真实数据还是fake数据。总的来说是:判别器区分真实数据和fake数据的能力越强越好;生成器生成的数据骗过判别器的能力越强越好,这个是矛盾的,所以只能交替训练网络。
需要搭建生成器网络和判别器网络,训练的时候交替训练。
首先训练判别器的参数,固定生成器的参数,让判别器判断生成器生成的数据,让其和0接近,让判别器判断真实数据,让其和1接近;
接着训练生成器的参数,固定判别器的参数,让生成器生成的数据进入判别器,让判断结果和1接近。生成器生成数据需要给定随机初始值
线性版:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.gridspec as gridspec
def showimg(images,count):
images=images.detach().numpy()[0:16,:]
images=255*(0.5*images+0.5)
images = images.astype(np.uint8)
grid_length=int(np.ceil(np.sqrt(images.shape[0])))
plt.figure(figsize=(4,4))
width = int(np.sqrt((images.shape[1])))
gs = gridspec.GridSpec(grid_length,grid_length,wspace=0,hspace=0)
# gs.update(wspace=0, hspace=0)
print('starting...')
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([width,width]),cmap = plt.cm.gray)
plt.axis('off')
plt.tight_layout()
print('showing...')
plt.tight_layout()
plt.savefig('./GAN_Image/%d.png'%count, bbox_inches='tight')
def loadMNIST(batch_size): #MNIST图片的大小是28*28
trans_img=transforms.Compose([transforms.ToTensor()])
trainset=MNIST('./data',train=True,transform=trans_img,download=True)
testset=MNIST('./data',train=False,transform=trans_img,download=True)
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=10)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=10)
return trainset,testset,trainloader,testloader
class discriminator(nn.Module):
def __init__(self):
super(discriminator,self).__init__()
self.dis=nn.Sequential(
nn.Linear(784,300),
nn.LeakyReLU(0.2),
nn.Linear(300,150),
nn.LeakyReLU(0.2),
nn.Linear(150,1),
nn.Sigmoid()
)
def forward(self, x):
x=self.dis(x)
return x
class generator(nn.Module):
def __init__(self,input_size):
super(generator,self).__init__()
self.gen=nn.Sequential(
nn.Linear(input_size,150),
nn.ReLU(True),
nn.Linear(150,300),
nn.ReLU(True),
nn.Linear(300,784),
nn.Tanh()
)
def forward(self, x):
x=self.gen(x)
return x
if __name__=="__main__":
criterion=nn.BCELoss()
num_img=100
z_dimension=100
D=discriminator()
G=generator(z_dimension)
trainset, testset, trainloader, testloader = loadMNIST(num_img) # data
d_optimizer=optim.Adam(D.parameters(),lr=0.0003)
g_optimizer=optim.Adam(G.parameters(),lr=0.0003)
'''
交替训练的方式训练网络
先训练判别器网络D再训练生成器网络G
不同网络的训练次数是超参数
也可以两个网络训练相同的次数
这样就可以不用分别训练两个网络
'''
count=0
#鉴别器D的训练,固定G的参数
epoch = 100
gepoch = 1
for i in range(epoch):
for (img, label) in trainloader:
# num_img=img.size()[0]
real_img=img.view(num_img,-1)#展开为28*28=784
real_label=torch.ones(num_img)#真实label为1
fake_label=torch.zeros(num_img)#假的label为0
#compute loss of real_img
real_out=D(real_img) #真实图片送入判别器D输出0~1
d_loss_real=criterion(real_out,real_label)#得到loss
real_scores=real_out#真实图片放入判别器输出越接近1越好
#compute loss of fake_img
z=torch.randn(num_img,z_dimension)#随机生成向量
fake_img=G(z)#将向量放入生成网络G生成一张图片
fake_out=D(fake_img)#判别器判断假的图片
d_loss_fake=criterion(fake_out,fake_label)#假的图片的loss
fake_scores=fake_out#假的图片放入判别器输出越接近0越好
#D bp and optimize
d_loss=d_loss_real+d_loss_fake
d_optimizer.zero_grad() #判别器D的梯度归零
d_loss.backward() #反向传播
d_optimizer.step() #更新判别器D参数
#生成器G的训练compute loss of fake_img
for j in range(gepoch):
fake_label = torch.ones(num_img) # 真实label为1
z = torch.randn(num_img, z_dimension) # 随机生成向量
fake_img = G(z) # 将向量放入生成网络G生成一张图片
output = D(fake_img) # 经过判别器得到结果
g_loss = criterion(output, fake_label)#得到假的图片与真实标签的loss
#bp and optimize
g_optimizer.zero_grad() #生成器G的梯度归零
g_loss.backward() #反向传播
g_optimizer.step()#更新生成器G参数
print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
'D real: {:.6f}, D fake: {:.6f}'.format(
i, epoch, d_loss.data[0], g_loss.data[0],
real_scores.data.mean(), fake_scores.data.mean()))
showimg(fake_img,count)
# plt.show()
count += 1
这里的图分别是 epoch为0、50、100、150、190的运行结果,可以看到图片中的数字并不单一
卷积版 Deep Convolutional Generative Adversarial Networks:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import matplotlib.gridspec as gridspec
import os
def showimg(images,count):
images=images.to('cpu')
images=images.detach().numpy()
images=images[[6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96]]
images=255*(0.5*images+0.5)
images = images.astype(np.uint8)
grid_length=int(np.ceil(np.sqrt(images.shape[0])))
plt.figure(figsize=(4,4))
width = images.shape[2]
gs = gridspec.GridSpec(grid_length,grid_length,wspace=0,hspace=0)
print(images.shape)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape(width,width),cmap = plt.cm.gray)
plt.axis('off')
plt.tight_layout()
# print('showing...')
plt.tight_layout()
# plt.savefig('./GAN_Imaget/%d.png'%count, bbox_inches='tight')
def loadMNIST(batch_size): #MNIST图片的大小是28*28
trans_img=transforms.Compose([transforms.ToTensor()])
trainset=MNIST('./data',train=True,transform=trans_img,download=True)
testset=MNIST('./data',train=False,transform=trans_img,download=True)
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=10)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=10)
return trainset,testset,trainloader,testloader
class discriminator(nn.Module):
def __init__(self):
super(discriminator,self).__init__()
self.dis=nn.Sequential(
nn.Conv2d(1,32,5,stride=1,padding=2),
nn.LeakyReLU(0.2,True),
nn.MaxPool2d((2,2)),
nn.Conv2d(32,64,5,stride=1,padding=2),
nn.LeakyReLU(0.2,True),
nn.MaxPool2d((2,2))
)
self.fc=nn.Sequential(
nn.Linear(7 * 7 * 64, 1024),
nn.LeakyReLU(0.2, True),
nn.Linear(1024, 1),
nn.Sigmoid()
)
def forward(self, x):
x=self.dis(x)
x=x.view(x.size(0),-1)
x=self.fc(x)
return x
class generator(nn.Module):
def __init__(self,input_size,num_feature):
super(generator,self).__init__()
self.fc=nn.Linear(input_size,num_feature) #1*56*56
self.br=nn.Sequential(
nn.BatchNorm2d(1),
nn.ReLU(True)
)
self.gen=nn.Sequential(
nn.Conv2d(1,50,3,stride=1,padding=1),
nn.BatchNorm2d(50),
nn.ReLU(True),
nn.Conv2d(50,25,3,stride=1,padding=1),
nn.BatchNorm2d(25),
nn.ReLU(True),
nn.Conv2d(25,1,2,stride=2),
nn.Tanh()
)
def forward(self, x):
x=self.fc(x)
x=x.view(x.size(0),1,56,56)
x=self.br(x)
x=self.gen(x)
return x
if __name__=="__main__":
criterion=nn.BCELoss()
num_img=100
z_dimension=100
D=discriminator()
G=generator(z_dimension,3136) #1*56*56
trainset, testset, trainloader, testloader = loadMNIST(num_img) # data
D=D.cuda()
G=G.cuda()
d_optimizer=optim.Adam(D.parameters(),lr=0.0003)
g_optimizer=optim.Adam(G.parameters(),lr=0.0003)
'''
交替训练的方式训练网络
先训练判别器网络D再训练生成器网络G
不同网络的训练次数是超参数
也可以两个网络训练相同的次数,
这样就可以不用分别训练两个网络
'''
count=0
#鉴别器D的训练,固定G的参数
epoch = 100
gepoch = 1
for i in range(epoch):
for (img, label) in trainloader:
# num_img=img.size()[0]
img=Variable(img).cuda()
real_label=Variable(torch.ones(num_img)).cuda()#真实label为1
fake_label=Variable(torch.zeros(num_img)).cuda()#假的label为0
#compute loss of real_img
real_out=D(img) #真实图片送入判别器D输出0~1
d_loss_real=criterion(real_out,real_label)#得到loss
real_scores=real_out#真实图片放入判别器输出越接近1越好
#compute loss of fake_img
z=Variable(torch.randn(num_img,z_dimension)).cuda()#随机生成向量
fake_img=G(z)#将向量放入生成网络G生成一张图片
fake_out=D(fake_img)#判别器判断假的图片
d_loss_fake=criterion(fake_out,fake_label)#假的图片的loss
fake_scores=fake_out#假的图片放入判别器输出越接近0越好
#D bp and optimize
d_loss=d_loss_real+d_loss_fake
d_optimizer.zero_grad() #判别器D的梯度归零
d_loss.backward() #反向传播
d_optimizer.step() #更新判别器D参数
#生成器G的训练compute loss of fake_img
for j in range(gepoch):
fake_label = Variable(torch.ones(num_img)).cuda() # 真实label为1
z = Variable(torch.randn(num_img, z_dimension)).cuda() # 随机生成向量
fake_img = G(z) # 将向量放入生成网络G生成一张图片
output = D(fake_img) # 经过判别器得到结果
g_loss = criterion(output, fake_label)#得到假的图片与真实标签的loss
#bp and optimize
g_optimizer.zero_grad() #生成器G的梯度归零
g_loss.backward() #反向传播
g_optimizer.step()#更新生成器G参数
# if ((i+1)%1000==0):
# print("[%d/%d] GLoss: %.5f" % (i + 1, gepoch, g_loss.data[0]))
print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
'D real: {:.6f}, D fake: {:.6f}'.format(
i, epoch, d_loss.data[0], g_loss.data[0],
real_scores.data.mean(), fake_scores.data.mean()))
showimg(fake_img,count)
plt.show()
count += 1
这里的gepoch设置为1,运行39次的结果是:
gepoch设置为2,运行0、25、50、75、100次的结果是:
gepoch设置为3,运行25、50、75次的结果是:
gepoch设置为4,运行0、10、20、30、35次的结果是:
gepoch设置为5,运行0、10、20、25、29次的结果是:
gepoch设置为3,z_dimension设置为190,epoch运行0、10、15、20、25、35的结果是:
可以看到生成的数字基本没有太多的规律,可能最终都是同个数字,不能生成指定的数字,CGAN就很好的解决这个问题,可以生成指定的数字 Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式
来源:https://blog.csdn.net/shiheyingzhe/article/details/83098339
猜你喜欢
- 一.概念简介 脚本:script是使用一种特定的描述性语言,依据一定的格式编写的可执行文件,又称作宏或批处理文件。 二.背景 近来在Wind
- 本文实例总结了php处理json格式数据的方法。分享给大家供大家参考,具体如下:1.json简介:何为json?简 单地说,JSON 可以将
- 测试sql: 代码如下:SET STATISTICS IO ON SET STATISTICS TIME ON SELECT COUNT(1
- 在Web上使用菜单可以极大地节约页面的空间,同时也比较的符合用户从Windows上继承下来的UI操作体验。在以往的Web页菜单设计中,我们普
- 在Google上搜一下,可以发现一大堆对ASP不好的评价,什么运行速度慢、异常处理机制不好、缺乏面向对象机制、开发效率低、漏洞多等等。为了让
- 分析古诗文网站下图1展示了古诗文网站—》诗文 栏目的首页数据。该栏目的地址是:https://so.gushiwen.cn/shiwens/
- 我们一般在调试程序的时候,有些操作会莫名地失败,又没有错误消息提示,特别是在执行数据库操作的时候,明明执行过去了,可就是数据库里没有记录变动
- counter 是一种特殊的字典,主要方便用来计数,key 是要计数的 item,value 保存的是个数。from collections
- 对于变量的访问和设置,我们可以使用get、set方法,如下:class student: def __init__(self,n
- HTTP_X_FORWARDED_FOR与REMOTE_ADDR的区别.在Request.ServerVariables中并没有HTTP_X
- 人们很容易忽视图像img标签的alt属性。然而,它的重要性也无法体现出来,它是有利于网页的accessibility and&nb
- 自己有一套模块化的思路,想搜索一下有没有共鸣结果排名靠前的是通过class拼凑页面的想法。模块化是twinsen提出来的,从我接收第一个po
- 1、查找表结构,判断要加入的列是否已存在2、如果不存在,则执行添加 CREATE PROCEDURE `mysql_sp_add_
- 我见到有的网站好像可以把数据库的记录读到表格里去,是这样的吗?如何做到的?可能是这样的,因为我们确实能把数据库里的记录用表格来储存,看看下面
- asp使用fso读取驱动器信息:<%vv=drive()response.write vv funct
- 以前写过一个刷校内网的人气的工具,Java的(以后再也不行Java程序了),里面用到了验证码识别,那段代码不是我自己写的:-) 校内的验证是
- 0x01 安装pyinotify>>> pip install pyinotify>>> import
- 本文主要是基于Python Opencv 实现的图像分割,其中使用到的opencv的函数有:使用 OpenCV 函数 cv::filter2
- 如果是在Oracle10g之前,删除一个表空间中的数据文件后,其文件在数据库数据字典中会仍然存在,除非你删除表空间,否则文件信息不会清除。但
- CSS 文件的大小和所引起的 HTTP 的请求数是 CSS 性能的最关键因素回流(reflow)和渲染时间(非常!)没那么重要副本(dupl