Giter VIP home page Giter VIP logo

Comments (3)

sumorday avatar sumorday commented on June 3, 2024

[class Generator(nn.Module):
def init(self, z_dim, M=4):
super().init()
self.M = M
self.linear = nn.Linear(z_dim, M * M * 256)
self.main = nn.Sequential(
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
nn.Tanh())
self.initialize()

def initialize(self):
    for m in self.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
            init.normal_(m.weight, std=0.02)
            init.zeros_(m.bias)

def forward(self, z, *args, **kwargs):
    x = self.linear(z)
    x = x.view(x.size(0), -1, self.M, self.M)
    x = self.main(x)
    return x

class Discriminator(nn.Module):
def init(self, M=32):
super().init()
self.M = M

    self.main = nn.Sequential(
        # M
        nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
        nn.LeakyReLU(0.1, inplace=True),
        nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
        nn.LeakyReLU(0.1, inplace=True),
        nn.BatchNorm2d(64),
        nn.Dropout2d(p=0.2),
        # M / 2
        nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
        nn.LeakyReLU(0.1, inplace=True),
        nn.BatchNorm2d(128),
        nn.Dropout2d(p=0.2),
        nn.Conv2d(128, 10, kernel_size=3, stride=1, padding=1),
        nn.ReLU(True))

    self.linear = nn.Linear(M // 2 * M // 2 * 10, 1) # here I am not sure
    self.initialize()

def initialize(self):
    for m in self.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            init.normal_(m.weight, std=0.02)
            init.zeros_(m.bias)
            spectral_norm(m)

def forward(self, x, *args, **kwargs):
    x = self.main(x)
    x = torch.flatten(x, start_dim=1)
    x = self.linear(x)
    return x](url)

from pytorch-gan-collections.

w86763777 avatar w86763777 commented on June 3, 2024

If you want to use custom model, you have better to write a new one with same prototype of __init__(self, z_dim) for generator and __init__(self) for discriminator, and update net_G_models and net_D_models in training scripts to include your models.

from pytorch-gan-collections.

sumorday avatar sumorday commented on June 3, 2024

Hi I tried to write a new one with same prototype of init(self, z_dim) for discriminator but the main.py(dcgan) ALSO need to change something like

self.linear = nn.Linear(M // 16 * M // 16 * 10, 1)

I have to delete M And it may happen some errors like :
Expected 4-dimensional input for 4-dimensional weight [128, 256, 4, 4], but got 2-dimensional input of size [128, 128] instead

some codes dimensions are 128. so it may cause the errors from calculate matrix. I have no ideas how to do...

if I delete that M,it will cause so many mistakes...

from pytorch-gan-collections.

Related Issues (14)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.