Giter VIP home page Giter VIP logo

Comments (18)

zpcore avatar zpcore commented on July 30, 2024 1

The worker attributes are setup when we initialize the dataloader:
https://github.com/pytorch/pytorch/blob/7c289c2a5c4e2233251565afadc2d95acf64b8c1/torch/utils/data/dataloader.py#L1113-L1128.

Since we are using torch's dataloader:

train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=FLAGS.batch_size,
sampler=train_sampler,
drop_last=FLAGS.drop_last,
shuffle=False if train_sampler else True,
num_workers=FLAGS.num_workers,
persistent_workers=FLAGS.persistent_workers,
prefetch_factor=FLAGS.prefetch_factor)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=FLAGS.test_set_batch_size,
sampler=test_sampler,
drop_last=FLAGS.drop_last,
shuffle=False,
num_workers=FLAGS.num_workers,
persistent_workers=FLAGS.persistent_workers,
prefetch_factor=FLAGS.prefetch_factor)
, I think it should contain the worker info. I can do a test on the real data to see if it is there or not.

from xla.

JackCaoG avatar JackCaoG commented on July 30, 2024

I felt like you are looking for

xla/torch_xla/runtime.py

Lines 135 to 192 in 34736f0

@requires_pjrt
def local_process_count() -> int:
"""Returns the number of processes running on this host."""
return xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_COUNT, int, defval=1)
@requires_pjrt
def global_device_count() -> int:
"""Returns the total number of devices across all processes/hosts."""
return len(torch_xla._XLAC._xla_get_all_devices())
@requires_pjrt
def world_size() -> int:
"""Returns the total number of processes participating in the job."""
if torch_xla._XLAC._xla_get_replication_devices_count() == 0:
return 1
return global_device_count()
@requires_pjrt
def local_device_count() -> int:
"""Returns the total number of devices on this host.
Assumes each process has the same number of addressable devices.
"""
return local_process_count() * addressable_device_count()
@requires_pjrt
def addressable_device_count() -> int:
"""Returns the number of devices visible to this process."""
return torch_xla._XLAC._xla_num_devices()
@requires_pjrt
def global_ordinal() -> int:
"""Returns global ordinal of this thread within all processes.
Global ordinal is in range [0, global_device_count). Global ordinals are not
guaranteed to have any predictable relationship to the TPU worker ID nor are
they guaranteed to be contiguous on each host."""
return torch_xla._XLAC._xla_get_default_device_ordinal()
@requires_pjrt
def local_ordinal() -> int:
"""Returns local ordinal of this thread within this host.
Local ordinal is in range [0, local_device_count)."""
local_rank = xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_RANK, int, 0)
devices_per_process = addressable_device_count()
return local_rank * devices_per_process + xla_device().index
@requires_pjrt
def process_index() -> int:
return torch_xla._XLAC._xla_get_process_index()
.

For the up to date master api you can also check https://pytorch.org/xla/master/#module-torch_xla.runtime

from xla.

JackCaoG avatar JackCaoG commented on July 30, 2024

@will-cromar @zpcore do you know where torch.utils.data gets that info? Wondering if we can do some mapping and also support that api.

from xla.

davidaknowles avatar davidaknowles commented on July 30, 2024

Thanks, those were the functions I was looking for. A cartoon version of my solution is the following:

class MyDataset(torch.utils.data.IterableDataset):

    def __init__(self):
        super().__init__()
        self.N = 100
        self.data = torch.rand(self.N, 30) 

    def __iter__(self): 
        for i in range(self.N): 
            if i % xr.world_size() == xm.get_ordinal(): 
                yield self.data[i]

def _mp_fn_(index): 

    device = xm.xla_device()
    dataset = MyDataset()
    dataloader = torch.utils.data.DataLoader(dataset, batch_size = 10)
    device_loader = pl.MpDeviceLoader(dataloader, device)

    for epoch in range(3): 
        mysum = torch.tensor(0., device = device) 
        for batch in device_loader: 
            mysum += batch.sum()
        sumsum = xm.all_reduce(xm.REDUCE_SUM, mysum).item()
        print(epoch, sumsum)

This runs fine... my new issue is on my real data when I hit the .item() it hangs. mysum here is meant to be the total loss for the data processed on the current device, and then sumsum is the total loss for the epoch (across all devices). Maybe there's a better pattern for getting the total loss?

