Giter VIP home page Giter VIP logo

Comments (7)

Jiayi-Pan avatar Jiayi-Pan commented on July 30, 2024 1

Solved, to run on tpu-v3-8 node

export TPU_HOST_BOUNDS='1,1,1'

For more complicated subsets, configure the following vars

TPU_CHIPS_PER_HOST_BOUNDS
TPU_HOST_BOUNDS'
TPU_VISIBLE_DEVICES

from xla.

JackCaoG avatar JackCaoG commented on July 30, 2024 1

Yea, that's the env var you need. I will close this issue if no further question?

from xla.

Jiayi-Pan avatar Jiayi-Pan commented on July 30, 2024

While export TPU_HOST_BOUNDS='1,1,1' works for naive code like

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device()
print(t)

It hangs in Multi-processing settings, at the first node of a tpu-v3-64 pod, with that env var set, following code hangs

import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch

def _mp_fn(index):
  device = xm.xla_device()
  data = torch.randn(2, 2, device=device)
  print(data)

if __name__ == '__main__':
  xmp.spawn(_mp_fn, args=())

Output:

jiayipan@t1v-n-bc530acf-w-0:~/prismatic-video-lms$ python example.py 
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1721514333.110325   95807 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/jiayipan/.l
ocal/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1721514333.110423   95807 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1721514333.110434   95807 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The fr
amework PJRT API version is 0.46.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1721514333.114095   95805 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/jiayipan/.l
ocal/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1721514333.114171   95805 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1721514333.114181   95805 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The fr
amework PJRT API version is 0.46.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1721514333.118047   95802 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/jiayipan/.l
ocal/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1721514333.118122   95802 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1721514333.118132   95802 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The fr
amework PJRT API version is 0.46.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1721514333.142126   95806 pjrt_api.cc:100] GetPjrtApi was found for tpu at /home/jiayipan/.l
ocal/lib/python3.10/site-packages/libtpu/libtpu.so
I0000 00:00:1721514333.142198   95806 pjrt_api.cc:79] PJRT_Api is set for device type tpu
I0000 00:00:1721514333.142215   95806 pjrt_api.cc:146] The PJRT plugin has PJRT API version 0.46. The fr
amework PJRT API version is 0.46.
concurrent.futures.process._RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 246, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 205, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 205, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 59, in _ru
n_thread_per_device
    initializer_fn(local_rank, local_world_size)
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 122, in in
itialize_multiprocess
    devices = xm.get_xla_supported_devices()
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/core/xla_model.py", line 99, in get
_xla_supported_devices
    devices = torch_xla._XLAC._xla_get_devices()
RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: Failed to establish SliceBuilder 
grpc channel to 10.142.0.20:8479.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/jiayipan/prismatic-video-lms/example.py", line 14, in <module>
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 38, in spawn
    return pjrt.spawn(fn, nprocs, start_method, args)
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 211, in spawn
    run_multiprocess(spawn_fn, start_method=start_method)
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/runtime.py", line 95, in wrapper
    return fn(*args, **kwargs)
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 171, in run_multiprocess
    replica_results = list(
  File "/home/jiayipan/.local/lib/python3.10/site-packages/torch_xla/_internal/pjrt.py", line 172, in <genexpr>
    itertools.chain.from_iterable(
  File "/usr/lib/python3.10/concurrent/futures/process.py", line 570, in _chain_from_iterable_of_lists
    for element in iterable:
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 621, in result_iterator
    yield _result_or_cancel(fs.pop())
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 319, in _result_or_cancel
    return fut.result(timeout)
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 458, in result
    return self.__get_result()
  File "/usr/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: Failed to establish SliceBuilder grpc channel to 10.142.0.20:8479.

Do you have any suggestions on how to fix this?

from xla.

JackCaoG avatar JackCaoG commented on July 30, 2024

Maybe take a look at https://gist.github.com/skye/f82ba45d2445bb19d53545538754f9a3? I believe for each subprocess you need to set different TPU_VISIBLE_DEVICES

from xla.

Jiayi-Pan avatar Jiayi-Pan commented on July 30, 2024

Thanks! I tried using this instead, which still doesn't work

export TPU_CHIPS_PER_PROCESS_BOUNDS="1,1,1"

# Set the TPU process bounds
export TPU_PROCESS_BOUNDS="2,2,1"

# Set the TPU process addresses
export TPU_PROCESS_ADDRESSES="localhost:8476,localhost:8477,localhost:8478,localhost:8479"

# Set the visible TPU devices
export TPU_VISIBLE_DEVICES="0"  # "1", "2", "3"

# Set the TPU process port
export TPU_PROCESS_PORT="8476"  # "8477", "8478", "8479"

export CLOUD_TPU_TASK_ID=0

Does it mean we need to provide different env vars to each of the process xmp.spawn creates? If so, how should we do this.

from xla.

JackCaoG avatar JackCaoG commented on July 30, 2024

lol @will-cromar I need your help

from xla.

will-cromar avatar will-cromar commented on July 30, 2024

You're on the right track. There are two places where we can request information about TPU topology: GCE metadata or environment variables.

If you want to do multiprocessing on one host out of a pod, the best way to do that would be to set all of the topology environment variables as if you were running on one host:

TPU_SKIP_MDS_QUERY=1 # Don't query metadata
TPU_HOST_BOUNDS=1,1,1 # Pretend there's one host in the "pod"
TPU_CHIPS_PER_HOST_BOUNDS=2,2,1 # 4 chips per host
TPU_WORKER_HOSTNAMES=localhost
WORKER_ID=0 # Since there's only one worker in this cluster, index is always 0

If you do that, then xmp.spawn will take care of TPU_PROCESS_BOUNDS, TPU_PROCESS_ADDRESSES, TPU_VISIBLE_DEVICES, etc. The logic for setting all of these lives in tpu.py if you're curious.

Just to be upfront, we can't support manually setting these topology settings in general. The configurations we support are already implemented through xmp.spawn.

Having said that, this particular configuration (skip metadata query and limit the workload to one host) is exactly the configuration used by Kaggle and Colab, which we do support, so you can expect that to keep working.

from xla.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.