I'm trying to understand the behavior of the partitioner on simple examples.
Using a simple examples with 8 devices:
print(jax.devices())
# [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3),
# CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
And placing the input data as:
y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))
z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))
I get the expected sharding:
When running a simple dot
, I get the expected sharding of the output:
Each device will compute a dot
producing a slice of the output:
ENTRY main.4_spmd {
param = f32[2048,8192]{1,0} parameter(0), sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
param.1 = f32[8192,4096]{1,0} parameter(1), sharding={devices=[1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
ROOT dot = f32[2048,4096]{1,0} dot(param, param.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[4,2]0,1,2,3,4,5,6,7}, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="est_sharding.py" source_line=27}
}
Now if I force a resharding of the output:
def f(y, z):
d = jnp.dot(y, z)
return jax.lax.with_sharding_constraint(d, sharding.reshape(2, 4))
Ideally we'd likely still compute "small" slices of dot
output based on the placement of the input data, and then move the data around to form the new slice, however right now we instead all-reduce everything and run the full dot on every device before resharding.
Here is the module before SPMD partitioning:
HloModule jit_f, entry_computation_layout={(f32[8192,8192]{1,0}, f32[8192,8192]{1,0})->f32[8192,8192]{1,0}}, allow_spmd_sharding_propagation_to_output={true}
ENTRY main.5 {
Arg_0.1 = f32[8192,8192]{1,0} parameter(0), sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
Arg_1.2 = f32[8192,8192]{1,0} parameter(1), sharding={devices=[1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
dot.3 = f32[8192,8192]{1,0} dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,4]0,1,2,3,4,5,6,7}, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
ROOT copy = f32[8192,8192]{1,0} copy(dot.3), sharding={devices=[2,4]0,1,2,3,4,5,6,7}, metadata={op_name="jit(f)/jit(main)/sharding_constraint[sharding=GSPMDSharding({devices=[2,4]0,1,2,3,4,5,6,7}) resource_env=ResourceEnv(Mesh(device_ids=[], axis_names=()), ()) unconstrained_dims={}]" source_file="test_sharding.py" source_line=31}
}
And after:
HloModule jit_f, entry_computation_layout={(f32[2048,8192]{1,0}, f32[8192,4096]{1,0})->f32[4096,2048]{1,0}}, allow_spmd_sharding_propagation_to_output={true}
ENTRY main.5_spmd {
param = f32[2048,8192]{1,0} parameter(0), sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
all-gather.1 = f32[8192,8192]{1,0} all-gather(param), channel_id=2, replica_groups={{0,2,4,6},{1,3,5,7}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
param.1 = f32[8192,4096]{1,0} parameter(1), sharding={devices=[1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
all-gather = f32[8192,8192]{1,0} all-gather(param.1), channel_id=1, replica_groups={{0,1},{2,3},{4,5},{6,7}}, dimensions={1}, use_global_device_ids=true, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
dot = f32[8192,8192]{1,0} dot(all-gather.1, all-gather), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
constant = s32[8]{0} constant({0, 0, 0, 0, 4096, 4096, 4096, 4096}), metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
partition-id = u32[] partition-id()
dynamic-slice = s32[1]{0} dynamic-slice(constant, partition-id), dynamic_slice_sizes={1}, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
reshape = s32[] reshape(dynamic-slice), metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
constant.1 = s32[8]{0} constant({0, 2048, 4096, 6144, 0, 2048, 4096, 6144}), metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
dynamic-slice.1 = s32[1]{0} dynamic-slice(constant.1, partition-id), dynamic_slice_sizes={1}, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
reshape.1 = s32[] reshape(dynamic-slice.1), metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
dynamic-slice.2 = f32[4096,2048]{1,0} dynamic-slice(dot, reshape, reshape.1), dynamic_slice_sizes={4096,2048}, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
ROOT copy.1 = f32[4096,2048]{1,0} copy(dynamic-slice.2), metadata={op_name="jit(f)/jit(main)/sharding_constraint[sharding=GSPMDSharding({devices=[2,4]0,1,2,3,4,5,6,7}) resource_env=ResourceEnv(Mesh(device_ids=[], axis_names=()), ()) unconstrained_dims={}]" source_file="test_sharding.py" source_line=31}
}
The final module before codegen looks like this (omitting the fusion details, but we see the all-reduce and the dot):
ENTRY main.5_spmd {
constant.6 = f32[] constant(0)
call = f32[8192,8192]{1,0} call(constant.6), to_apply=parallel_broadcast
copy.2 = f32[8192,8192]{1,0} copy(call)
param = f32[2048,8192]{1,0} parameter(0), sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}
partition-id.1 = u32[] partition-id()
replica-id = u32[] replica-id()
fusion.2 = f32[8192,8192]{1,0} fusion(copy.2, param, partition-id.1, replica-id), kind=kLoop, calls=fused_computation.2
all-reduce = f32[8192,8192]{1,0} all-reduce(fusion.2), channel_id=2, replica_groups={{0,2,4,6},{1,3,5,7}}, use_global_device_ids=true, to_apply=add
copy.3 = f32[8192,8192]{1,0} copy(call)
param.1 = f32[8192,4096]{1,0} parameter(1), sharding={devices=[1,2,4]0,2,4,6,1,3,5,7 last_tile_dim_replicate}
fusion.1 = f32[8192,8192]{1,0} fusion(copy.3, param.1, partition-id.1, replica-id), kind=kLoop, calls=fused_computation.1
all-reduce.1 = f32[8192,8192]{1,0} all-reduce(fusion.1), channel_id=1, replica_groups={{0,1},{2,3},{4,5},{6,7}}, use_global_device_ids=true, to_apply=add.1
dot = f32[8192,8192]{1,0} dot(all-reduce, all-reduce.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(f)/jit(main)/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="test_sharding.py" source_line=27}
ROOT call.1 = f32[4096,2048]{1,0} call(dot, partition-id.1), to_apply=parallel_fusion
}