Giter VIP home page Giter VIP logo

pytorch-image-generation-metrics's Introduction

Pytorch Implementation of Common Image Generation Metrics

PyPI

Installation

pip install pytorch-image-generation-metrics

Quick Start

from pytorch_image_generation_metrics import get_inception_score, get_fid

images = ... # [N, 3, H, W] normalized to [0, 1]
IS, IS_std = get_inception_score(images)        # Inception Score
FID = get_fid(images, 'path/to/fid_ref.npz') # Frechet Inception Distance

The file path/to/fid_ref.npz is compatiable with the official FID implementation.

Notes

The FID implementation is inspired by pytorch-fid.

This repository is developed for personal research. If you find this package useful, please feel free to open issues.

Features

  • Currently, this package supports the following metrics:
  • The computation procedures for IS and FID are integrated to avoid multiple forward passes.
  • Supports reading images on the fly to prevent out-of-memory issues, especially for large-scale images.
  • Supports computation on GPU to speed up some CPU operations, such as np.cov and scipy.linalg.sqrtm.

Reproducing Results of Official Implementations on CIFAR-10

Train IS Test IS Train(50k) vs Test(10k)
FID
Official 11.24±0.20 10.98±0.22 3.1508
ours 11.26±0.13 10.97±0.19 3.1525
ours use_torch=True 11.26±0.15 10.97±0.20 3.1457

The results differ slightly from the official implementations due to the framework differences between PyTorch and TensorFlow.

Documentation

Prepare Statistical Reference for FID

  • Download the pre-calculated reference, or
  • Calculate the statistical reference for your custom dataset using the command-line tool:
    python -m pytorch_image_generation_metrics.fid_ref \
        --path path/to/images \
        --output path/to/fid_ref.npz
    See fid_ref.py for details.

Inception Features

  • When getting IS or FID, the InceptionV3 model will be loaded into torch.device('cuda:0') by default.
  • Change the device argument in the get_* functions to set the torch device.

Using torch.Tensor as images

  • Prepare images as torch.float32 tensors with shape [N, 3, H, W], normalized to [0,1].
    from pytorch_image_generation_metrics import (
        get_inception_score,
        get_fid,
        get_inception_score_and_fid
    )
    
    images = ... # [N, 3, H, W]
    assert 0 <= images.min() and images.max() <= 1
    
    # Inception Score
    IS, IS_std = get_inception_score(
        images)
    
    # Frechet Inception Distance
    FID = get_fid(
        images, 'path/to/fid_ref.npz')
    
    # Inception Score & Frechet Inception Distance
    (IS, IS_std), FID = get_inception_score_and_fid(
        images, 'path/to/fid_ref.npz')

Using PyTorch DataLoader to Provide Images

  1. Use pytorch_image_generation_metrics.ImageDataset to collect images from your storage or use your custom torch.utils.data.Dataset.

    from pytorch_image_generation_metrics import ImageDataset
    from torch.utils.data import DataLoader
    
    dataset = ImageDataset(path_to_dir, exts=['png', 'jpg'])
    loader = DataLoader(dataset, batch_size=50, num_workers=4)

    You can wrap a generative model in a dataset to support generating images on the fly.

    class GeneratorDataset(Dataset):
        def __init__(self, G, noise_dim):
            self.G = G
            self.noise_dim = noise_dim
    
        def __len__(self):
            return 50000
    
        def __getitem__(self, index):
            return self.G(torch.randn(1, self.noise_dim))
    
    dataset = GeneratorDataset(G, noise_dim=128)
    loader = DataLoader(dataset, batch_size=50, num_workers=0)
  2. Calculate metrics

    from pytorch_image_generation_metrics import (
        get_inception_score,
        get_fid,
        get_inception_score_and_fid
    )
    
    # Inception Score
    IS, IS_std = get_inception_score(
        loader)
    
    # Frechet Inception Distance
    FID = get_fid(
        loader, 'path/to/fid_ref.npz')
    
    # Inception Score & Frechet Inception Distance
    (IS, IS_std), FID = get_inception_score_and_fid(
        loader, 'path/to/fid_ref.npz')