from xla.

JackCaoG avatar JackCaoG commented on July 30, 2024

can you always do a xm.mark_step() or torch_xla.sync() before you do the .item call. It is always recommend to flush the pending executions before accessing the value of the tensor.

from xla.

davidaknowles avatar davidaknowles commented on July 30, 2024

Hmm so now it hangs on that mark_step() instead. Well, it gets past the mark_step() on one device but the other 3 hang.

from xla.

JackCaoG avatar JackCaoG commented on July 30, 2024

that's... interesting. It usually mean the graph is different for each device. Can you dump the HLO following https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#common-debugging-environment-variables-combinations? You should multiple files.

from xla.

davidaknowles avatar davidaknowles commented on July 30, 2024

Hmm each device could end up processing a (slightly) different number of batches, I suppose that technically makes the graph different? I'll figure out getting the HLO and report back.

from xla.

davidaknowles avatar davidaknowles commented on July 30, 2024

OK HLO files are here. LMK if anything looks suspicious! In the meantime I'll see if I can get eager mode -> compilation going with the nightly build.

from xla.

davidaknowles avatar davidaknowles commented on July 30, 2024

Nightly build's (2.5.something) torch_xla.distributed.xla_multiprocessing is only giving me access to 1 of 4 devices, is that expected?

from xla.

JackCaoG avatar JackCaoG commented on July 30, 2024

hmm no that's not expected, I am on nightly and if I do

python examples/data_parallel/train_resnet_xla_ddp.py

I can see 4 processes

epoch: 1, step: 190, loss: 6.7231669425964355, rate: 1746.296055355192
epoch: 1, step: 190, loss: 6.705419540405273, rate: 1746.3170991653592
epoch: 1, step: 190, loss: 6.700830459594727, rate: 1745.7355188993108
epoch: 1, step: 190, loss: 6.731178283691406, rate: 1746.154144282245

(each process prints their own loss)

from xla.

JackCaoG avatar JackCaoG commented on July 30, 2024

btw I check your HLO, the last computation is the same

HloModule IrToHlo.14, entry_computation_layout={(f32[], f32[])->(f32[])}

%AddComputation.6 (x.7: f32[], y.8: f32[]) -> f32[] {
  %x.7 = f32[] parameter(0)
  %y.8 = f32[] parameter(1)
  ROOT %add.9 = f32[] add(f32[] %x.7, f32[] %y.8)
}

ENTRY %IrToHlo.14 (p0.1: f32[], p1.2: f32[]) -> (f32[]) {
  %p1.2 = f32[] parameter(1), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  %p0.1 = f32[] parameter(0), metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  %tuple.3 = (f32[], f32[]) tuple(f32[] %p1.2, f32[] %p0.1), metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  %get-tuple-element.4 = f32[] get-tuple-element((f32[], f32[]) %tuple.3), index=0, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  %get-tuple-element.5 = f32[] get-tuple-element((f32[], f32[]) %tuple.3), index=1, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  %all-reduce.10 = (f32[], f32[]) all-reduce(f32[] %get-tuple-element.4, f32[] %get-tuple-element.5), replica_groups={}, constrain_layout=true, to_apply=%AddComputation.6, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  %get-tuple-element.12 = f32[] get-tuple-element((f32[], f32[]) %all-reduce.10), index=1, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  %get-tuple-element.11 = f32[] get-tuple-element((f32[], f32[]) %all-reduce.10), index=0, metadata={op_type="xla__cross_replica_sum" op_name="xla__cross_replica_sum" source_file="/home/daknowles/.local/lib/python3.10/site-packages/torch/_ops.py" source_line=854}
  ROOT %tuple.13 = (f32[]) tuple(f32[] %get-tuple-element.11)
}

which is just a simple all_reduce.. I can't really tell why it hang. Do you have a repo I can try on my end? The model code can just be dummy model code or you can use one of my examples in https://github.com/pytorch/xla/blob/master/examples/data_parallel/train_resnet_spmd_data_parallel.py

from xla.

davidaknowles avatar davidaknowles commented on July 30, 2024

Hi @JackCaoG - I made a minimal branch of my repo here. Hopefully it's straightforward to test with the info in the README. Thanks!

from xla.

JackCaoG avatar JackCaoG commented on July 30, 2024

Thanks, let me take a look tmr.

from xla.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.