Giter VIP home page Giter VIP logo

xla's Introduction

XLA

XLA (Accelerated Linear Algebra) is an open-source machine learning (ML) compiler for GPUs, CPUs, and ML accelerators.

The XLA compiler takes models from popular ML frameworks such as PyTorch, TensorFlow, and JAX, and optimizes them for high-performance execution across different hardware platforms including GPUs, CPUs, and ML accelerators.

Get started

If you want to use XLA to compile your ML project, refer to the corresponding documentation for your ML framework:

If you're not contributing code to the XLA compiler, you don't need to clone and build this repo. Everything here is intended for XLA contributors who want to develop the compiler and XLA integrators who want to debug or add support for ML frontends and hardware backends.

Contribute

If you'd like to contribute to XLA, review How to Contribute and then see the developer guide.

Contacts

  • For questions, contact the maintainers - maintainers at openxla.org

Resources

Code of Conduct

While under TensorFlow governance, all community spaces for SIG OpenXLA are subject to the TensorFlow Code of Conduct.

xla's People

Contributors

akuegel avatar anlunx avatar berkinilbeyi avatar bixia1 avatar blakehechtman avatar cheshire avatar chr1sj0nes avatar chsigg avatar d0k avatar ddunl avatar ezhulenev avatar ghpvnist avatar hanbinyoon avatar hawkinsp avatar jblespiau avatar jreiffers avatar jurahul avatar klucke avatar majnemer avatar meheffernan avatar nouiz avatar olegshyshkov avatar pifon2a avatar pschuh avatar skye avatar tensorflower-gardener avatar timshen91 avatar tyb0807 avatar ukoxyz avatar yunxing avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

xla's Issues

CMake build support

XLA builds with Bazel at the moment, is it desirable to also have a CMake build?
Who would benefit from this and what workflow will this enable that aren't doable or easy with the current Bazel configuration?

crosstool_wrapper_driver_is_not_gcc failed: error executing command

I am building xla from source on a Ubuntu image, following the instructions here.

yes '' | GCC_HOST_COMPILER_PATH=/usr/bin/gcc-10 CC=/usr/bin/gcc-10 TF_NEED_ROCM=0 TF_NEED_CUDA=1 TF_CUDA_CLANG=0 ./configure

bazel build --test_output=all --spawn_strategy=sandboxed //xla/...

When I run bazel build --test_output=all --spawn_strategy=sandboxed //xla/... , I see the following output:

INFO: Options provided by the client:
  Inherited 'common' options: --isatty=1 --terminal_columns=120
INFO: Reading rc options for 'build' from /root/openxla/xla/.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'build' from /root/openxla/xla/.bazelrc:
  'build' options: --define framework_shared_object=true --define tsl_protobuf_header_only=true --define=use_fast_cpp_protos=true --define=allow_oversize_protos=true --spawn_strategy=standalone -c opt --announce_rc --define=grpc_no_ares=true --noincompatible_remove_legacy_whole_archive --enable_platform_specific_config --define=with_xla_support=true --config=short_logs --config=v2 --define=no_aws_support=true --define=no_hdfs_support=true --experimental_cc_shared_library --experimental_link_static_libraries_once=false --incompatible_enforce_config_setting_visibility
INFO: Reading rc options for 'build' from /root/openxla/xla/.tf_configure.bazelrc:
  'build' options: --action_env PYTHON_BIN_PATH=/root/anaconda3/bin/python3 --action_env PYTHON_LIB_PATH=/root/anaconda3/lib/python3.10/site-packages --python_path=/root/anaconda3/bin/python3 --action_env CUDA_TOOLKIT_PATH=/usr/local/cuda-11.2 --action_env TF_CUDA_COMPUTE_CAPABILITIES=3.5,7.0 --action_env LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 --action_env GCC_HOST_COMPILER_PATH=/usr/bin/x86_64-linux-gnu-gcc-10 --config=cuda --test_tag_filters=-benchmark-test,-no_oss,-oss_excluded,-no_gpu,-oss_serial --build_tag_filters=-benchmark-test,-no_oss,-oss_excluded,-no_gpu
INFO: Reading rc options for 'build' from /root/openxla/xla/.bazelrc:
  'build' options: --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils,tensorflow/core/tfrt/utils/debug
INFO: Found applicable config definition build:short_logs in file /root/openxla/xla/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:v2 in file /root/openxla/xla/.bazelrc: --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1
INFO: Found applicable config definition build:cuda in file /root/openxla/xla/.bazelrc: --repo_env TF_NEED_CUDA=1 --crosstool_top=@local_config_cuda//crosstool:toolchain --@local_config_cuda//:enable_cuda
INFO: Found applicable config definition build:linux in file /root/openxla/xla/.bazelrc: --define=build_with_onednn_v2=true --host_copt=-w --copt=-Wno-all --copt=-Wno-extra --copt=-Wno-deprecated --copt=-Wno-deprecated-declarations --copt=-Wno-ignored-attributes --copt=-Wno-array-bounds --copt=-Wunused-result --copt=-Werror=unused-result --copt=-Wswitch --copt=-Werror=switch --copt=-Wno-error=unused-but-set-variable --define=PREFIX=/usr --define=LIBDIR=$(PREFIX)/lib --define=INCLUDEDIR=$(PREFIX)/include --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 --config=dynamic_kernels --experimental_guard_against_concurrent_changes
INFO: Found applicable config definition build:dynamic_kernels in file /root/openxla/xla/.bazelrc: --define=dynamic_loaded_kernels=true --copt=-DAUTOLOAD_DYNAMIC_KERNELS
DEBUG: /root/.cache/bazel/_bazel_root/f13311be1636a046ef2c92173c8bb2a7/external/bazel_tools/tools/cpp/lib_cc_configure.bzl:118:10:
Auto-Configuration Warning: 'TMP' environment variable is not set, using 'C:\Windows\Temp' as default
INFO: Build option --action_env has changed, discarding analysis cache.
INFO: Analyzed 2666 targets (266 packages loaded, 26626 targets configured).
INFO: Found 2666 targets...
ERROR: /root/openxla/xla/xla/backends/profiler/gpu/BUILD:134:13: Compiling xla/backends/profiler/gpu/cuda_test.cu.cc failed: (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target //xla/backends/profiler/gpu:cuda_test) external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/xla/backends/profiler/gpu/_objs/cuda_test/cuda_test.cu.pic.d ... (remaining 117 arguments skipped)

Use --sandbox_debug to see verbose messages from the sandbox and retain the sandbox build root for debugging
nvcc warning : The 'compute_35', 'compute_37', 'compute_50', 'sm_35', 'sm_37' and 'sm_50' architectures are deprecated, and may be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
nvcc warning : The 'compute_35', 'compute_37', 'compute_50', 'sm_35', 'sm_37' and 'sm_50' architectures are deprecated, and may be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
external/eigen_archive/Eigen/src/Core/util/XprHelper.h(97): warning: __host__ annotation is ignored on a function("no_assignment_operator") that is explicitly defaulted on its first declaration

The main error seems to be that

ERROR: /root/openxla/xla/xla/backends/profiler/gpu/BUILD:134:13: Compiling xla/backends/profiler/gpu/cuda_test.cu.cc failed: (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target //xla/backends/profiler/gpu:cuda_test) external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/xla/backends/profiler/gpu/_objs/cuda_test/cuda_test.cu.pic.d ... (remaining 117 arguments skipped)

System information

  • Ubuntu20.04

  • Python version: 3.10.9

  • GCC/Compiler version (if compiling from source):10.3.0

  • CUDA version:11.2.0

  • cuDNN version:8.1

Status of Flash-Attention on TPUs

Hello, Flash Attention is a method to produce tiled and fused kernels such that the tiled parameters can fit onto the device SRAM.

May I ask to what degree this technique has been applied to TPUs?

I am aware that there is ongoing work on improving tiling via gml_st, but it doesn't seem to be a blocker for this work, since we already seem to support rectangular tiling well in XLA (?).

May I also know the status of integration with gml_st?

Resources

  1. Triton example implementation
  2. https://github.com/HazyResearch/flash-attention
  3. https://github.com/lucidrains/flash-attention-jax

Configure layout assignment for transposes

By default, layout assignment tries to assign a layout to transposes that make them a bitcast. This layout is then propagated inside the HloComputation, which means if it does not cancel out with another transpose, there may be a copy op in some other part of the computation. This copy will later be turned back into a transpose.

In T5X, we observed that default XLA layout assignment gives us many transposes in each transformer layer, which results in perf overheads. We tried to fix this problem by manually transposing T5X network input, however, the manually inserted transpose was ignored by XLA, which also creates many transposes in each layer.

@akuegel created a PR #1485 which can stop our manually inserted transpose from being ignored. It turns out that this flag really can help us get better perf. I built a unit test with one T5X encoder layer, default runtime without manual transpose is 11.02ms. Runtime with manual transpose but without Adrian's PR is 11.04ms. After I apply both manual transpose and Adrian's PR (XLA_FLAGS=--xla_gpu_transpose_to_bitcast=false), runtime becomes 10.63ms. Clearly, Adrian's PR is helpful.

BTW, you can find my unit test HLO from here https://drive.google.com/drive/u/1/folders/1MRsPJqF7M-DoFRT3Gw8kYKHSoPAKPDUi

Unfortunately, Adrian told us that his PR cannot be merged because it's controversial and would restricts what they can do in compiler. To get better perf, I file this github issue to discuss what other possible solutions that we can have.

One suggested solution is to use a no-op custom call as "barrier" so that transposes do not get propagated. However, we still do not know how to do this easily without affecting performance. Other ideas that have been proposed is allowing a metadata annotation for the transpose.

We do not have a concrete solution yet, I hope we can nail down this problem by discussing here. Thanks a lot.

Explore performance of XLA:CPU on ARM.

@sherhut @d0k @jreiffers

It would be interesting to benchmark XLA:CPU Next on ARM. I am starting this issue to track the progress and also to share information about the code location.

XLA:CPU uses MLIR tiling/fusion/vectorization transformations that exist in both OpenXLA and TF repos.

1. XLA:CPU compiler contains two important parts

  • HloXlaRuntimePipeline MLIR pipeline that goes from HLO to Linalg + tHLO, then performs tiling/fusion and buffer allocation/optimizations and emits structured control flow with scalars, vectors and memrefs.

  • XlaCpuCompilationPipeline that lowers the result of hlo-xla-runtime-pipeline to LLVM.

2. Tiling, fusion and vectorization.

CpuTilingPipeline finds fusion clusters e.g. map(matmul(transpose)), reduce(map); tiles the root, fuses all consumers in and then vectorizes or scalarizes the loop bodies. There are many tests that fuse tHLO/Linalg ops in tests/Dialect/gml_st/cpu_tiling. This pipeline has options that affect tile sizes.

3. Vector optimizations and lowering to SCF.

LowerVectorsPass is launched after bufferization. It rewrites higher-level vector ops, e.g. vector.contract, vector.multi_reduction; optimizes vector.transfer_read/write ops and then lowers the result to SCF by unrolling the vectors.

4. Enabling MLIR pipeline for AOT compilation.

tf_library rule should have mlir_components set to "HloLowering".

Using the XLA compiler as a library from CMake

What is the recommended way of including the XLA compiler as a library in a CMake build system?

There seems to be no Bazel target to install/package the libraries an include files.
I am looking for something like

make install

Auto Clustering Leading to Invalid Argument Error

Hi XLA Experts,

We are using Tensorflow (2.4) together with Horovod (0.23) to do distributed training.
We turned on auto clustering via tf.config.optimizer.set_jit(True). However it throws the following error:

324227 [14]<stderr>:  File "/opt/code-fetcher-system/src/jymbii-pc-v2-azkaban_90f8ba721605647216ab36f5d75f950ef4e7b509633d1439819d6a8e67629db9/site-packages/tensorflow/python/keras/engine/training.py", line 1100, in fit
324228 [14]<stderr>:    tmp_logs = self.train_function(iterator)
324229 [14]<stderr>:  File "/opt/code-fetcher-system/src/jymbii-pc-v2-azkaban_90f8ba721605647216ab36f5d75f950ef4e7b509633d1439819d6a8e67629db9/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
324230 [14]<stderr>:    result = self._call(*args, **kwds)
324231 [14]<stderr>:  File "/opt/code-fetcher-system/src/jymbii-pc-v2-azkaban_90f8ba721605647216ab36f5d75f950ef4e7b509633d1439819d6a8e67629db9/site-packages/tensorflow/python/eager/def_function.py", line 956, in _call
324232 [14]<stderr>:    filtered_flat_args)
324233 [14]<stderr>:  File "/opt/code-fetcher-system/src/jymbii-pc-v2-azkaban_90f8ba721605647216ab36f5d75f950ef4e7b509633d1439819d6a8e67629db9/site-packages/tensorflow/python/eager/function.py", line 2943, in __call__
324234 [14]<stderr>:    filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access
324235 [14]<stderr>:  File "/opt/code-fetcher-system/src/jymbii-pc-v2-azkaban_90f8ba721605647216ab36f5d75f950ef4e7b509633d1439819d6a8e67629db9/site-packages/tensorflow/python/eager/function.py", line 1919, in _call_flat
324236 [14]<stderr>:    ctx, args, cancellation_manager=cancellation_manager))
324237 [14]<stderr>:  File "/opt/code-fetcher-system/src/jymbii-pc-v2-azkaban_90f8ba721605647216ab36f5d75f950ef4e7b509633d1439819d6a8e67629db9/site-packages/tensorflow/python/eager/function.py", line 560, in call
324238 [14]<stderr>:    ctx=ctx)
324239 [14]<stderr>:  File "/opt/code-fetcher-system/src/jymbii-pc-v2-azkaban_90f8ba721605647216ab36f5d75f950ef4e7b509633d1439819d6a8e67629db9/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
324240 [14]<stderr>:    inputs, attrs, num_outputs)
324241 [14]<stderr>:tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
324242 [14]<stderr>:  (0) Invalid argument:  Trying to assign variable with wrong dtype. Expected INVALID got float
324243 [14]<stderr>:    [[{{node cond/else/_1/cond/StatefulPartitionedCall/Variable_320/cond/else/_22838/Variable_320/cond/Assign}}]]
324244 [14]<stderr>:    [[cond/else/_1/cond/StatefulPartitionedCall/assert_greater_equal/Assert/AssertGuard/branch_executed/_26438/_6409]]
324245 [14]<stderr>:  (1) Invalid argument:  Trying to assign variable with wrong dtype. Expected INVALID got float
324246 [14]<stderr>:    [[{{node cond/else/_1/cond/StatefulPartitionedCall/Variable_320/cond/else/_22838/Variable_320/cond/Assign}}]]
324247 [14]<stderr>:0 successful operations.
324248 [14]<stderr>:0 derived errors ignored. [Op:__inference_fn_with_cond_272149]
324249 [14]<stderr>:
324250 [14]<stderr>:Function call stack:
324251 [14]<stderr>:fn_with_cond -> fn_with_cond
324252 [14]<stderr>:

