Giter VIP home page Giter VIP logo

chenliu-1996 / gan-evaluator Goto Github PK

View Code? Open in Web Editor NEW
10.0 2.0 0.0 797 KB

A pip-installable evaluator for GANs (IS and FID). Accepts either dataloaders or individual batches. Supports on-the-fly evaluation during training. A working DCGAN SVHN demo script provided.

License: MIT License

Python 100.00%
fid gan inception-score deep-learning evaluator frechet-inception-distance generative-adversarial-network metrics dcgan fid-score generation pytorch pytorch-gan pytorch-implmention generative-model torchvision evaluation-framework evaluation-metrics svhn-dataset gan-evaluation

gan-evaluator's Introduction

GAN Evaluator for Inception Score (IS) and Frechet Inception Distance (FID) in PyTorch

Latest PyPI version PyPI license PyPI download month PyPI download day made-with-python

Please kindly Star Github Stars this repo for better reach if you find it useful. Let's help out the community!

Main Contributions

  1. We created a GAN evaluator for IS and FID that
    • is easy to use,
    • accepts data as either dataloaders or individual batches, and
    • supports on-the-fly evaluation during training.
  2. We provided a simple demo script to demonstrate one common use case.

NEWS

[Feb 18, 2023]

Now available on PyPI! Now you can pip install it to your desired environment via:

pip install gan-evaluator

And in your Python project, wherever you need the GAN_Evaluator, you can import via:

from gan_evaluator import GAN_Evaluator

NOTE 1: You no longer need to copy any code from this repo in order to use GAN_Evalutor! At this point, the primary purpose of this repo is description and demonstration. With that said, you surely can clone this repo and try out the demo script. Also, you may find it easier to copy and modify the code if you want slightly different behaviors.

NOTE 2: During pip install gan-evaluator, the dependencies of GAN_Evaluator (but not of the demo script) are also installed.

Demo Script: Use DCGAN to generate SVHN digits

The script can be found in src/train_dcgan_svhn.py

  • Usage from the demo script, to give you a taste.

    Declaration

    evaluator = GAN_Evaluator(device=device,
                              num_images_real=len(train_loader.dataset),
                              num_images_fake=len(train_loader.dataset))
    

    Before traing loop

    evaluator.load_all_real_imgs(real_loader=train_loader, idx_in_loader=0)
    

    Inside traing loop

    if shall_plot:
        IS_mean, IS_std, FID = evaluator.fill_fake_img_batch(fake_batch=x_fake)
    else:
        evaluator.fill_fake_img_batch(fake_batch=x_fake, return_results=False)
    

    After each epoch of training

    evaluator.clear_fake_imgs()
    
  • Some visualizations of the demo script:

    • Real (top) and Generated (bottom) images.
    • IS and FID curves.

Details: The Evaluator for IS and FID

Introduction to the Evaluator

More details can be found in src/utils/gan_evaluator.py/GAN_Evaluator.

This evaluator computes the following metrics:
    - Inception Score (IS)
    - Frechet Inception Distance (FID)

This evaluator will take in the real images and the fake/generated images.
Then it will compute the activations from the real and fake images as well as the
predictions from the fake images.
The (fake) predictions will be used to compute IS, while
the (real, fake) activations will be used to compute FID.
If input image resolution < 75 x 75, we will upsample the image to accommodate Inception v3.

The real and fake images can be provided to this evaluator in either of the following formats:
1. dataloader
    `load_all_real_imgs`
    `load_all_fake_imgs`
2. per-batch
    `fill_real_img_batch`
    `fill_fake_img_batch`

!!! Please note: the latest IS and FID will be returned upon completion of either of the following:
    `load_all_fake_imgs`
    `fill_fake_img_batch`
Return format:
    (IS mean, IS std, FID)
*So please make sure you load real images before the fake images.*

Common Use Cases:
1. For the purpose of on-the-fly evaluation during GAN training:
    We recommend pre-loading the real images using the dataloader format, and
    populate the fake images using the per-batch format as training goes on.
    - At the end of each epoch, you can clean the fake images using:
        `clear_fake_imgs`
    - In *unusual* cases where your real images change (such as in progressive growing GANs),
    you may want to clear the real images. You can do so via:
        `clear_real_imgs`
