Giter VIP home page Giter VIP logo

siren's Introduction

Sinusoidal Representation Networks (SIREN)

Unofficial PyTorch implementation of Sinusodial Representation networks (SIREN) from the paper Implicit Neural Representations with Periodic Activation Functions. This repository is a PyTorch port of this excellent TF 2.0 implementation of the same.

If you are using this codebase in your research, please use the following citation:

@software{aman_dalmia_2020_3902941,
  author       = {Aman Dalmia},
  title        = {dalmia/siren},
  month        = jun,
  year         = 2020,
  publisher    = {Zenodo},
  version      = {v1.1},
  doi          = {10.5281/zenodo.3902941},
  url          = {https://doi.org/10.5281/zenodo.3902941}
}

Setup

  • Install using pip
$ pip install siren-torch

Usage

Sine activation

You can use the Sine activation as any other activation

from siren import Sine

x = torch.rand(10)
y = Sine(w0=1)(x)

Initialization

The authors in the paper propose a principled way of intializing the layers for the SIREN model. The initialization function can be used as any other initialization present in torch.nn.init.

from siren.init import siren_uniform_

w = torch.empty(3, 5)
siren_uniform_(w, mode='fan_in', c=6)

SIREN model

The SIREN model used in the paper, with sine activation and custom initialization, can directly be created as follows.

from siren import SIREN

# defining the model
layers = [256, 256, 256, 256, 256]
in_features = 2
out_features = 3
initializer = 'siren'
w0 = 1.0
w0_initial = 30.0
c = 6
model = SIREN(
    layers, in_features, out_features, w0, w0_initial,
    initializer=initializer, c=c)

# defining the input
x = torch.rand(10, 2)

# forward pass
y = model(x)

Results on Image Inpainting task

A partial implementation of the image inpainting task is available as the train_inpainting_siren.py and eval_inpainting_siren.py scripts.

To run training:

$ python scripts/train_inpainting_siren.py

To run evaluation:

$ python scripts/eval_inpainting_siren.py

Weight files are made available in the repository under the checkpoints directory. It generates the following output after 5000 epochs of training with batch size 8192 while using only 10% of the available pixels in the image during training phase.

Tests

Tests are written using unittest. You can run any script under the tests folder.

Contributing

As mentioned at the beginning, this codebase is a PyTorch port of this. So, I might have missed a few details mentioned in the original paper. Assuming that the implemention in the linked repo is correct, one can safely trust this implementation as well. The only major difference from the reference repo is that it has w0 as part of the initialization as well. I did not see that in the paper and hence, didn't include it here. I have not deeply read the paper and this is simply to serve as a starting point for anyone looking for the implementation. Please feel free to make a PR or create an issue if you find a bug or you want to contribute to improve any other aspect of the codebase.

siren's People

Contributors

dalmia avatar dalmiaman avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

siren's Issues

Implementation discuession

Hi,
I also implemented my version of SIREN according to the paper, and I have done several tests with it. My implementation is here SIREN-2d .

I haven't run your code yet but have two questions.

  1. the weight initialization of your code applies w0 to the result of linear output link. In the paper, seems that w0 is only applied to the weight*input without the bias. I'm not sure if my understanding is correct.

  2. In the paper, SIREN is also used to fit image gradient. It's not clear that the SIREN gradient is calculated in an analytically way or not. Since the authors explicitly give out its analytical gradient, I think it should be in an analytically way. However, I cannot converge in this setting. Could you also provide such tests, or try my code to check if my understanding is correct.

I will give more tests when I have time.

how to use with ResNet

Thankyou for this project but I was interested to know how can we use the siren activation with any other deep model such as ResNet, DenseNet, etc?

Currently, I guess the implementation is for MLP kind networks only.

In the scripts folder can you also provide python file for say inpainting using Resnet-18 with ReLU replaced with SIREN.

Thankyou

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.