网络编程
位置:首页>> 网络编程>> Python编程>> Pytorch学习笔记DCGAN极简入门教程

Pytorch学习笔记DCGAN极简入门教程

作者:xz1308579340  发布时间:2022-05-28 17:29:02 

标签:Pytorch,DCGAN

1.图片分类网络

这是一个二分类网络,可以是alxnet ,vgg,resnet任何一个,负责对图片进行二分类,区分图片是真实图片还是生成的图片

2.图片生成网络

输入是一个随机噪声,输出是一张图片,使用的是反卷积层

相信学过深度学习的都能写出这两个网络,当然如果你写不出来,没关系,有人替你写好了

首先是图片分类网络:

简单来说就是cnn+relu+sogmid,可以换成任何一个分类网络,比如bgg,resnet等


class Discriminator(nn.Module):
   def __init__(self, ngpu):
       super(Discriminator, self).__init__()
       self.ngpu = ngpu
       self.main = nn.Sequential(
           # input is (nc) x 64 x 64
           nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
           nn.LeakyReLU(0.2, inplace=True),
           # state size. (ndf) x 32 x 32
           nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ndf * 2),
           nn.LeakyReLU(0.2, inplace=True),
           # state size. (ndf*2) x 16 x 16
           nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ndf * 4),
           nn.LeakyReLU(0.2, inplace=True),
           # state size. (ndf*4) x 8 x 8
           nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ndf * 8),
           nn.LeakyReLU(0.2, inplace=True),
           # state size. (ndf*8) x 4 x 4
           nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
           nn.Sigmoid()
       )
   def forward(self, input):
       return self.main(input)

重点是生成网络

代码如下,其实就是反卷积+bn+relu


class Generator(nn.Module):
   def __init__(self, ngpu):
       super(Generator, self).__init__()
       self.ngpu = ngpu
       self.main = nn.Sequential(
           # input is Z, going into a convolution
           nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
           nn.BatchNorm2d(ngf * 8),
           nn.ReLU(True),
           # state size. (ngf*8) x 4 x 4
           nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ngf * 4),
           nn.ReLU(True),
           # state size. (ngf*4) x 8 x 8
           nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ngf * 2),
           nn.ReLU(True),
           # state size. (ngf*2) x 16 x 16
           nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
           nn.BatchNorm2d(ngf),
           nn.ReLU(True),
           # state size. (ngf) x 32 x 32
           nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
           nn.Tanh()
           # state size. (nc) x 64 x 64
       )
   def forward(self, input):
       return self.main(input)

讲道理,以上两个网络都挺简单。

真正的重点到了,怎么训练

每一个step分为三个步骤:

  • 训练二分类网络
       1.输入真实图片,经过二分类,希望判定为真实图片,更新二分类网络
       2.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为虚假图片,更新二分类网络

  • 训练生成网络
       3.输入噪声,进过生成网络,生成一张图片,输入二分类网络,希望判定为真实图片,更新生成网络

不多说直接上代码


for epoch in range(num_epochs):
   # For each batch in the dataloader
   for i, data in enumerate(dataloader, 0):
       ############################
       # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
       ###########################
       ## Train with all-real batch
       netD.zero_grad()
       # Format batch
       real_cpu = data[0].to(device)
       b_size = real_cpu.size(0)
       label = torch.full((b_size,), real_label, device=device)
       # Forward pass real batch through D
       output = netD(real_cpu).view(-1)
       # Calculate loss on all-real batch
       errD_real = criterion(output, label)
       # Calculate gradients for D in backward pass
       errD_real.backward()
       D_x = output.mean().item()
       ## Train with all-fake batch
       # Generate batch of latent vectors
       noise = torch.randn(b_size, nz, 1, 1, device=device)
       # Generate fake image batch with G
       fake = netG(noise)
       label.fill_(fake_label)
       # Classify all fake batch with D
       output = netD(fake.detach()).view(-1)
       # Calculate D's loss on the all-fake batch
       errD_fake = criterion(output, label)
       # Calculate the gradients for this batch
       errD_fake.backward()
       D_G_z1 = output.mean().item()
       # Add the gradients from the all-real and all-fake batches
       errD = errD_real + errD_fake
       # Update D
       optimizerD.step()
       ############################
       # (2) Update G network: maximize log(D(G(z)))
       ###########################
       netG.zero_grad()
       label.fill_(real_label)  # fake labels are real for generator cost
       # Since we just updated D, perform another forward pass of all-fake batch through D
       output = netD(fake).view(-1)
       # Calculate G's loss based on this output
       errG = criterion(output, label)
       # Calculate gradients for G
       errG.backward()
       D_G_z2 = output.mean().item()
       # Update G
       optimizerG.step()
       # Output training stats
       if i % 50 == 0:
           print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                 % (epoch, num_epochs, i, len(dataloader),
                    errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
       # Save Losses for plotting later
       G_losses.append(errG.item())
       D_losses.append(errD.item())
       # Check how the generator is doing by saving G's output on fixed_noise
       if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
           with torch.no_grad():
               fake = netG(fixed_noise).detach().cpu()
           img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
       iters += 1

来源:https://blog.csdn.net/xz1308579340/article/details/105883090

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com