Comments (18)
The worker attributes are setup when we initialize the dataloader:
https://github.com/pytorch/pytorch/blob/7c289c2a5c4e2233251565afadc2d95acf64b8c1/torch/utils/data/dataloader.py#L1113-L1128.
Since we are using torch's dataloader:
xla/test/test_train_mp_imagenet.py
Lines 235 to 252 in 1651e76
from xla.
I felt like you are looking for
Lines 135 to 192 in 34736f0
For the up to date master api you can also check https://pytorch.org/xla/master/#module-torch_xla.runtime
from xla.
@will-cromar @zpcore do you know where torch.utils.data
gets that info? Wondering if we can do some mapping and also support that api.
from xla.
Thanks, those were the functions I was looking for. A cartoon version of my solution is the following:
class MyDataset(torch.utils.data.IterableDataset):
def __init__(self):
super().__init__()
self.N = 100
self.data = torch.rand(self.N, 30)
def __iter__(self):
for i in range(self.N):
if i % xr.world_size() == xm.get_ordinal():
yield self.data[i]
def _mp_fn_(index):
device = xm.xla_device()
dataset = MyDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 10)
device_loader = pl.MpDeviceLoader(dataloader, device)
for epoch in range(3):
mysum = torch.tensor(0., device = device)
for batch in device_loader:
mysum += batch.sum()
sumsum = xm.all_reduce(xm.REDUCE_SUM, mysum).item()
print(epoch, sumsum)
This runs fine... my new issue is on my real data when I hit the .item()
it hangs. mysum
here is meant to be the total loss for the data processed on the current device, and then sumsum
is the total loss for the epoch (across all devices). Maybe there's a better pattern for getting the total loss?
from xla.
can you always do a xm.mark_step()
or torch_xla.sync()
before you do the .item
call. It is always recommend to flush the pending executions before accessing the value of the tensor.
from xla.
Hmm so now it hangs on that mark_step()
instead. Well, it gets past the mark_step()
on one device but the other 3 hang.
from xla.
that's... interesting. It usually mean the graph is different for each device. Can you dump the HLO following https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#common-debugging-environment-variables-combinations? You should multiple files.
from xla.
Hmm each device could end up processing a (slightly) different number of batches, I suppose that technically makes the graph different? I'll figure out getting the HLO and report back.
from xla.
OK HLO files are here. LMK if anything looks suspicious! In the meantime I'll see if I can get eager mode -> compilation going with the nightly build.
from xla.
Nightly build's (2.5.something) torch_xla.distributed.xla_multiprocessing
is only giving me access to 1 of 4 devices, is that expected?
from xla.
hmm no that's not expected, I am on nightly and if I do
python examples/data_parallel/train_resnet_xla_ddp.py
I can see 4 processes
epoch: 1, step: 190, loss: 6.7231669425964355, rate: 1746.296055355192
epoch: 1, step: 190, loss: 6.705419540405273, rate: 1746.3170991653592
epoch: 1, step: 190, loss: 6.700830459594727, rate: 1745.7355188993108
epoch: 1, step: 190, loss: 6.731178283691406, rate: 1746.154144282245
(each process prints their own loss)
from xla.
btw I check your HLO, the last computation is the same
HloModule IrToHlo.14, entry_computation_layout={(f32[], f32[])->(f32[])}
%AddComputation.6 (x.7: f32[], y.8: f32[]) -> f32[] {
%x.7 = f32[] parameter(0)
%y.8 = f32[] parameter(1)
ROOT %add.9 = f32[] add(f32[] %x.7, f32[] %y.8)
}
ENTRY %IrToHlo.14 (p0.1: f32[], p1.2: f32[]) -> (f32[]) {
%p1.2 = f32[] parameter(1), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
%p0.1 = f32[] parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
%tuple.3 = (f32[], f32[]) tuple(f32[] %p1.2, f32[] %p0.1), metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
%get-tuple-element.4 = f32[] get-tuple-element((f32[], f32[]) %tuple.3), index=0, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
%get-tuple-element.5 = f32[] get-tuple-element((f32[], f32[]) %tuple.3), index=1, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
%all-reduce.10 = (f32[], f32[]) all-reduce(f32[] %get-tuple-element.4, f32[] %get-tuple-element.5), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.6, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
%get-tuple-element.12 = f32[] get-tuple-element((f32[], f32[]) %all-reduce.10), index=1, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
%get-tuple-element.11 = f32[] get-tuple-element((f32[], f32[]) %all-reduce.10), index=0, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
ROOT %tuple.13 = (f32[]) tuple(f32[] %get-tuple-element.11)
}
which is just a simple all_reduce
.. I can't really tell why it hang. Do you have a repo I can try on my end? The model code can just be dummy model code or you can use one of my examples in https://github.com/pytorch/xla/blob/master/examples/data_parallel/train_resnet_spmd_data_parallel.py
from xla.
Hi @JackCaoG - I made a minimal
branch of my repo here. Hopefully it's straightforward to test with the info in the README
. Thanks!
from xla.
Thanks, let me take a look tmr.
from xla.
Related Issues (20)
- [API Usability] Internalize `ToXlaTensorArena`
- [API Usability] Delete `check_view_sharing`
- [API Usability] Internalize `reduce_gradients`
- [Fori Loop] Inconsistent Shape Behavior HOT 2
- Is there any way to directly execute the cached computational graph HOT 5
- Op info test for `T .. arange` HOT 1
- CUDA and GPU-Flavoured Docker/Container Image Missing CUDA Support HOT 1
- Graph dump to optimize HOT 9
- Invalid version identifier in filenames of nightly builds HOT 6
- How to test on a subset of TPUs in a TPU Pod HOT 7
- Failed to import torch_xla by following the GPU instructions on an H100 node (A3-High) HOT 1
- Iteration of MpDeviceLoader doesn't work HOT 1
- Improve device auto-detection HOT 2
- libtpu not installed with nightly build HOT 4
- PyTorch/XLA usability progress tracking
- inconsistency in calling `get_ordinal` and `world_size` calls HOT 2
- Effectively manage API usability changes
- Make `torch_xla.launch` work transparently in notebooks
- Support portable executables in `torch_xla.launch`
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from xla.