Comments (1)
import torch.nn as nn
import torch.nn.functional as F
from base import BaseModel
import torchvision.models as models
class MnistModel(BaseModel):
def init(self, num_classes=10):
super(MnistModel, self).init()
models.resnet18()
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
This is what your code looks like.
The torchvision.models.resnet18()
is a function that builds and returns the nn.Module
object(which is resnet here). In your code, this function is properly called but returned object(containing parameters) is not saved in class member.
class Resnet18(BaseModel):
def __init__(self, num_classes=10):
super(Resnet18, self).__init__()
self.resnet = models.resnet18()
self.resnet.avgpool = nn.AdaptiveMaxPool2d(1)
self.resnet.fc = nn.Linear(512, num_classes)
def forward(self, x_input):
output = self.resnet(x_input)
return torch.sigmoid(output)
This is my example code for doing this. ask google 'transfer learning in pytorch' for more information.
from pytorch-template.
Related Issues (20)
- Loss function HOT 1
- Any plans to support Wandb Hyperparameter Searching?
- Strange bugs occur when the number of Gpus trained and tested is inconsistent
- usage of shuffle=True when using SubsetRandomSampler HOT 1
- TODO: also configure logging for sub-processes(not master) HOT 4
- DataLoader example
- python train.py --resume path/to/checkpoint HOT 1
- Save best model HOT 1
- Model is moved to GPU after the optimizer is instatiated, resulting in a performance hit. HOT 4
- Setting 'early_stop: 0' does not disable it HOT 2
- Passing an iterable from config.json HOT 2
- Add custom flag & override from CLI HOT 1
- Some features I have implemented.
- ReduceLROnPlateau lr_scheduler HOT 3
- Only `data_loader.data_loaders` is plural HOT 5
- Any support for multiple loss functions? HOT 5
- Latest checkpoint HOT 8
- Some thing wrong with add_histogram function HOT 2
- Adding mesh to Tensorboard not working via TensorboardWriter HOT 2
- Any plans to support DistributedDataParallel? HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-template.