Comments (5)
the augmentations in the end are nn.Module 's, this same behaviour i believe you face when you forward a tensor to a regular model in pytorch that the tensor should match with the params device and not the other way around @johnnv1 @shijianjian
from kornia.
the augmentations in the end are nn.Module 's, this same behaviour i believe you face when you forward a tensor to a regular model in pytorch that the tensor should match with the params device and not the other way around @johnnv1 @shijianjian
Yeah, It should work like with nn.Module
. When I'm applying the same logic to a RandomJPEG
object - create an object, then move it to a CUDA device, then create an input tensor on the same CUDA device, then call the RandomJPEG
object with the previously created tensor I still have an error. Here is an example:
import torch
from kornia.augmentation import RandomJPEG
device = "cuda"
jpegq = (1.0, 50.0)
aug = RandomJPEG(jpeg_quality=jpegq, p=1.0).to(device)
example_input = torch.randn((3, 224, 224)).to(device)
res = aug(example_input)
And here is an error about the wrong devices:
/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3483.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
Traceback (most recent call last):
File "/home/dmdr/Documents/Code/Python/aaa/ptrainer/tmp.py", line 27, in <module>
res = aug(example_input)
File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/augmentation/base.py", line 210, in forward
output = self.apply_func(in_tensor, params, flags)
File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/augmentation/_2d/base.py", line 129, in apply_func
output = self.transform_inputs(in_tensor, params, flags, trans_matrix)
File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/augmentation/base.py", line 261, in transform_inputs
output = self.apply_transform(in_tensor, params, flags, transform=transform)
File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/augmentation/_2d/intensity/jpeg.py", line 56, in apply_transform
jpeg_output: Tensor = jpeg_codec_differentiable(input, params["jpeg_quality"])
File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/utils/image.py", line 231, in _wrapper
output = f(input, *args, **kwargs)
File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/enhance/jpeg.py", line 484, in jpeg_codec_differentiable
y_encoded, cb_encoded, cr_encoded = _jpeg_encode(
File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/enhance/jpeg.py", line 281, in _jpeg_encode
y_encoded: Tensor = _quantize(
File "/home/dmdr/miniconda3/envs/ptrain/lib/python3.10/site-packages/kornia/enhance/jpeg.py", line 177, in _quantize
quantization_table[:, None] * _jpeg_quality_to_scale(jpeg_quality)[:, None, None, None]
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
from kornia.
@ditwoo thanks ! we'll try to fix unless you want to give it a shot
from kornia.
@edgarriba I can write a PR with a fix.
from kornia.
@ditwoo thanks ! very appreciated
from kornia.
Related Issues (20)
- `AugmentationSequential` does not support instance masks shape (N, H, W) HOT 1
- Improve Image Matching docs HOT 2
- Bug in normalize_min_max HOT 2
- Update apply_colormap HOT 5
- geometry.transform.build_pyramid max_level definition HOT 1
- cv2.edgePreservingFilter this filter support?
- RANSAC.max_samples_by_conf returns negative numbers HOT 1
- Cannot compile `torch.jit.script(kornia.geometry.warp_perspective)` because of `KORNIA_CHECK_IS_TENSOR` HOT 2
- Move WunschLineMatcher to kornia.feature.matching Module HOT 2
- find_essential and run_5point are not working correctly with batch size>1 HOT 5
- Update pytorch to 2.3.0 on CI
- Refactor `SOLD2Net` to Support Dataclasses for Configuration HOT 1
- AttributeError: 'list' object has no attribute 'ndim' with RandomTransplantation and batch with keys HOT 2
- Including Steerers HOT 2
- mypy error: Argument 4 to "AugmentationSequential" has incompatible type "RandomTransplantation"; expected "_AugmentationBase | ImageSequential"
- NaN values returned during backward pass in axis_angle_to_rotation_matrix function
- No LAFS returns cpu tensor not on same device as input HOT 2
- RandomMosaic not working with masks? HOT 2
- implement a deterministic two view scene
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from kornia.