Load Images from a Directory

  • Calculate metrics for images in a directory and its subfolders.
    from pytorch_image_generation_metrics import (
        get_inception_score_from_directory,
        get_fid_from_directory,
        get_inception_score_and_fid_from_directory)
    
    IS, IS_std = get_inception_score_from_directory(
        'path/to/images')
    FID = get_fid_from_directory(
        'path/to/images', 'path/to/fid_ref.npz')
    (IS, IS_std), FID = get_inception_score_and_fid_from_directory(
        'path/to/images', 'path/to/fid_ref.npz')

Accelerating Matrix Computation with PyTorch

  • Set use_torch=True when calling functions like get_inception_score, get_fid, etc.

  • WARNING: when use_torch=True is used, the FID might be nan due to the unstable implementation of matrix sqrt root.

Tested Versions

  • python 3.9 + torch 1.13.1 + CUDA 11.7
  • python 3.9 + torch 2.3.0 + CUDA 12.1

License

This implementation is licensed under the Apache License 2.0.

This implementation is derived from pytorch-fid, licensed under the Apache License 2.0.

FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see https://arxiv.org/abs/1706.08500

The original implementation of FID is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. See https://github.com/bioinf-jku/TTUR.

pytorch-image-generation-metrics's People

Contributors

david20571015 avatar gongxinyuu avatar hadaev8 avatar w86763777 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

pytorch-image-generation-metrics's Issues

Exception when computing FID

When computing FID with following cmd:

      python -m pytorch_gan_metrics.calc_fid_stats \
             --path ./textual_inversion/cat_statue \
             --output ./textual_inversion/cat_statue/statistics.npz

I got this traceback.

/home/featurize/work/ldm/lib/python3.8/site-packages/numpy/core/fromnumeric.py:3474: RuntimeWarning: Mean of empty slice.
  return _methods._mean(a, axis=axis, dtype=dtype,
/home/featurize/work/ldm/lib/python3.8/site-packages/numpy/core/_methods.py:181: RuntimeWarning: invalid value encountered in true_divide
  ret = um.true_divide(
/home/featurize/work/ldm/lib/python3.8/site-packages/numpy/lib/function_base.py:495: RuntimeWarning: Mean of empty slice.
  avg = a.mean(axis)
/home/featurize/work/ldm/lib/python3.8/site-packages/pytorch_gan_metrics/utils.py:282: RuntimeWarning: Degrees of freedom <= 0 for slice
  sigma = np.cov(acts, rowvar=False)
/home/featurize/work/ldm/lib/python3.8/site-packages/numpy/lib/function_base.py:2680: RuntimeWarning: divide by zero encountered in true_divide
  c *= np.true_divide(1, fact)
/home/featurize/work/ldm/lib/python3.8/site-packages/numpy/lib/function_base.py:2680: RuntimeWarning: invalid value encountered in multiply
  c *= np.true_divide(1, fact)

Still I got a statistics.npz. However, I am not sure if it is correct. What's wrong with it?

FID_WEIGHTS_URL - not found

Hi,

Github url used in inception.py:

FID_WEIGHTS_URL = ('https://github.com/w86763777/pytorch_image_generation_metrics/releases/' 'download/v0.1.0/pt_inception-2015-12-05-6726825d.pth') doesn't seem to exist anymore.

When calling get_inception_score_and_fid(), I get the following error: HTTPError: HTTP Error 404: Not Found

get_inception_score() returns two NaNs

I used a very simple test example to test get_inception_score():

    input_image_tensor = torch.zeros([5, 3, 256, 256], dtype=torch.float32)
    IS, IS_std = get_inception_score(input_image_tensor)

Both returns, IS and IS_std, was nan with the following trace back:

    /ldm/lib/python3.8/site-packages/numpy/core/fromnumeric.py:3474: RuntimeWarning: Mean of empty slice.
        return _methods._mean(a, axis=axis, dtype=dtype,
    /ldm/lib/python3.8/site-packages/numpy/core/_methods.py:181: RuntimeWarning: invalid value encountered in true_divide
        ret = um.true_divide(
    /ldm/lib/python3.8/site-packages/numpy/core/_methods.py:189: RuntimeWarning: invalid value encountered in double_scalars
        ret = ret.dtype.type(ret / rcount)

The environments related are listed as follows:

    - python=3.8.10
    - cudatoolkit=11.3
    - pytorch=1.10.2
    - torchvision=0.11.3
    - numpy=1.22.3

Could you please tell me how to fix this bug?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.