Pytorch学习笔记DCGAN极简入门教程
作者:xz1308579340 发布时间:2022-05-28 17:29:02
标签:Pytorch,DCGAN
1.图片分类网络
这是一个二分类网络,可以是alxnet ,vgg,resnet任何一个,负责对图片进行二分类,区分图片是真实图片还是生成的图片
2.图片生成网络
输入是一个随机噪声,输出是一张图片,使用的是反卷积层
相信学过深度学习的都能写出这两个网络,当然如果你写不出来,没关系,有人替你写好了
首先是图片分类网络:
简单来说就是cnn+relu+sogmid,可以换成任何一个分类网络,比如bgg,resnet等
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
重点是生成网络
代码如下,其实就是反卷积+bn+relu
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
讲道理,以上两个网络都挺简单。
真正的重点到了,怎么训练
每一个step分为三个步骤:
训练二分类网络
1.输入真实图片,经过二分类,希望判定为真实图片,更新二分类网络
2.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为虚假图片,更新二分类网络训练生成网络
3.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为真实图片,更新生成网络
不多说直接上代码
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch
errD_fake.backward()
D_G_z1 = output.mean().item()
# Add the gradients from the all-real and all-fake batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
来源:https://blog.csdn.net/xz1308579340/article/details/105883090


猜你喜欢
- 1.彻底弄懂CSS盒子模式一(DIV布局快速入门) 2.彻底弄懂CSS盒子模式二(导航栏实例) 3.彻底弄懂CSS盒子模式三(浮动的表演和清
- 如果你之前没用过进度条,八成是觉得它会增加不必要的复杂性或者很难维护,其实不然。要加一个进度条其实只需要几行代码。在这几行代码中,我们可以看
- oracle如果存储过程比较复杂,我们要定位到错误就比较困难,那么可以存储过程的调试功能先按简单的存储过程做个例子,就是上次做的存储过程(p
- 导语小伙伴们大家好~如今的游戏可谓是层出不穷,NBA 2K系列啊,FIFA系列啊更是经典中的经典,不过小编发现,赛车游戏也是深受大家欢迎啊,
- Ewebeditor及fckeditork,90%的网站都是采用这两种编辑器作为产品或者内容的说明部分的编辑窗口,近日,一客户的外贸站点基本
- 至于对好广告的评判,不同的人有不同的标准,一些人认为那些打动人、有新意、有共鸣的广告是好广告,另一些人的观点则是:观众喜欢与否,不是广告好与
- 如果你经常使用python开发GUI程序的话,那么就知道,有时你需要很长时间来执行一个任务。当然,如果你使用命令行程序来做的话,你回非常惊讶
- 如果对自然语言分类,有很多中分法,比如英语、法语、汉语等,这种分法是最常见的。在语言学里面,也有对语言的分类方法,比如什么什么语系之类的。我
- 本文实例为大家分享了Python socket实现简单聊天室的具体代码,供大家参考,具体内容如下服务端使用了select模块,实现了对多个s
- 全局变量与局部变量# num1是全局变量num1 = 1# num2是局部变量def func():num2 = 2在函数外(且不在函数里)
- 本文实例讲述了Python编程实现数学运算求一元二次方程的实根算法。分享给大家供大家参考,具体如下:问题:请定义一个函数quadratic(
- 昨天去面试,百度题果然不一样,笔试我就蒙了,现在能记住两道题,笔试:1、title和alt 区别2、三列布局 左边裂固定宽度左对齐,右边列固
- 数据库隔离级别有四种,应用《高性能mysql》一书中的说明:然后说说修改事务隔离级别的方法:1.全局修改,修改mysql.ini配置文件,在
- 一、方法2此方法是两个表构建某一相同字段,然后全连接,在做匹配结果筛选,此方法针对数据量不大的时候,逻辑比较简单,但是内存消耗较大1. 导入
- 前言刚接触golang不久,有些环境无法融会贯通,现在针对开发过程中遇到的问题做个排查记录问题背景开发环境区分不同网段,同一个程序引入到另一
- 在新旧版的torch中的傅里叶变换函数在定义和用法上存在不同,记录一下。1、旧版fft = torch.rfft(input, 2, nor
- 《lnmp一键安装包》中需要获取ip地址,有2种情况:如果服务器只有私网地址没有公网地址,这个时候获取的IP(即私网地址)不能用来判断服务器
- 现将几种主要情况进行小结: 一、如何输入NULL值 如果不输入null值,当时间为空时,会默认写入"1900-01-01"
- 当where子句对某一列使用函数时,除非利用这个简单的技术强制索引,否则Oracle优化器不能在查询中使用索引。通常情况下,如果在WHERE
- form无论是在网站的制作中,还是在网站的重构中,我们都会频繁地“碰面”,当“碰面”的次数多了,反而觉得他更让人迷茫,有种熟悉的“陌生”,越