I am not sure if this is right place for me to ask this question, but it greatly helps if you could take a quick look and suggest on how I can further debug. Thank you in advance!

Remove wrapped downcasts from PJRT C API

This cost me a fair bit of time today chasing down what appeared to be a jump to a stray address (which happened to be in code that I own with a backtrace that made no logical sense). I tracked it down to the PJRT C API wrapped accessors which presume that the other side of the C API is just a C++ object that can be blindly downcasted to. It looks like sometime in the last N days, IFRT started querying num replicas and partitions, which are some of the last things using this.

There are TODOs in the code to stop doing this. It's a really-really bad idea and it would be great to prioritize finishing making this kind of bug impossible.

In the meantime, I am carrying this patch which gets me back to limping.

diff --git a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h
index 65f233c12e4..96c2cddcd31 100644
--- a/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h
+++ b/tensorflow/compiler/xla/pjrt/pjrt_c_api_client.h
@@ -413,8 +413,11 @@ class PjRtCApiExecutable : public PjRtExecutable {
   PjRtCApiExecutable(const PJRT_Api* c_api, PJRT_Executable* executable);
 
   absl::string_view name() const override;
-  int num_replicas() const override { return wrapped()->num_replicas(); }
-  int num_partitions() const override { return wrapped()->num_partitions(); }
+  // int num_replicas() const override { return wrapped()->num_replicas(); }
+  // int num_partitions() const override { return wrapped()->num_partitions(); }
+
+  int num_replicas() const override { return 1; }
+  int num_partitions() const override { return 1; }
 
   int64_t SizeOfGeneratedCodeInBytes() const override;
 
@@ -443,8 +446,11 @@ class PjRtCApiLoadedExecutable : public PjRtLoadedExecutable {
 
   PjRtClient* client() const override { return client_; }
   absl::string_view name() const override { return executable_->name(); }
-  int num_replicas() const override { return wrapped()->num_replicas(); }
-  int num_partitions() const override { return wrapped()->num_partitions(); }
+  // int num_replicas() const override { return wrapped()->num_replicas(); }
+  // int num_partitions() const override { return wrapped()->num_partitions(); }
+
+  int num_replicas() const override { return 1; }
+  int num_partitions() const override { return 1; }
 
   int64_t SizeOfGeneratedCodeInBytes() const override {
     return executable_->SizeOfGeneratedCodeInBytes();

[XLA:GPU] Significant memory usage increase after adding performance modeling in fusion merger

There's a significant memory consumption increase after torch_xla updating its TF pin (pytorch/xla#4815). After some debugging, I found the following commit introduces this regression:

commit 5c5eeaf0cefc40cff80072fb36735dc400d712f0
Author: Ilia Sergachev <[email protected]>
Date:   Wed Nov 2 01:45:59 2022 -0700

    [XLA:GPU] Replace heuristics in the fusion merger with performance modeling.
    
    PiperOrigin-RevId: 485532049

After this change, the buffer assignment of bert-base-uncased increased from 34.04GiB to 54.55GiB.

Removing these lines seem to solve this problem: https://github.com/tensorflow/tensorflow/blob/959db87d6c468ea6319a1e4cf5979abe70284bbd/tensorflow/compiler/xla/service/gpu/fusion_merger.cc#L268-L274.

xla debug files can be found here:
good: https://drive.google.com/file/d/1YPHDEi29xld1dtwngMoGzub_-gicFXwV/view?usp=share_link
bad: https://drive.google.com/file/d/1Unjc3IzZuJ4clfB8ZU5FZzChg_azR-Qi/view?usp=share_link

Cannot build CPU version

Could someone please help with initial build? Manual seems to be outdated https://github.com/openxla/xla/blob/main/docs/developer_guide.md

Repro steps:

git clone https://github.com/openxla/xla.git
cd xla
docker pull tensorflow/build:latest-python3.9
docker run --name xla -w /xla -it -d --rm -v $PWD:/xla tensorflow/build:latest-python3.9 bash
docker exec xla ./configure
docker exec xla bazel build --test_output=all --spawn_strategy=sandboxed //xla/...

1 hour later:

ERROR: /xla/xla/tools/BUILD:160:14: Linking xla/tools/dumped_computation_to_text failed: (Exit 1): gcc failed: error executing command (from target //xla/tools:dumped_computation_to_text) /usr/bin/gcc @bazel-out/k8-opt/bin/xla/tools/dumped_computation_to_text-2.params

Use --sandbox_debug to see verbose messages from the sandbox and retain the sandbox build root for debugging
bazel-out/k8-opt/bin/xla/backends/interpreter/_objs/platform/platform.o:platform.cc:function stream_executor::interpreter::XlaInterpreterPlatform::GetExecutor(stream_executor::StreamExecutorConfig const&): error: undefined reference to 'stream_executor::ExecutorCache::GetOrCreate(stream_executor::StreamExecutorConfig const&, std::function<absl::lts_20230125::StatusOr<std::unique_ptr<stream_executor::StreamExecutor, std::default_delete<stream_executor::StreamExecutor> > > ()> const&)'
bazel-out/k8-opt/bin/xla/backends/interpreter/_objs/platform/platform.o:platform.cc:function std::_Rb_tree<int, std::pair<int const, stream_executor::ExecutorCache::Entry>, std::_Select1st<std::pair<int const, stream_executor::ExecutorCache::Entry> >, std::less<int>, std::allocator<std::pair<int const, stream_executor::ExecutorCache::Entry> > >::_M_erase(std::_Rb_tree_node<std::pair<int const, stream_executor::ExecutorCache::Entry> >*): error: undefined reference to 'stream_executor::ExecutorCache::Entry::~Entry()'
collect2: error: ld returned 1 exit status

Tried to run it manually from the container along with 'bazel clean --expunge' - no luck.

How to catch XlaRuntimeError in python

Trying to catch XlaRuntimeError in Python. But the error is not visible and the module that defined this exception is also not visible. We are unable to import the exception.
The presumed module that defined XlaRuntimeError is tensorflow.compiler.xla.xla_client.

Also, it seems currently, there is no way to import the exceptions thrown by XLA in python. The expectation is to have flexibility for the front end frameworks to import exceptions thrown by XLA.
eg.
from xla.exceptions import *

glibc malloc: process could monopolise host memory when compiling

Context

After compiling a deep and large computation graph, the process monopolises a large amount of memory (more than 2.5GB). This problem is encountered with jax==0.4.6, xla from tensorflow repo (43e9d313548ded301fa54f25a4192d3bcb123330).

It could be damaging when the process is long-running.

Cause

With the glibc allocator, allocation of small objects (malloc(size) with size<=1032) could make the process never return a large portion of the heap memory to the system.

Two simple cases that generate this behaviour:

  • Case 0: allocation of lot of small objects
  • Case 1: a small object is allocated with the heap size increasing, then this object is quickly deallocated

Case 0

{
    std::vector<std::vector<char>> v(1 << 20, std::vector<char>(1032, 0));
}

Case 1

{
    // we use a variable to ensure the deallocation occurs only at end of scope
    std::vector<char> o(1032 + 1, 0);
    std::vector<std::vector<char>> v(1 << 20, o);
    {
        std::vector<char> tmp(1032, 0);
    } 
}

Possible solutions

One solution is to invoke malloc_trim(...), but this is not portable. Another one is to use an alternative allocator (jemalloc 5.3.0 does not exhibit the faulty behaviour). Finally, I tried to set a maximum size to the data segment, unfortunately malloc turns unusable if only this option is configured (maybe there are env variables for malloc to configure too).

Remarks

  1. this behaviour is present b/c glibc is built by default with tcache enabled (tested with version: 2.37 and version: 2.31 on a x86_64 linux machine). It is because when deallocating a small object, the chunk is put in the tcache if there are available slots, thus preventing consolidation.
  2. with XLA_FLAG="--xla_cpu_use_xla_runtime", the following line generates heap memory distention when compiling a large and deep computation graph: service/copy_insertion.cc#L2176

When enabling low-precision model training with XLA, Why do half and bfloat16 perform differently?

When training a given model with different dtypes(half and bfloat16). I'd like to inspect the performance behaviors of model training with XLA.

For a specific part of the original TF Graph, XLA exhibits different splitting strategies under different data types, which means that the size of fused kernels and the number of MHLO instructions it contains are also different. I am very curious about this. I would like to ask what factors or HLO passes will affect the size of a fused kernel under these two data types.

Linking fails with undefined symbol initxla_extension

Building fails with.

ERROR: /home/petkantchin/ws/openxla/xla/repo/xla/python/BUILD:922:21: Linking xla/python/xla_extension.so failed: (Exit 1): gcc failed: error executing command /usr/bin/gcc @bazel-out/k8-opt/bin/xla/python/xla_extension.so-2.params
ld.lld: error: version script assignment of 'global' to symbol 'initxla_extension' failed: symbol not defined
ld.lld: error: version script assignment of 'global' to symbol 'init_xla_extension' failed: symbol not defined

My setup is

ubuntu 20.04
python 3.11.3
bazel 5.3.0

configure

export GCC_HOST_COMPILER_PATH=$(which gcc-10)
export CC=$GCC_HOST_COMPILER_PATH
export PYTHON_BIN_PATH=$(which python)
export PYTHON_LIB_PATH=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
export TF_NEED_ROCM=0
export TF_NEED_CUDA=0
export TF_DOWNLOAD_CLANG=0
export CC_OPT_FLAGS="-Wno-sign-compare"
python configure.py

build

bazel build //xla/...

On top of fad77e0.

`UniformIntDistribution` is exclusive in its upper bound

Copied from a TensorFlow ticket I raised and closed

Issue Type

Bug

Source

source (I actually use this binary but that isn't compiled by Google)

Tensorflow Version

2.8

Custom Code

Yes

OS Platform and Distribution

Ubuntu 20.04

Mobile device

n/a

Python version

n/a

Bazel version

2.4.1

GCC/Compiler version

Unknown, between 7.5 and 9.3

CUDA/cuDNN version

CUDA Version: 11.6

GPU model and memory

NVIDIA GeForce GTX 1070

Current Behaviour?

UniformIntDistribution appears to be inclusive in its lower bound, but exclusive in its upper bound. I think this is wrong for a number of reasons:

  • There is no simple way to get uniform samples that include the maximum U64.
  • Setting both bounds to be the same produces seemingly undefined behaviour*
  • Its sister function UniformFloatingPointDistribution appears to be inclusive in its upper bound. I came to this conclusion as sampling that function with equal bounds returns the common bound, rather than, say NaN. The two functions thus have incongruous behaviour.

*Example samples for equal bounds of 0 and 0

[
      11077253088097075545
    , 13897614985444391724
    , 18164676955841373932
    , 376028057765135569
    , 12082511028297777851
    , 1029834464974463124
    , 12380146588138085609
    , 8561165853724746330
    , 9319215302267380863
    , 16134235052671214906
]

Standalone code to reproduce the issue

void Test() {
    xla::XlaBuilder builder("");
    auto zero = xla::ConstantR0<uint64_t>(&builder, 0);
    zero = xla::BroadcastInDim(zero, {1000}, {});
    auto one = xla::ConstantR0<uint64_t>(&builder, 1);
    one = xla::BroadcastInDim(one, {1000}, {});
    auto key = xla::ConstantR0<uint64_t>(&builder, 0);
    auto state = xla::ConstantR0<uint64_t>(&builder, {0});
    auto shape = xla::ShapeUtil::MakeShape(xla::U64, {1000});

    auto sample = xla::UniformIntDistribution(
        key, state, xla::ThreeFryBitGenerator, zero, one, shape
    ).value;
    auto anyEqOne = xla::Any(xla::Eq(sample, one));

    auto computation = builder.Build(anyEqOne).ConsumeValueOrDie();
    auto res =
        xla::ClientLibrary::LocalClientOrDie()
        ->ExecuteAndTransfer(computation, {})
        .ConsumeValueOrDie()
        .ToString();

    std::cout << res << std::endl;
}

Relevant log output

pred[] false

Most latency of weight gradient all-reduce is exposed

With the latest implementation of Latency Hiding Scheduling, we observe that most weight gradient all-reduce latency is still exposed. (ref slide 6 and 7 at here)

Here is a brief summary of our observations.

  • dgrads and wgrads are still scheduled separately for DP-only running (ref slide 6). This is suboptimal because in this way dgrads never can be used to overlap all-reduce kernels.
  • wgrads calculation order does not match with wgrads all-reduce order.
    Screenshot 2023-02-21 at 2 29 57 PM
    You can see that encoder wgrads calculation never get overlapped with any AR kernel. I think we can have an order (for example due to control-dependency) for AR kernels, but AR order should match with wgrad calculation order, so that we can start an AR right after its corresponding wgrads are calculated.
  • Currently, I see TP-AR and wgrads AR are scheduled in the same stream (all of them are in stream 47). I think it's better to put them in two different streams to avoid their interference. Otherwise, they can block each other and create some exposure. For example, in 11B T5X running, dgrads and wgrads compute are scheduled together in bwd (ref slide 7), but all wgrads AR are scheduled at the very end of bwd. If we can put wgrads AR in a different stream, they can start earlier (right after their corresponding wgrads are calculated).

Currently, exposed weight gradient all-reduce latency is 15%-25% of T5X training step runtime, so it's very critical to fix this issue. Thanks a lot.

JAX built against XLA crashes on Pascal GPUs in WSL after commit 16e953a921c150dcaf012299a52d771e4bb425c2

I've encountered an issue where JAX built against XLA crashes on Pascal GPUs while running under Windows Subsystem for Linux (WSL). This issue seems to have started after the recent commit 16e953a.

Environment

XLA commit: 16e953a
JAX commit: ae4f1fcb66d13d6606525c0810bd9a0aba087f0c
GPU: I confirmed that it happens on 1060 and 1080 Ti
WSL version: 1.1.6.0
Ubuntu version: Ubuntu 22.04.2 LTS
NVIDIA driver version: 531.41
CUDA toolkit version: 11.8
cuDNN version: 8.8.1

Steps to reproduce:

Install latest version of Jax.
With CUDA 11:

jax.random.normal(jax.random.PRNGKey(0),(5,5))

With CUDA 12:
Example with CUDA 11 and, in addition, any attempt to create an array.

jax.numpy.zeros((5,5))

I also created issue in jax repo with this issue google/jax#15260

How to build in debug?

How should I go about building in debug? I am having trouble getting debug symbols into the binaries.
I tried

export CC=$(which gcc-10)
bazel build \
  --config=monolithic \
  --config=dbg \
  --spawn_strategy=local \
  --copt=-g \
  --linkopt=-g \
  --strip=never \
  //xla/tools:run_hlo_module

When I try loading the with gdb I get

$ gdb bazel-bin/xla/tools/run_hlo_module
GNU gdb (GDB) 12.1

....

(No debugging symbols found in bazel-bin/xla/tools/run_hlo_module)

My OS is Ubuntu 20.04.

GPU fusion kernel with transpose operator shows bad data locality of I/O tensors

We observed a case that XLA's fusion kernel is very slow due to bad data locality of DRAM access for I/O tensors. The case is from stable diffusion model.

The fused kernel has has theoretical read traffic of 2.64GB , and write traffic of 0.66GB.
But from nsight-compute profiler results, we observed the kernel has total 22.14GB DRAM read traffic, which is 8.4x theoretical traffic. This makes the fusion kernel quite slow, kernel time is about 44ms, which takes about 24% of total step time.

Seems that the data locality is due to transpose operator in the fusion graph.

Attaching the HLO graph and profiler results here: https://drive.google.com/drive/folders/14KNN4vLKtVNzXhYjnmv3_lJEAViDlgcE

Please check the fused_computation.1072 kernel in HLO dump file xla-nightly-float32-n4/module_0798.pjit_update.sm_8.0_gpu_after_optimizations

XLA from openxla/xla cannot be built as an bazel subrepository

I'm attempting to switch JAX to use the openxla/xla copy of XLA, but unfortunately it appears that repository cannot be built as a Bazel submodule, unlike the copy in the TensorFlow repository:

Simple repro:

$ git clone https://github.com/openxla/xla.git /tmp/xla
$ cd /tmp/xla && git checkout test
$ cd /tmp
$ mkdir myrepo && cd myrepo
$ cat > WORKSPACE <<EOF
local_repository(
   name = "org_tensorflow",
   path = "/tmp/xla",
)

load("@org_tensorflow//:workspace3.bzl", "tf_workspace3")
tf_workspace3()

load("@org_tensorflow//:workspace2.bzl", "tf_workspace2")
tf_workspace2()

load("@org_tensorflow//:workspace1.bzl", "tf_workspace1")
tf_workspace1()

load("@org_tensorflow//:workspace0.bzl", "tf_workspace0")
tf_workspace0()

EOF

$ touch BUILD
$ bazel build --experimental_repo_remote_exec @org_tensorflow//xla:util

ERROR: Error computing the main repository mapping: at /.../68526d8af9166f1b3b7478289d0f17d6/external/org_tensorflow/workspace2.bzl:14:6: Label '//tools/toolchains/embedded/arm-linux:arm_linux_toolchain_configure.bzl' is invalid because 'tools/toolchains/embedded/arm-linux' is not a package; perhaps you meant to put the colon here: '//:tools/toolchains/embedded/arm-linux/arm_linux_toolchain_configure.bzl'?

There are at least two things going wrong.

  • references like
load("@//tools/toolchains:cpus/aarch64/aarch64_compiler_configure.bzl", "aarch64_compiler_configure")

refer to the enclosing repository because of the leading @. But there is no such file in the enclosing repository: we want to get these from the XLA repository.

  • If we fix the first problem, then we next hit:
ERROR: ... :: Error ...: no such package '@tsl//tsl': The repository's path is "third_party/tsl" (absolute: "/.../myrepo/third_party/tsl") but it does not exist or is not a directory.

This is because workspace2.bzl uses local_repository to create a @tsl repository, but local_repository per the Bazel documentation requires "This must be a path to the directory containing the repository's WORKSPACE file. The path can be either absolute or relative to the main repository's WORKSPACE file.". That won't work if XLA is itself a submodule.

Please fix? This blocks using the OpenXLA repository from JAX.

Questions about VectorLoad in LLVM IR dumped by XLA

Hi, recently, I'm trying to analyze the performance of softmax with openXLA. Here is my simple test code.

from jax import numpy as jnp
import numpy as np
import jax

@jax.jit
def softmax(x):
    return jax.nn.softmax(x * 0.01)

arr = np.random.randn(16,4096,4096).astype(jnp.float16)

jax.profiler.start_trace("./log")
for i in range(100):
    softmax(arr)
jax.profiler.stop_trace()

After enabling XLA and setting the export XLA_FLAGS="--xla_dump_to=./output", I got two different LLVM IRs, one is ir-no-opt.ll, another is ir-with-opt.ll.

When comparing two LLVM IRs, I found that certain LLVM passes have converted scalar load to vector load as shown below.

ir-no-opt.ll

%37 = getelementptr inbounds half, ptr %arg2, i32 %linear_index1
  %38 = load half, ptr %37, align 2, !invariant.load !11

ir-with-opt.ll

 %arg37 = addrspacecast ptr %arg3 to ptr addrspace(1)
  %arg25 = addrspacecast ptr %arg2 to ptr addrspace(1)
  %arg13 = addrspacecast ptr %arg1 to ptr addrspace(1)
  %arg01 = addrspacecast ptr %arg0 to ptr addrspace(1)
  %0 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !8
  %1 = shl nuw nsw i32 %0, 10
  %2 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !9
  %3 = shl nuw nsw i32 %2, 2
  %linear_index_base = or i32 %1, %3
  %4 = lshr i32 %0, 2
  %5 = zext i32 %linear_index_base to i64
  %6 = getelementptr half, ptr addrspace(1) %arg25, i64 %5
  %7 = load <4 x half>, ptr addrspace(1) %6, align 8, !invariant.load !10
  %8 = extractelement <4 x half> %7, i32 0
  %9 = extractelement <4 x half> %7, i32 1
  %10 = extractelement <4 x half> %7, i32 2
  %11 = extractelement <4 x half> %7, i32 3

Question: Which LLVM passes did this work?

by the way, I have tried to dump all LLVM passes by setting export XLA_FLAGS="--xla_gpu_dump_llvmir" ๏ผŒ but didn't find any useful cues.

Any cues or ideas would be highly appreciated. Thanks.

Use of preprocessor directive in function-like macro argument list is undefined behavior

The #if directives nested in TF_ASSIGN_OR_RETURN is not conform to the C standard.

TF_ASSIGN_OR_RETURN(auto tensor_w,
CreateCudnnTensor(filter_dims, filter_strides, 'w',
input_type, vector_size, vector_dim,
/*is_virtual=*/false,
#if (CUDNN_VERSION >= 8800 && TF_ENABLE_CUDNN_FRONTEND)
tensor_ordering_type
#else
is_reordered_nchw_vect
#endif
));

And this is causing excessive compiling issue with msvc

warning C5101: use of preprocessor directive in function-like macro argument list is undefined behavior

FYI:
https://wiki.sei.cmu.edu/confluence/display/c/PRE32-C.+Do+not+use+preprocessor+directives+in+invocations+of+function-like+macros
https://stackoverflow.com/questions/64452244/c-c-how-should-preprocessor-directive-work-on-macros-argument-list

Rounding modifier required for instruction 'cvt'

Getting the following error when trying to run a code on a A100 which I was able to successfully run on a TPU. A more detail issue opened in jax lib.

Traceback (most recent call last):
  File "/home/keremturgutlu/t5x/t5x/train.py", line 835, in <module>
    config_utils.run(main)
  File "/home/keremturgutlu/t5x/t5x/config_utils.py", line 214, in run
    gin_utils.run(main)
  File "/home/keremturgutlu/t5x/t5x/gin_utils.py", line 129, in run
    app.run(
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/keremturgutlu/t5x/t5x/train.py", line 788, in main
    _main(argv)
  File "/home/keremturgutlu/t5x/t5x/train.py", line 830, in _main
    train_using_gin()
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/keremturgutlu/t5x/t5x/train.py", line 614, in train
    trainer.compile_train(dummy_batch)
  File "/home/keremturgutlu/t5x/t5x/trainer.py", line 538, in compile_train
    self._compiled_train_step = self._partitioner.compile(
  File "/home/keremturgutlu/t5x/t5x/partitioning.py", line 805, in compile
    return partitioned_fn.lower(*args).compile()
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/stages.py", line 600, in compile
    self._lowering.compile(**kw),
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 2816, in compile
    self._executable = UnloadedMeshExecutable.from_hlo(
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/interpreters/pxla.py", line 3028, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
    return backend_compile(backend, serialized_computation, compile_options,
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/envs/myenv/lib/python3.9/site-packages/jax/_src/dispatch.py", line 471, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: ptxas exited with non-zero error code 65280, output: ptxas /var/tmp/tempfile-gpu-4159c6f2-29193-5fb1040a6f770, line 234; error   : Rounding modifier required for instruction 'cvt'

The pjrt C api cannot be used in C

The pjrt_c_api.h header is not C compatible. I tried to compile a C source file with an empty main function and the compiler spat out these errors

external/xla/xla/stream_executor/tpu/libtftpu.h:24:40: error: expected specifier-qualifier-list before โ€˜decltypeโ€™
   24 | #define TFTPU_ADD_FN_IN_STRUCT(FnName) decltype(FnName)* FnName##Fn;
      |                                        ^~~~~~~~
external/xla/xla/stream_executor/tpu/libtftpu.h:52:3: note: in expansion of macro โ€˜TFTPU_ADD_FN_IN_STRUCTโ€™
   52 |   TFTPU_ADD_FN_IN_STRUCT(TfTpu_Initialize);
      |   ^~~~~~~~~~~~~~~~~~~~~~
In file included from external/xla/xla/pjrt/c/pjrt_c_api.h:23,
                 from test/pjrt_c_test.c:2:
external/xla/xla/stream_executor/tpu/c_api_decl.h:24:8: error: expected identifier or โ€˜(โ€™ before string constant
   24 | extern "C" {
      |        ^~~
In file included from test/pjrt_c_test.c:2:
external/xla/xla/pjrt/c/pjrt_c_api.h:783:42: error: unknown type name โ€˜PJRT_Chunkโ€™
  783 | typedef PJRT_Error* (*PJRT_SendCallback)(PJRT_Chunk* chunk,
      |                                          ^~~~~~~~~~
external/xla/xla/pjrt/c/pjrt_c_api.h:795:3: error: unknown type name โ€˜PJRT_SendCallbackโ€™
  795 |   PJRT_SendCallback send_callback;
      |   ^~~~~~~~~~~~~~~~~
external/xla/xla/pjrt/c/pjrt_c_api.h:820:23: error: expected โ€˜:โ€™, โ€˜,โ€™, โ€˜;โ€™, โ€˜}โ€™ or โ€˜__attribute__โ€™ before โ€˜=โ€™ token
  820 |   size_t num_send_ops = 0;
      |                       ^
In file included from external/xla/xla/pjrt/c/pjrt_c_api.h:19,
                 from test/pjrt_c_test.c:2:
external/xla/xla/pjrt/c/pjrt_c_api.h:26:3: error: โ€˜PJRT_ExecuteOptionsโ€™ {aka โ€˜struct PJRT_ExecuteOptionsโ€™} has no member named โ€˜launch_idโ€™
   26 |   offsetof(struct_type, last_field) + sizeof(((struct_type*)0)->last_field)
      |   ^~~~~~~~
external/xla/xla/pjrt/c/pjrt_c_api.h:30:38: note: in expansion of macro โ€˜PJRT_STRUCT_SIZEโ€™
   30 |   const size_t sname##_STRUCT_SIZE = PJRT_STRUCT_SIZE(sname, last_field);
      |                                      ^~~~~~~~~~~~~~~~
external/xla/xla/pjrt/c/pjrt_c_api.h:828:1: note: in expansion of macro โ€˜PJRT_DEFINE_STRUCT_TRAITSโ€™
  828 | PJRT_DEFINE_STRUCT_TRAITS(PJRT_ExecuteOptions, launch_id);
      | ^~~~~~~~~~~~~~~~~~~~~~~~~
In file included from test/pjrt_c_test.c:2:
external/xla/xla/pjrt/c/pjrt_c_api.h:26:63: error: โ€˜PJRT_ExecuteOptionsโ€™ {aka โ€˜struct PJRT_ExecuteOptionsโ€™} has no member named โ€˜launch_idโ€™
   26 |   offsetof(struct_type, last_field) + sizeof(((struct_type*)0)->last_field)
      |                                                               ^~
external/xla/xla/pjrt/c/pjrt_c_api.h:30:38: note: in expansion of macro โ€˜PJRT_STRUCT_SIZEโ€™
   30 |   const size_t sname##_STRUCT_SIZE = PJRT_STRUCT_SIZE(sname, last_field);
      |                                      ^~~~~~~~~~~~~~~~
external/xla/xla/pjrt/c/pjrt_c_api.h:828:1: note: in expansion of macro โ€˜PJRT_DEFINE_STRUCT_TRAITSโ€™
  828 | PJRT_DEFINE_STRUCT_TRAITS(PJRT_ExecuteOptions, launch_id);
      | ^~~~~~~~~~~~~~~~~~~~~~~~~
external/xla/xla/pjrt/c/pjrt_c_api.h:995:3: error: unknown type name โ€˜Int64Listโ€™
  995 |   Int64List dimensions;         // out
      |   ^~~~~~~~~
external/xla/xla/pjrt/c/pjrt_c_api.h:996:3: error: unknown type name โ€˜BoolListโ€™
  996 |   BoolList dynamic_dimensions;  // out
      |   ^~~~~~~~
external/xla/xla/pjrt/c/pjrt_c_api.h:1000:3: error: unknown type name โ€˜XLA_Layoutโ€™
 1000 |   XLA_Layout layout;  // out
      |   ^~~~~~~~~~
external/xla/xla/pjrt/c/pjrt_c_api.h:1146:3: error: unknown type name โ€˜PJRT_Chunkโ€™
 1146 |   PJRT_Chunk* chunk;
      |   ^~~~~~~~~~

The source file would compile correctly if I change it to c++.

Will XLA support dynamic shapes?

XLA will re-compile everytime a new shape appears. Will XLA support dynamic shapes? Is there a roadmap?
Or could you please tell me where to find the roadmap the XLA?

How to easily enable OpenXLA in TensorFlow?

Hi, we are interested in using OpenXLA (MHLO or StableHLO) in TensorFlow, is there an easy way to enable it easily for testing in TensorFlow? Either Python or C++ API is fine. We just want to quickly get a sense whether OpenXLA is faster or slower than traditional XLA.

triton seem to be not buildable outside of google

The Issues section is not opened in openxla/triton repo. So I report it here. This is related with openxla/triton and cannot be upstreamed.

I need to apply the following path before producing a jaxlib whl on windows

diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp
index b2913e5..ba78043 100644
--- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp
+++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp
@@ -25,7 +25,7 @@
 #include "llvm/IRReader/IRReader.h"
 #include "llvm/Linker/Linker.h"
 #include "llvm/Support/SourceMgr.h"
-#include "third_party/py/triton/google/find_cuda.h"
+// #include "third_party/py/triton/google/find_cuda.h"
 #include <dlfcn.h>
 #include <filesystem>
 #include <iterator>
@@ -164,8 +164,9 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
       }
       return std::filesystem::path(fileinfo.dli_fname);
     }();
-    static const auto runtime_path = (
-        fs::path(PathToLibdevice()) / "libdevice.10.bc");
+    static const auto runtime_path =
+        this_library_path.parent_path().parent_path() / "third_party" / "cuda" /
+        "lib" / "libdevice.10.bc";
     if (fs::exists(runtime_path)) {
       externLibs.try_emplace(libdevice, runtime_path.string());
     } else {

And it seems that the find_cuda.h is only available within google, according to
https://github.com/openxla/triton/blob/5b63e5b265a2ff9784b084d901b9feff5a4fc0fc/BUILD#L486-L488

Build error on main

Steps:
Follow developer guide: https://github.com/openxla/xla/blob/main/docs/developer_guide.md

Result:

...
WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/tensorflow/runtime/archive/0aaa6e679847a4eeb407136e7b0bcef93ec652e6.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
Loading: 0 packages loaded
...
WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/a52054cfa29d665c43141c66c20a7b8f7a96b546.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
Loading: 0 packages loaded

What is `PJRT_Executable_OptimizedProgram` used for?

Hi, I'm currently working on the implementation of PJRT_Executable_OptimizedProgram in the open-xla-pjrt-plugin, and I have a question about its expected behavior and use.

From looking at the uses of OptimizedProgram in xla, it seems that it is only used in GetHloModules, which is the function that defines the Python binding hlo_modules. That Python binding is then called in the deprecated function compiler_ir, which is called in the deprecated JAX tests that fail when PJRT_Executable_OptimizedProgram is not defined:

Since the Python function and tests that use OptimizedProgram are deprecated, are there plans to also deprecate this method from the PJRT plugin?

Also, other than for debugging by dumping the IR at some point through the lowering pipeline, what is this method used for?

@skye

CUDA illegal memory access on H100 from triton gemm

Running t5x on H100, a CUDA illegal memory access(IMA) error was hit. The error can be reproduced by running the attached HLO:

bazel-bin/tensorflow/compiler/xla/tools/run_hlo_module --platform=gpu ./train.txt

CUDA coredump file points to the location of IMA as:

CUDA Exception: Warp Out-of-range Address
The exception was triggered at PC 0x7f36c09b68b0
[Current focus set to CUDA kernel 0, grid 1532, block (8,0,0), thread (32,0,0), device 0, sm 0, warp 2, lane 0]
#0  0x00007f36c09b6970 in triton_gemm_dot_1<<<(16384,1,1),(128,1,1)>>> ()

train.txt

, so disabling triton gemm helps to work around.

Build error: โ€˜ASSERT_OK_AND_ASSIGNโ€™ was not declared in this scope

I have a build error locally right now somehow:

xla/pjrt/gpu/se_gpu_pjrt_client_test.cc:366:3: error: โ€˜ASSERT_OK_AND_ASSIGNโ€™ was not declared in this scope; did you mean โ€˜TF_ASSERT_OK_AND_ASSIGNโ€™?
  366 |   ASSERT_OK_AND_ASSIGN(
      |   ^~~~~~~~~~~~~~~~~~~~
      |   TF_ASSERT_OK_AND_ASSIGN

I don't know where is this ASSERT_OK_AND_ASSIGN macro supposed to be defined?
There are 3 hits in the codebase right now (all in the same file):

$ git grep ASSERT_OK_AND_ASSIGN | grep -v TF_ASSERT_OK_AND_ASSIGN
xla/pjrt/gpu/se_gpu_pjrt_client_test.cc:  ASSERT_OK_AND_ASSIGN(
xla/pjrt/gpu/se_gpu_pjrt_client_test.cc:  ASSERT_OK_AND_ASSIGN(
xla/pjrt/gpu/se_gpu_pjrt_client_test.cc:  ASSERT_OK_AND_ASSIGN(

Incorrect device pointer usage in multi-GPU XLA implementation for FP8 GEMM scaling factor

When attempting to run FP8 cublasLt GEMM on multiple GPUs in the Hopper system, you encountered a CUDA illegal memory access (IMA) error. This error can be replicated by executing the provided script(cuda_ima.py) on two GPUs:

python cuda_ima.py --d 4096

In the above command, the --d flag specifies the dimension of the GEMM operation. It is worth noting that larger matrix sizes are more likely to trigger the error.

The compute-sanitizer has detected that the illegal memory access (IMA) occurs when accessing the scaling factor pointer. The error message provided is as follows:

========= Invalid __global__ read of size 4 bytes
=========     at 0x7f8adde55d00 in sm90_xmma_gemm_e5m2e4m3f32_e5m2e4m3f32_f32_tn_n_tilesize128x128x128_warpgroupsize2x1x1_algo2_execute_segment_k_off_kernel__5x_cublas

The scaling factor pointers on 2 GPUs are 0x7f8adda55d00 (device 0)and 0x7f8adde55d00 (device 1) respectively. But cuBlasLt log shows

  computeDesc=[
      computeType=COMPUTE_32F
      scaleType=R_32F
      aScalePointer=0x7f8adda55d00 <-- Device 0
      bScalePointer=0x7f8adda55d00 <-- Device 0
      cScalePointer=0x7f8adde55d00 <-- Device 1
      dScalePointer=0x7f8adde55d00 <-- Device 1
  ]

where the cScalePointer and dScalePointer are clearly wrong.

XLA incorrectly optimizes scalar bool*float to select()

google/jax#15492

Repro from JAX:

In [1]: print(jax.jit(lambda x, y: x*y).lower(f, inf).as_text(dialect="hlo"))
HloModule jit__lambda_, entry_computation_layout={(pred[],f32[])->f32[]}

ENTRY main.5 {
  Arg_0.1 = pred[] parameter(0), sharding={replicated}
  convert.3 = f32[] convert(Arg_0.1)
  Arg_1.2 = f32[] parameter(1), sharding={replicated}
  ROOT multiply.4 = f32[] multiply(convert.3, Arg_1.2)
}

In[2]: print(jax.jit(lambda x, y: x*y).lower(f, inf).compile().as_text())
HloModule jit__lambda_, entry_computation_layout={(pred[],f32[])->f32[]}, allow_spmd_sharding_propagation_to_output={true}

ENTRY %main.5 (Arg_0.1: pred[], Arg_1.2: f32[]) -> f32[] {
  %Arg_0.1 = pred[] parameter(0), sharding={replicated}
  %Arg_1.2 = f32[] parameter(1), sharding={replicated}
  %constant.1 = f32[] constant(0)
  ROOT %select = f32[] select(pred[] %Arg_0.1, f32[] %Arg_1.2, f32[] %constant.1), metadata={op_name="jit(<lambda>)/jit(main)/mul" source_file="<ipython-input-3-2c0439f60e63>" source_line=1}
}

i.e., during optimization, XLA has changed a multiplication by a scalar bool to a select. This is incorrect because it does not have correct inf/nan semantics. 0 * inf should be nan, not 0.

PR process improvements

  • Run PRs against JAX tests on GitHub
  • Run PRs against Tensorflow tests on GitHub
  • Automated notification of rollbacks on PRs
  • Documentation on Copybara quirks
  • Unify triggering of CI under Github Actions. Right now Kokoro triggers automatically in some cases, but a single label should be all that's needed to trigger CI, and this should happen automatically for all members of the organization.

Mac M1 build failure with `--config=monolithic`: no such target '@local_config_rocm//rocm:hipfft'

I'm running the docker build on a mac M1, which isn't documented but I managed to get some of the way there by adding --platform linux/x86_64/v8 to my docker run command. I've used the default configuration, but with option --config=monolithic in

docker exec xla bazel build --test_output=all --spawn_strategy=sandboxed --nocheck_visibility --config=monolithic //xla/...

I'm seeing the error

ERROR: /xla/xla/stream_executor/rocm/BUILD:205:11: no such target '@local_config_rocm//rocm:hipfft': target 'hipfft' not declared in package 'rocm' defined by /root/.cache/bazel/_bazel_root/e4ab50d61a21943a819d1e092972a817/external/local_config_rocm/rocm/BUILD and referenced by '//xla/stream_executor/rocm:hipfft_if_static'
ERROR: Analysis of target '//xla/stream_executor/rocm:hipfft_if_static' failed; build aborted: Analysis failed

which is surprising since I have used the default config so am not expecting rocm to be involved.

The full logs are

Extracting Bazel installation...
Starting local Bazel server and connecting to it...
INFO: Options provided by the client:
  Inherited 'common' options: --isatty=0 --terminal_columns=80
INFO: Reading rc options for 'build' from /xla/.bazelrc:
  Inherited 'common' options: --experimental_repo_remote_exec
INFO: Reading rc options for 'build' from /etc/bazel.bazelrc:
  'build' options: --action_env=DOCKER_CACHEBUSTER=1680717592238941475 --host_action_env=DOCKER_HOST_CACHEBUSTER=1680717592321081121
INFO: Reading rc options for 'build' from /xla/.bazelrc:
  'build' options: --define framework_shared_object=true --define tsl_protobuf_header_only=true --define=use_fast_cpp_protos=true --define=allow_oversize_protos=true --spawn_strategy=standalone -c opt --announce_rc --define=grpc_no_ares=true --noincompatible_remove_legacy_whole_archive --enable_platform_specific_config --define=with_xla_support=true --config=short_logs --config=v2 --define=no_aws_support=true --define=no_hdfs_support=true --experimental_cc_shared_library --experimental_link_static_libraries_once=false --incompatible_enforce_config_setting_visibility
INFO: Reading rc options for 'build' from /xla/.tf_configure.bazelrc:
  'build' options: --action_env PYTHON_BIN_PATH=/usr/bin/python3 --action_env PYTHON_LIB_PATH=/usr/lib/python3/dist-packages --python_path=/usr/bin/python3 --config=nonccl --test_tag_filters=-benchmark-test,-no_oss,-oss_excluded,-gpu,-oss_serial --build_tag_filters=-benchmark-test,-no_oss,-oss_excluded,-gpu
INFO: Reading rc options for 'build' from /xla/.bazelrc:
  'build' options: --deleted_packages=tensorflow/compiler/mlir/tfrt,tensorflow/compiler/mlir/tfrt/benchmarks,tensorflow/compiler/mlir/tfrt/jit/python_binding,tensorflow/compiler/mlir/tfrt/jit/transforms,tensorflow/compiler/mlir/tfrt/python_tests,tensorflow/compiler/mlir/tfrt/tests,tensorflow/compiler/mlir/tfrt/tests/ir,tensorflow/compiler/mlir/tfrt/tests/analysis,tensorflow/compiler/mlir/tfrt/tests/jit,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_tfrt,tensorflow/compiler/mlir/tfrt/tests/lhlo_to_jitrt,tensorflow/compiler/mlir/tfrt/tests/tf_to_corert,tensorflow/compiler/mlir/tfrt/tests/tf_to_tfrt_data,tensorflow/compiler/mlir/tfrt/tests/saved_model,tensorflow/compiler/mlir/tfrt/transforms/lhlo_gpu_to_tfrt_gpu,tensorflow/core/runtime_fallback,tensorflow/core/runtime_fallback/conversion,tensorflow/core/runtime_fallback/kernel,tensorflow/core/runtime_fallback/opdefs,tensorflow/core/runtime_fallback/runtime,tensorflow/core/runtime_fallback/util,tensorflow/core/tfrt/eager,tensorflow/core/tfrt/eager/backends/cpu,tensorflow/core/tfrt/eager/backends/gpu,tensorflow/core/tfrt/eager/core_runtime,tensorflow/core/tfrt/eager/cpp_tests/core_runtime,tensorflow/core/tfrt/gpu,tensorflow/core/tfrt/run_handler_thread_pool,tensorflow/core/tfrt/runtime,tensorflow/core/tfrt/saved_model,tensorflow/core/tfrt/graph_executor,tensorflow/core/tfrt/saved_model/tests,tensorflow/core/tfrt/tpu,tensorflow/core/tfrt/utils
INFO: Found applicable config definition build:short_logs in file /xla/.bazelrc: --output_filter=DONT_MATCH_ANYTHING
INFO: Found applicable config definition build:v2 in file /xla/.bazelrc: --define=tf_api_version=2 --action_env=TF2_BEHAVIOR=1
INFO: Found applicable config definition build:nonccl in file /xla/.bazelrc: --define=no_nccl_support=true
INFO: Found applicable config definition build:monolithic in file /xla/.bazelrc: --define framework_shared_object=false --define tsl_protobuf_header_only=false --experimental_link_static_libraries_once=false
INFO: Found applicable config definition build:linux in file /xla/.bazelrc: --host_copt=-w --copt=-Wno-all --copt=-Wno-extra --copt=-Wno-deprecated --copt=-Wno-deprecated-declarations --copt=-Wno-ignored-attributes --copt=-Wno-array-bounds --copt=-Wunused-result --copt=-Werror=unused-result --copt=-Wswitch --copt=-Werror=switch --copt=-Wno-error=unused-but-set-variable --define=PREFIX=/usr --define=LIBDIR=$(PREFIX)/lib --define=INCLUDEDIR=$(PREFIX)/include --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include --cxxopt=-std=c++17 --host_cxxopt=-std=c++17 --config=dynamic_kernels --experimental_guard_against_concurrent_changes
INFO: Found applicable config definition build:dynamic_kernels in file /xla/.bazelrc: --define=dynamic_loaded_kernels=true --copt=-DAUTOLOAD_DYNAMIC_KERNELS
Loading: 
Loading: 0 packages loaded
DEBUG: /xla/third_party/repo.bzl:132:14: 
Warning: skipping import of repository 'tf_runtime' because it already exists.
DEBUG: /xla/third_party/repo.bzl:132:14: 
Warning: skipping import of repository 'llvm-raw' because it already exists.
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/tensorflow/runtime/archive/0aaa6e679847a4eeb407136e7b0bcef93ec652e6.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 0 packages loaded
WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/99fc6ec34cc1b023a837830d266fbbd523a509c3.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
Loading: 0 packages loaded
Loading: 0 packages loaded
Loading: 5 packages loaded
    currently loading: xla/python/tpu_driver/client ... (14 packages)
WARNING: Download from https://mirror.bazel.build/github.com/bazelbuild/rules_cc/archive/081771d4a0e9d7d3aa0eed2ef389fa4700dfb23e.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
Analyzing: 2545 targets (101 packages loaded, 0 targets configured)
Analyzing: 2545 targets (120 packages loaded, 1588 targets configured)
WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/openxla/stablehlo/archive/a2c36eb790c5e70109cf3c2b55f43dcdc779727e.zip failed: class java.io.FileNotFoundException GET returned 404 Not Found
Analyzing: 2545 targets (156 packages loaded, 1896 targets configured)
WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/pybind/pybind11_abseil/archive/2c4932ed6f6204f1656e245838f4f5eae69d2e29.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
Analyzing: 2545 targets (226 packages loaded, 2989 targets configured)
ERROR: /xla/xla/stream_executor/rocm/BUILD:205:11: no such target '@local_config_rocm//rocm:hipfft': target 'hipfft' not declared in package 'rocm' defined by /root/.cache/bazel/_bazel_root/e4ab50d61a21943a819d1e092972a817/external/local_config_rocm/rocm/BUILD and referenced by '//xla/stream_executor/rocm:hipfft_if_static'
INFO: Repository com_google_ortools instantiated at:
  /xla/WORKSPACE:19:15: in <toplevel>
  /xla/workspace2.bzl:84:21: in workspace
  /xla/workspace2.bzl:48:20: in _tf_repositories
  /xla/third_party/repo.bzl:136:21: in tf_http_archive
Repository rule _tf_http_archive defined at:
  /xla/third_party/repo.bzl:89:35: in <toplevel>
INFO: Repository com_google_benchmark instantiated at:
  /xla/WORKSPACE:19:15: in <toplevel>
  /xla/workspace2.bzl:70:19: in workspace
  /root/.cache/bazel/_bazel_root/e4ab50d61a21943a819d1e092972a817/external/tsl/workspace2.bzl:613:28: in workspace
  /root/.cache/bazel/_bazel_root/e4ab50d61a21943a819d1e092972a817/external/tsl/workspace2.bzl:43:14: in _initialize_third_party
  /root/.cache/bazel/_bazel_root/e4ab50d61a21943a819d1e092972a817/external/tsl/third_party/benchmark/workspace.bzl:9:20: in repo
  /root/.cache/bazel/_bazel_root/e4ab50d61a21943a819d1e092972a817/external/tsl/third_party/repo.bzl:136:21: in tf_http_archive
Repository rule _tf_http_archive defined at:
  /root/.cache/bazel/_bazel_root/e4ab50d61a21943a819d1e092972a817/external/tsl/third_party/repo.bzl:89:35: in <toplevel>
ERROR: Analysis of target '//xla/stream_executor/rocm:hipfft_if_static' failed; build aborted: Analysis failed
INFO: Elapsed time: 139.788s
INFO: 0 processes.
FAILED: Build did NOT complete successfully (228 packages loaded, 9941 targets configured)
FAILED: Build did NOT complete successfully (228 packages loaded, 9941 targets configured)

hlo_sharding_util_test fails to build with g++ 11.3.0

Changes in a781713 introduce this error:

xla/hlo/utils/hlo_sharding_util_test.cc:167:60: error: call of overloaded 'TileAssignment(<brace-enclosed initializer list>)' is ambiguous
  167 |   EXPECT_EQ(result.tile_assignment(), TileAssignment({4, 1}));
      |                                                            ^
In file included from ./xla/hlo/ir/hlo_sharding.h:34,
                 from ./xla/hlo/ir/hlo_instruction.h:48,
                 from ./xla/hlo/ir/hlo_computation.h:35,
                 from ./xla/hlo/utils/hlo_sharding_util.h:28,
                 from xla/hlo/utils/hlo_sharding_util_test.cc:16:
./xla/hlo/ir/tile_assignment.h:173:12: note: candidate: 'xla::TileAssignment::TileAssignment(absl::lts_20230125::Span<const long int>)'
  173 |   explicit TileAssignment(absl::Span<const int64_t> dims)
      |            ^~~~~~~~~~~~~~
./xla/hlo/ir/tile_assignment.h:172:12: note: candidate: 'xla::TileAssignment::TileAssignment(xla::IotaTileAssignment)'
  172 |   explicit TileAssignment(IotaTileAssignment iota) : iota_(std::move(iota)) {}
      |            ^~~~~~~~~~~~~~
In file included from external/com_google_googletest/googletest/include/gtest/gtest.h:67,
                 from external/com_google_googletest/googlemock/include/gmock/internal/gmock-internal-utils.h:50,
                 from external/com_google_googletest/googlemock/include/gmock/gmock-actions.h:145,
                 from ./xla/test.h:42,
                 from xla/hlo/utils/hlo_sharding_util_test.cc:23:
xla/hlo/utils/hlo_sharding_util_test.cc: In member function 'virtual void xla::hlo_sharding_util::{anonymous}::HloShardingUtilTest_GetManualSubgroupSharding_ManualOnly_Test::TestBody()':

when compiling with

$ g++ --version
g++ (Ubuntu 11.3.0-1ubuntu1~22.04.1) 11.3.0

Very slow constant folding of very-large integer arrays, for instance when working with sparse matrices

This relates to the JAX issue #14655: copying in various details from that thread below.

I've got a use case where I'd like to store the nonzero entries of a very large sparse matrix, and then access them later during a machine learning training loop. Unfortunately, using JIT compilation results in constant-folding of this array, making it extremely slow on large problems. Here's an MWE that runs on my laptop and captures the typical behavior:

import jax
import jax.numpy as jnp
import jax.experimental.sparse as sparse
from jax.experimental.sparse import BCOO

n = 10000000

def build_sparse_linear_operator():
    nonzeroes = sparse.eye(n).indices # shape (n,2)
    def product(other):
        matrix = BCOO((jnp.ones(n),nonzeroes), shape=(n,n), indices_sorted=True, unique_indices=True)
        return matrix @ other
    return product

operator = build_sparse_linear_operator()

def fn(x):
    return operator(jnp.ones(n) / x).sum()

fn(1.0) # executes in 0.1s
jax.jit(fn)(1.0) # executes in almost one minute

Calling the function without JIT executes in about a tenth of a second, but calling it with JIT takes almost a minute. On larger problems in the codebase which prompted this MWE, I have had it crash due to running out of memory after about an hour. This produces warnings similar to the following:

Constant folding an instruction is taking > 8s:

  slice.22 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above).

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

Switching to the following JAX code bypasses the issue:

def build_sparse_linear_operator():
    nonzeroes = sparse.eye(n).indices # shape (n,2)
    def product(other):
        nz = _optimization_barrier(nonzeroes)
        matrix = BCOO((jnp.ones(n), nz), shape=(n,n), indices_sorted=True, unique_indices=True)
        return matrix @ other
    return product

From this, the problem seems to be that XLA for some reason tries to constant-fold nonzeroes, in spite of its large size, and then runs out of resources while trying to do so. I haven't yet been able to replicate this for float arrays, so I'm not sure whether or not the issue is intger-specific.

Suboptimal SPMD partitioning

I'm trying to understand the behavior of the partitioner on simple examples.

Using a simple examples with 8 devices:

print(jax.devices())
# [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3),
#  CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]

And placing the input data as:

    y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))
    z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))

I get the expected sharding:

Screenshot 2023-07-06 at 10 12 45 AM

When running a simple dot, I get the expected sharding of the output:

Screenshot 2023-07-06 at 10 14 04 AM

Each device will compute a dot producing a slice of the output:

ENTRY main.4_spmd {
  param = f32[2048,8192]{1,0} parameter(0), sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
  param.1 = f32[8192,4096]{1,0} parameter(1), sharding={devices=[1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
  ROOT dot = f32[2048,4096]{1,0} dot(param, param.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[4,2]0,1,2,3,4,5,6,7}, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="est_sharding.py" source_line=27}
}

Now if I force a resharding of the output:

def f(y, z):
    d = jnp.dot(y, z)
    return jax.lax.with_sharding_constraint(d, sharding.reshape(2, 4))

Screenshot 2023-07-06 at 10 18 11 AM

Ideally we'd likely still compute "small" slices of dot output based on the placement of the input data, and then move the data around to form the new slice, however right now we instead all-reduce everything and run the full dot on every device before resharding.

Here is the module before SPMD partitioning:

HloModule jit_f, entry_computation_layout={(f32[8192,8192]{1,0}, f32[8192,8192]{1,0})->f32[8192,8192]{1,0}}, allow_spmd_sharding_propagation_to_output={true}

ENTRY main.5 {
  Arg_0.1 = f32[8192,8192]{1,0} parameter(0), sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
  Arg_1.2 = f32[8192,8192]{1,0} parameter(1), sharding={devices=[1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
  dot.3 = f32[8192,8192]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,4]0,1,2,3,4,5,6,7}, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
  ROOT copy = f32[8192,8192]{1,0} copy(dot.3), sharding={devices=[2,4]0,1,2,3,4,5,6,7}, metadata={op_name="jit(f)/jit(main)/sharding_constraint[sharding=GSPMDSharding({devices=[2,4]0,1,2,3,4,5,6,7}) resource_env=ResourceEnv(Mesh(device_ids=[], axis_names=()), ()) unconstrained_dims={}]" source_file="test_sharding.py" source_line=31}
}

And after:

HloModule jit_f, entry_computation_layout={(f32[2048,8192]{1,0}, f32[8192,4096]{1,0})->f32[4096,2048]{1,0}}, allow_spmd_sharding_propagation_to_output={true}

ENTRY main.5_spmd {
  param = f32[2048,8192]{1,0} parameter(0), sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
  all-gather.1 = f32[8192,8192]{1,0} all-gather(param), channel_id=2, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
  param.1 = f32[8192,4096]{1,0} parameter(1), sharding={devices=[1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
  all-gather = f32[8192,8192]{1,0} all-gather(param.1), channel_id=1, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
  dot = f32[8192,8192]{1,0} dot(all-gather.1, all-gather), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
  constant = s32[8]{0} constant({0, 0, 0, 0, 4096, 4096, 4096, 4096}), metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
  partition-id = u32[] partition-id()
  dynamic-slice = s32[1]{0} dynamic-slice(constant, partition-id), dynamic_slice_sizes={1}, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
  reshape = s32[] reshape(dynamic-slice), metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
  constant.1 = s32[8]{0} constant({0, 2048, 4096, 6144, 0, 2048, 4096, 6144}), metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
  dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.1, partition-id), dynamic_slice_sizes={1}, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
  reshape.1 = s32[] reshape(dynamic-slice.1), metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
  dynamic-slice.2 = f32[4096,2048]{1,0} dynamic-slice(dot, reshape, reshape.1), dynamic_slice_sizes={4096,2048}, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
  ROOT copy.1 = f32[4096,2048]{1,0} copy(dynamic-slice.2), metadata={op_name="jit(f)/jit(main)/sharding_constraint[sharding=GSPMDSharding({devices=[2,4]0,1,2,3,4,5,6,7}) resource_env=ResourceEnv(Mesh(device_ids=[], axis_names=()), ()) unconstrained_dims={}]" source_file="test_sharding.py" source_line=31}
}

The final module before codegen looks like this (omitting the fusion details, but we see the all-reduce and the dot):

ENTRY main.5_spmd {
  constant.6 = f32[] constant(0)
  call = f32[8192,8192]{1,0} call(constant.6), to_apply=parallel_broadcast
  copy.2 = f32[8192,8192]{1,0} copy(call)
  param = f32[2048,8192]{1,0} parameter(0), sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
  partition-id.1 = u32[] partition-id()
  replica-id = u32[] replica-id()
  fusion.2 = f32[8192,8192]{1,0} fusion(copy.2, param, partition-id.1, replica-id), kind=kLoop, calls=fused_computation.2
  all-reduce = f32[8192,8192]{1,0} all-reduce(fusion.2), channel_id=2, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
  copy.3 = f32[8192,8192]{1,0} copy(call)
  param.1 = f32[8192,4096]{1,0} parameter(1), sharding={devices=[1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
  fusion.1 = f32[8192,8192]{1,0} fusion(copy.3, param.1, partition-id.1, replica-id), kind=kLoop, calls=fused_computation.1
  all-reduce.1 = f32[8192,8192]{1,0} all-reduce(fusion.1), channel_id=1, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=add.1
  dot = f32[8192,8192]{1,0} dot(all-reduce, all-reduce.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
  ROOT call.1 = f32[4096,2048]{1,0} call(dot, partition-id.1), to_apply=parallel_fusion
}

[CUDA] Kernel launch failure due to large Triton grid dimensions

See: google/jax#16286

JAX code:

import jax
from jax import jit, vmap
import jax.numpy as jnp

@jit
def f(adj, mat):
    return adj @ mat / jnp.sum(adj, axis=1)[:, jnp.newaxis]

adj = jnp.ones((1024 * 100, 10, 10), dtype=bool)
mat = jnp.ones((1024 * 100, 10, 100), dtype=float)

jax.jit(vmap(f))(adj, mat)

XLA error:

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to launch CUDA kernel: triton_gemm_dot_0 with block dimensions: 128x1x1 and grid dimensions: 4x1x102400 and shared memory size: 65536: CUDA_ERROR_INVALID_VALUE: invalid argument

Those grid dimensions are indeed unreasonably large.

HLO dump:

HloModule jit_f, entry_computation_layout={(pred[102400,10,10]{2,1,0}, f32[102400,10,100]{2,1,0})->f32[102400,10,100]{2,1,0}}, allow_spmd_sharding_propagation_to_output={true}

region_0.3 {
  Arg_0.4 = s32[] parameter(0)
  Arg_1.5 = s32[] parameter(1)
  ROOT add.6 = s32[] add(Arg_0.4, Arg_1.5), metadata={op_name="jit(f)/jit(main)/vmap(jit(f))/reduce_sum[axes=(2,)]" source_file="experimental/users/phawkins/jax/grid.py" source_line=13}
}

f.7 {
  Arg_0.8 = pred[102400,10,10]{2,1,0} parameter(0)
  convert.11 = f32[102400,10,10]{2,1,0} convert(Arg_0.8), metadata={op_name="jit(f)/jit(main)/vmap(jit(f))/convert_element_type[new_dtype=float32 weak_type=False]" source_file="experimental/users/phawkins/jax/grid.py" source_line=13}
  Arg_1.9 = f32[102400,10,100]{2,1,0} parameter(1)
  dot.12 = f32[102400,10,100]{2,1,0} dot(convert.11, Arg_1.9), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, metadata={op_name="jit(f)/jit(main)/vmap(jit(f))/dot_general[dimension_numbers=(((2,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="experimental/users/phawkins/jax/grid.py" source_line=13}
  convert.13 = s32[102400,10,10]{2,1,0} convert(Arg_0.8), metadata={op_name="jit(f)/jit(main)/vmap(jit(f))/convert_element_type[new_dtype=int32 weak_type=False]" source_file="experimental/users/phawkins/jax/grid.py" source_line=13}
  constant.10 = s32[] constant(0)
  reduce.14 = s32[102400,10]{1,0} reduce(convert.13, constant.10), dimensions={2}, to_apply=region_0.3, metadata={op_name="jit(f)/jit(main)/vmap(jit(f))/reduce_sum[axes=(2,)]" source_file="experimental/users/phawkins/jax/grid.py" source_line=13}
  reshape.15 = s32[102400,10,1]{2,1,0} reshape(reduce.14), metadata={op_name="jit(f)/jit(main)/vmap(jit(f))/broadcast_in_dim[shape=(102400, 10, 1) broadcast_dimensions=(0, 1)]" source_file="experimental/users/phawkins/jax/grid.py" source_line=13}
  convert.16 = f32[102400,10,1]{2,1,0} convert(reshape.15), metadata={op_name="jit(f)/jit(main)/vmap(jit(f))/convert_element_type[new_dtype=float32 weak_type=False]" source_file="experimental/users/phawkins/jax/grid.py" source_line=13}
  broadcast.17 = f32[102400,10,1]{2,1,0} broadcast(convert.16), dimensions={0,1,2}, metadata={op_name="jit(f)/jit(main)/vmap(jit(f))/div" source_file="experimental/users/phawkins/jax/grid.py" source_line=13}
  reshape.18 = f32[102400,10]{1,0} reshape(broadcast.17), metadata={op_name="jit(f)/jit(main)/vmap(jit(f))/div" source_file="experimental/users/phawkins/jax/grid.py" source_line=13}
  broadcast.19 = f32[102400,10,100]{2,1,0} broadcast(reshape.18), dimensions={0,1}, metadata={op_name="jit(f)/jit(main)/vmap(jit(f))/div" source_file="experimental/users/phawkins/jax/grid.py" source_line=13}
  ROOT divide.20 = f32[102400,10,100]{2,1,0} divide(dot.12, broadcast.19), metadata={op_name="jit(f)/jit(main)/vmap(jit(f))/div" source_file="experimental/users/phawkins/jax/grid.py" source_line=13}
}

ENTRY main.22 {
  Arg_0.1 = pred[102400,10,10]{2,1,0} parameter(0), sharding={replicated}
  Arg_1.2 = f32[102400,10,100]{2,1,0} parameter(1), sharding={replicated}
  ROOT call.21 = f32[102400,10,100]{2,1,0} call(Arg_0.1, Arg_1.2), to_apply=f.7
}

`UniformFloatingPointDistribution` incorrect behaviour at infinite bounds

Copied from a TensorFlow issue I raised and closed

Issue Type

Bug

Source

source (I actually use this binary but that isn't compiled by Google))

Tensorflow Version

2.8

Custom Code

Yes

OS Platform and Distribution

Ubuntu 20.04

Mobile device

n/a

Python version

n/a

Bazel version

2.4.1

GCC/Compiler version

Unknown, between 7.5 and 9.3

CUDA/cuDNN version

CUDA Version: 11.6

GPU model and memory

NVIDIA GeForce GTX 1070

Current Behaviour?

`UniformFloatingPointDistribution` produces the following samples for the following bounds

1) -inf 0 -> nan
2) 0 +inf -> inf
3) -inf -inf -> nan
4) +inf +inf -> nan

I believe 1) is incorrect and inconsistent with 2), which is correct. I believe 3) and 4) should be -inf and +inf respectively, since any sample between one +inf and another +inf will be +inf, and since there's no way to specify _different_ +infs (so that bounds are different), I think it makes sense to assume the bounds of +inf and +inf are different, and the same for -inf and -inf.

Standalone code to reproduce the issue

#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/prng.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"

void Test() {
    xla::XlaBuilder builder("");
    auto posinf = xla::MaxValue(&builder, xla::F64);
    auto neginf = xla::MinValue(&builder, xla::F64);
    auto zero = xla::ConstantR0<double>(&builder, 0.0);
    auto key = xla::ConstantR0<uint64_t>(&builder, 0);
    auto state = xla::ConstantR0<uint64_t>(&builder, {0});
    auto shape = xla::ShapeUtil::MakeShape(xla::F64, {});

    auto sample = xla::UniformFloatingPointDistribution(
        key, state, xla::ThreeFryBitGenerator, neginf, 0, shape // replace bounds as appropriate
    ).value;

    auto computation = builder.Build(sample).ConsumeValueOrDie();
    auto res =
        xla::ClientLibrary::GetOrCreateLocalClient(tensorflow::GPUMachineManager())  // I'm also seeing this on CPU
        .ConsumeValueOrDie()
        ->ExecuteAndTransfer(computation, {})
        .ConsumeValueOrDie()
        .ToString();

    std::cout << res << std::endl;
}

Relevant log output

f64[] -nan
f64[] inf
f64[] -nan
f64[] -nan

for each test case 1-4

Advertising bazel-contrib/rules_cuda

rules_cuda is a community effort for adding cuda support for bazel.

It currently supports:

  1. linux and windows
  2. nvcc+gcc or nvcc+msvc where cuda >= 10.0 or standalone clang cuda
  3. relocatable device code (rdc) and device link time optimiaztion (dlto)

It is also purely implemented in starlark. Toolchains can be configured with cc_toolchain style DSL. Attacking TF repo might be too ambitious at the moment, I think OpenXLA might suit for it.

How to prevent instructions from being CSE?

Hi! Thanks for the repo! I am currently experimenting with XLA optimization passes and am interested in trying out customized optimization. In order to implement these optimizations, we need to add some instructions immediately prior to the SPMD pass. However, we have observed that these instructions are being deleted by the subsequent CSE passes.

We were wondering if there are any flags that could be used to disable optimization on specific instructions. We appreciate any guidance or suggestions you may have regarding this matter.

Batched TriangularSolve of singular matrix returns incorrect results

Here's a short repro in JAX, that more or less passes the inputs directly to TriangularSolveOp:

import jax
import jax.numpy as jnp

def solve(x, y):
  return jax.lax.linalg.triangular_solve(x, y, left_side=True)

x = jnp.array([[1., 1.], [0., 0.]])
y = jnp.array([[1], [1.]])

print(solve(x, y))
# [[-inf]
#  [ inf]]
print(solve(x[None], y[None])[0])
# [[0.]
#  [1.]]

I would expect the second output to be identical to the first. This appears to be the root cause of the issue reported in google/jax#15429

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.