Giter VIP home page Giter VIP logo

Comments (1)

ActonMartin avatar ActonMartin commented on August 25, 2024

I have a gpu ,so I set "n_gpu": 1. And "base_trainer.py"(line 17) set model to cuda. And "trainer.py"(line 42) set input to cuda. But when start my Dataset, the error appears.
Trainable parameters: 31042369 Traceback (most recent call last): File "train.py", line 68, in <module> main(config) File "train.py", line 49, in main trainer.train() File "E:\kaggle\pytorch-template\base\base_trainer.py", line 66, in train result = self._train_epoch(epoch) File "E:\kaggle\pytorch-template\trainer\trainer.py", line 47, in _train_epoch output = self.model(data) File "D:\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__ result = self.forward(*input, **kwargs) File "E:\kaggle\pytorch-template\model\model.py", line 117, in forward x1 = self.inc(x) File "D:\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__ result = self.forward(*input, **kwargs) File "E:\kaggle\pytorch-template\model\model.py", line 42, in forward return self.double_conv(x) File "D:\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__ result = self.forward(*input, **kwargs) File "D:\Anaconda3\lib\site-packages\torch\nn\modules\container.py", line 100, in forward input = module(input) File "D:\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 550, in __call__ result = self.forward(*input, **kwargs) File "D:\Anaconda3\lib\site-packages\torch\nn\modules\conv.py", line 349, in forward return self._conv_forward(input, self.weight) File "D:\Anaconda3\lib\site-packages\torch\nn\modules\conv.py", line 346, in _conv_forward self.padding, self.dilation, self.groups) RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same
It makes me feel confused.I don't have idea.

I find CRUX of the problem! It's definition of the model's forward. There is the origin code.
β€˜β€™β€˜python
def forward(self, x):
x = torch.as_tensor(x, dtype=torch.float)
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
'''
there are two ways.
1.delete "x = torch.as_tensor(x, dtype=torch.float) "
or
2.It needs to make it transforms into device('cuda').
let it changes to "x = torch.as_tensor(x, dtype=torch.float).to(device="cuda")"

from pytorch-template.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.