2. For the purpose of offline evaluation of a saved dataset:
    We recommend pre-loading the real images and fake images.

Repository Hierarchy

GAN-evaluator
    ├── config
    |   └── `dcgan_svhn.yaml`
    ├── data (*)
    ├── debug_plot (*)
    ├── logs (*)
    └── src
        ├── utils
        |   ├── `gan_evaluator.py`: THIS CONTAINS OUR `GAN_Evaluator`.
        |   └── other utility files.
        └── `train_dcgan_svhn.py`: our demo script.

Folders marked with (*), if not exist, will be created automatically when you run train_dcgan_svhn.py.

Usage

  • To integrate our evaluator into your existing project, you can simply install via pip!

  • You can refer to the demo script to see how to interface with the evaluator. Briefly, you need to declare the evaluator, feed in the real images, and feed in the fake images. In most cases, you would feed in the real images in one pass, whereas feed in the fake image on-the-fly as the model is trained and fake images are generated. Each time you feed in the fake images, you can choose whether or not to compute the metrics. Lastly, don't forget to remove the fake images at the end of each epoch.

  • To run our demo script, do the following after activating the proper environment.

git clone [email protected]:ChenLiu-1996/GAN-evaluator.git
cd src
python train_dcgan_svhn.py --config ../config/dcgan_svhn.yaml

Citation

To be added.

Environement Setup

Packages Needed

The GAN_Evaluator module itself only uses numpy, scipy, torch, torchvision, and (for aesthetics) tqdm.

To run the example script, it additionally requires matplotlib, argparse, and yaml.

On our Yale Vision Lab server
  • There is a virtualenv ready to use, located at /media/home/chliu/.virtualenv/mondi-image-gen/.

  • Alternatively, you can start from an existing environment "torch191-py38env", and install the following packages:

python3 -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
python3 -m pip install wget gdown numpy matplotlib pyyaml click scipy yacs scikit-learn

If you see error messages such as Failed to build CUDA kernels for bias_act., you can fix it with:

python3 -m pip install ninja

Acknowledgements

  1. The code for the GAN_Evaluator (specifically, the computation of IS and FID) is inspired by:
  2. The code for the demo script (specifically, architecture and training of DCGAN) is inspired by:

gan-evaluator's People

Contributors

chenliu-1996 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

gan-evaluator's Issues

Error

Hello. I tried using GAN evaluator on the CIFAR10 dataset. Encountered some errors.
1%| | 3/391 [00:16<35:33, 5.50s/it]
Traceback (most recent call last):
File "D:\python_parctice\pt\FID\FID_MODLE.py", line 204, in
evaluator.load_all_real_imgs(real_loader=dataloader, idx_in_loader=0)
File "D:\Anaconda\install\envs\pt\lib\site-packages\gan_evaluator.py", line 176, in load_all_real_imgs
self.fill_real_img_batch(real_batch)
File "D:\Anaconda\install\envs\pt\lib\site-packages\gan_evaluator.py", line 136, in fill_real_img_batch
self.activation_vec_real[self.vec_real_pointer:self.vec_real_pointer +
ValueError: could not broadcast input array from shape (128,2048) into shape (7,2048)
下面是我的代码:
from future import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision.models as models
import matplotlib.animation as animation
from IPython.display import HTML

from attribute_hashmap import AttributeHashmap
from gan_evaluator import GAN_Evaluator
from log_utils import log
from seed import seed_everything

from scipy import linalg
from torch.nn.functional import adaptive_avg_pool2d

from PIL import Image

import matplotlib.pyplot as plt
import sys
import numpy as np
import os

print(os.listdir("../input"))

import time

SEED=42
random.seed(SEED)
torch.manual_seed(SEED)

Batch size during training

batch_size = 128

Spatial size of training images. All images will be resized to this size using a transformer.

image_size = 64

Number of channels in the training images. For color images this is 3

nc = 3

Size of z latent vector (i.e. size of generator input)

nz = 100

Size of feature maps in generator

ngf = 64

Size of feature maps in discriminator

ndf = 64

Number of training epochs

num_epochs = 70

different Learning rate for optimizers

g_lr = 0.0001
d_lr = 0.0004

Beta1 hyperparam for Adam optimizers

beta1 = 0.5
ngpu=1

#normalizing input between -1 and 1
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0,0,0), (1,1,1)),])

dataset = dset.CIFAR10(root="./input/", train=True,
download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=2)

Decide which device we want to run on

device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images

real_batch = next(iter(dataloader))

plt.figure(figsize=(8,8))

plt.axis("off")

plt.title("Training Images")

plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

class Generator(nn.Module):
def init(self, ngpu):
super(Generator, self).init()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf
4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)

def forward(self, input):
    return self.main(input)

custom weights initialization called on netG and netD

def weights_init(m):
classname = m.class.name
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)

Create the generator

netG = Generator(ngpu).to(device)
netG.apply(weights_init)
print(netG)

class Discriminator(nn.Module):
def init(self, ngpu):
super(Discriminator, self).init()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf
4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)

def forward(self, input):
    return self.main(input)

Create the Discriminator

netD = Discriminator(ngpu).to(device)

Apply the weights_init function to randomly initialize all weights

to mean=0, stdev=0.2.

netD.apply(weights_init)

Print the model

print(netD)

Initialize BCELoss function

criterion = nn.BCELoss()

Establish convention for real and fake labels during training

real_label = 1

fake_label = 0

"""adding label smoothing"""
real_label=0.9
fake_label=0.1

Setup Adam optimizers for both G and D

optimizerD = optim.Adam(netD.parameters(), lr=d_lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=g_lr, betas=(beta1, 0.999))

img_list = []
G_losses = []
D_losses = []
iters = 0
epoch_list, IS_list, FID_list = [], [], []
print("Starting Training Loop...")
if name=='main':
# Our GAN Evaluator.
evaluator = GAN_Evaluator(device=device,
num_images_real=len(dataloader),
num_images_fake=len(dataloader))

# We can pre-load the real images in the format of a dataloader.
# Of course you can do that in individual batches, but this way is neater.
# Because in CIFAR10, each batch contains a (image, label) pair, we set `idx_in_loader` = 0.
# If we only have images in the datalaoder, we can set `idx_in_loader` = None.
evaluator.load_all_real_imgs(real_loader=dataloader, idx_in_loader=0)
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        #         # add some noise to the input to discriminator

        real_cpu = 0.9 * real_cpu + 0.1 * torch.randn((real_cpu.size()), device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        with torch.no_grad():
           fake = netG(noise)
           IS_mean, IS_std, FID = evaluator.fill_fake_img_batch(
               fake_batch=fake)
           epoch_list.append(epoch + i / len(dataloader))
           IS_list.append(IS_mean)
           FID_list.append(FID)
        label.fill_(fake_label)

        fake = 0.9 * fake + 0.1 * torch.randn((fake.size()), device=device)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        D_G_z2 = output.mean().item()

        # Calculate gradients for G
        errG.backward()
        # Update G
        optimizerG.step()
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs - 1) and (i == len(dataloader) - 1)):
            with torch.no_grad():
                fixed_noise = torch.randn(ngf, nz, 1, 1, device=device)
                fake_display = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake_display, padding=2, normalize=True))

        iters += 1
    # Update the IS and FID curves every epoch.
    fig = plt.figure(figsize=(10, 4))
    ax = fig.add_subplot(1, 2, 1)
    ax.scatter(epoch_list, IS_list, color='firebrick')
    ax.plot(epoch_list, IS_list, color='firebrick')
    ax.set_ylabel('Inception Score (IS)')
    ax.set_xlabel('Epoch')
    ax.spines[['right', 'top']].set_visible(False)
    ax = fig.add_subplot(1, 2, 2)
    ax.scatter(epoch_list, FID_list, color='firebrick')
    ax.plot(epoch_list, FID_list, color='firebrick')
    ax.set_ylabel('Frechet Inception Distance (FID)')
    ax.set_xlabel('Epoch')
    ax.spines[['right', 'top']].set_visible(False)
    plt.tight_layout()
    plt.close(fig=fig)

    # Need to clear up the fake images every epoch.
    evaluator.clear_fake_imgs()

    G_losses.append(errG.item())
    D_losses.append(errD.item())

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.