max_poolXd through decomposition is expensive in thunder. torch executor should be able to run those as a single aten call on fwd as well as bwd via a custom grad_transform
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast()
def augmented_forward_fn(arg):
# arg: "cuda:0 f16[32, 3, 224, 224]"
t0 = prims.pad(arg, -float('inf'), [(0, 0, 0), (0, 0, 0), (1, 1, 0), (1, 1, 0)]) # t0: "cuda:0 f16[32, 3, 226, 226]"
t1 = ltorch.arange(9, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1: "cuda:0 i64[9]"
# t1 = prims.iota(9, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1: "cuda:0 i64[9]"
t2 = prims.broadcast_in_dim(t1, [9, 1], [0]) # t2: "cuda:0 i64[9, 1]"
t3 = prims.broadcast_in_dim(t1, [1, 9], [1]) # t3: "cuda:0 i64[1, 9]"
t4 = prims.broadcast_in_dim(t2, (9, 9), (0, 1)) # t4: "cuda:0 i64[9, 9]"
t5 = prims.broadcast_in_dim(t3, (9, 9), (0, 1)) # t5: "cuda:0 i64[9, 9]"
t6 = prims.eq(t4, t5) # t6: "cuda:0 b8[9, 9]"
t7 = prims.convert_element_type(t6, dtypes.float16) # t7: "cuda:0 f16[9, 9]"
t8 = prims.reshape(t7, (1, 9, 1, 3, 3)) # t8: "cuda:0 f16[1, 9, 1, 3, 3]"
t9 = prims.broadcast_in_dim(t8, (3, 9, 1, 3, 3), (0, 1, 2, 3, 4)) # t9: "cuda:0 f16[3, 9, 1, 3, 3]"
t10 = prims.reshape(t9, (27, 1, 3, 3)) # t10: "cuda:0 f16[27, 1, 3, 3]"
t11 = prims.convolution(t0, t10, None, (2,), (0,), (1,), False, (0, 0), 3) # t11: "cuda:0 f16[32, 27, 112, 112]"
t12 = prims.reshape(t11, (32, 3, 9, 112, 112)) # t12: "cuda:0 f16[32, 3, 9, 112, 112]"
t13 = prims.convert_element_type(t12, dtypes.float32) # t13: "cuda:0 f32[32, 3, 9, 112, 112]"
t14 = prims.amax(t13, (2,)) # t14: "cuda:0 f32[32, 3, 112, 112]"
t15 = prims.convert_element_type(t14, dtypes.float16) # t15: "cuda:0 f16[32, 3, 112, 112]"
return {'output': t15, 'flat_args': [arg], 'flat_output': (t15,)}, ((t10, t14, t13), (0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 3, 32, 27, 112, 112, 32, 3, 9, 112, 112, 2))
i239 = operator.neg(i86) # i239
# i239 = prims.neg(i86) # i239
del i86
i257 = operator.neg(i27) # i257
# i257 = prims.neg(i27) # i257
del i27
i258 = operator.neg(i28) # i258
# i258 = prims.neg(i28) # i258
del i28
i259 = operator.neg(i30) # i259
# i259 = prims.neg(i30) # i259
del i30
i260 = operator.neg(i31) # i260
# i260 = prims.neg(i31) # i260
del i31
i261 = operator.neg(i33) # i261
# i261 = prims.neg(i33) # i261
del i33
i262 = operator.neg(i34) # i262
# i262 = prims.neg(i34) # i262
del i34
i263 = operator.neg(i36) # i263
# i263 = prims.neg(i36) # i263
del i36
i264 = operator.neg(i37) # i264
# i264 = prims.neg(i37) # i264
del i37
t303 = torch.unsqueeze(t19, 2) # t303
# t303 = ltorch.unsqueeze(t19, 2) # t303
# t303 = prims.broadcast_in_dim(t19, [32, 3, 1, 2, 2], [0, 1, 3, 4]) # t303
del t19
t220 = Tensor.expand(t303, [32, 3, 1, 2, 2]) # t220
# t220 = ltorch.expand(t303, [32, 3, 1, 2, 2]) # t220
# t220 = prims.broadcast_in_dim(t303, (32, 3, 1, 2, 2), (0, 1, 2, 3, 4)) # t220
del t303
t221 = Tensor.expand(t220, (i104, i105, i106, i107, i108)) # t221
# t221 = ltorch.expand(t220, (i104, i105, i106, i107, i108)) # t221
# t221 = prims.broadcast_in_dim(t220, (i104, i105, i106, i107, i108), (0, 1, 2, 3, 4)) # t221
del t220
t233 = torch.permute(t15, (1, 0, 2, 3)) # t233
# t233 = ltorch.permute(t15, (1, 0, 2, 3)) # t233
# t233 = prims.transpose(t15, (1, 0, 2, 3)) # t233
del t15
t234 = torch.reshape(t233, [1, i91, 9, 3, 3]) # t234
# t234 = ltorch.reshape(t233, [1, i91, 9, 3, 3]) # t234
# t234 = prims.reshape(t233, (1, i91, 9, 3, 3)) # t234
del t233
t235 = torch.permute(t234, (1, 0, 2, 3, 4)) # t235
# t235 = ltorch.permute(t234, (1, 0, 2, 3, 4)) # t235
# t235 = prims.transpose(t234, (1, 0, 2, 3, 4)) # t235
del t234
t236 = torch.reshape(t235, [3, 9, 3, 3]) # t236
# t236 = ltorch.reshape(t235, [3, 9, 3, 3]) # t236
# t236 = prims.reshape(t235, (3, 9, 3, 3)) # t236
del t235
[t230, t282] = nvFusion0(i10, i104, i105, i106, i107, i108, i109, i9, t0, t17, t21, t221)
# t18 = prims.convert_element_type(t17, dtypes.float32) # t18
# t282 = prims.pad(t0, 0.0, [(0, 0, 0), (0, 0, 0), (i9, 3, 0), (i10, 3, 0)]) # t282
# t217 = prims.convert_element_type(t21, dtypes.float32) # t217
# t218 = prims.broadcast_in_dim(t217, [32, 3, 1, 2, 2], [0, 1, 3, 4]) # t218
# t219 = prims.broadcast_in_dim(t218, (i104, i105, i106, i107, i108), (0, 1, 2, 3, 4)) # t219
# t222 = prims.eq(t18, t221) # t222
# t223 = prims.sum(t222, (i109,)) # t223
# t224 = prims.broadcast_in_dim(t223, [32, 3, 1, 2, 2], [0, 1, 3, 4]) # t224
# t225 = prims.convert_element_type(t222, dtypes.float32) # t225
# t226 = prims.mul(t219, t225) # t226
# t227 = prims.broadcast_in_dim(t224, (32, 3, 9, 2, 2), (0, 1, 2, 3, 4)) # t227
# t228 = prims.convert_element_type(t227, dtypes.float32) # t228
# t229 = prims.div(t226, t228) # t229
# t230 = prims.convert_element_type(t229, dtypes.float16) # t230
del i10, i104, i105, i106, i107, i108, i109, i9, t0, t17, t21, t221
t283 = torch.permute(t282, (1, 0, 2, 3)) # t283
# t283 = ltorch.permute(t282, (1, 0, 2, 3)) # t283
# t283 = prims.transpose(t282, (1, 0, 2, 3)) # t283
del t282
t284 = torch.reshape(t283, [i16, 3, 32, 11, 11]) # t284
# t284 = ltorch.reshape(t283, [i16, 3, 32, 11, 11]) # t284
# t284 = prims.reshape(t283, (i16, 3, 32, 11, 11)) # t284
del t283
t285 = torch.permute(t284, (1, 0, 2, 3, 4)) # t285
# t285 = ltorch.permute(t284, (1, 0, 2, 3, 4)) # t285
# t285 = prims.transpose(t284, (1, 0, 2, 3, 4)) # t285
del t284
t286 = torch.reshape(t285, [3, 32, 11, 11]) # t286
# t286 = ltorch.reshape(t285, [3, 32, 11, 11]) # t286
# t286 = prims.reshape(t285, (3, 32, 11, 11)) # t286
del t285
t231 = torch.reshape(t230, (i97, i98, i99, i100)) # t231
# t231 = ltorch.reshape(t230, (i97, i98, i99, i100)) # t231
# t231 = prims.reshape(t230, (i97, i98, i99, i100)) # t231
del t230, i97, i98, i99, i100
t232 = torch_pad_prim_impl(t231, 0.0, [(0, 0, 0), (0, 0, 0), (0, 0, 1), (0, 0, 1)]) # t232
del t231
t237 = torch.flip(t236, (2, 3)) # t237
# t237 = ltorch.flip(t236, (2, 3)) # t237
# t237 = prims.flip(t236, (2, 3)) # t237
del t236
t238 = torch.convolution(t232, t237, None, (1,), [2, 2], (i87, i87), False, (i89, i90), i91) # t238
# t238 = ltorch.convolution(t232, t237, None, (1,), [2, 2], (i87, i87), False, (i89, i90), i91) # t238
# t238 = prims.convolution(t232, t237, None, (1,), [2, 2], (i87, i87), False, (i89, i90), i91) # t238
del t232, t237, i87, i89, i90, i91
[t270] = nvFusion1(i239, i257, i258, i259, i260, i261, i262, i263, i264, t2, t238)
# t241 = prims.pad(t238, 0.0, [(0, 0, 0), (0, 0, 0), (i239, 0, 0), (i239, 0, 0)]) # t241
# t265 = prims.pad(t241, 0.0, [(i257, i258, 0), (i259, i260, 0), (i261, i262, 0), (i263, i264, 0)]) # t265
# t266 = prims.slice(t265, [0, 0, 0, 0], [32, 3, 3, 3], [1, 1, 1, 1]) # t266
# t267 = prims.slice(t266, [0, 0, 0, 0], [32, 3, 3, 3], [1, 1, 1, 1]) # t267
# t268 = prims.slice(t267, [0, 0, 0, 0], [32, 3, 3, 3], [1, 1, 1, 1]) # t268
# t269 = prims.slice(t268, [0, 0, 0, 0], [32, 3, 3, 3], [1, 1, 1, 1]) # t269
# t270 = prims.where(t2, t269, 0.0) # t270
del i239, i257, i258, i259, i260, i261, i262, i263, i264, t2, t238
t287 = torch.permute(t270, (1, 0, 2, 3)) # t287
# t287 = ltorch.permute(t270, (1, 0, 2, 3)) # t287
# t287 = prims.transpose(t270, (1, 0, 2, 3)) # t287
del t270
t288 = torch.convolution(t286, t287, None, (i11, i12), (0,), (i7, i8), False, (i14, i15), i16) # t288
# t288 = ltorch.convolution(t286, t287, None, (i11, i12), (0,), (i7, i8), False, (i14, i15), i16) # t288
# t288 = prims.convolution(t286, t287, None, (i11, i12), (0,), (i7, i8), False, (i14, i15), i16) # t288
del t286, t287, i11, i12, i7, i8, i14, i15, i16
t289 = torch.permute(t288, (1, 0, 2, 3)) # t289
# t289 = ltorch.permute(t288, (1, 0, 2, 3)) # t289
# t289 = prims.transpose(t288, (1, 0, 2, 3)) # t289
del t288
return (None, t289)
We can have pooling layers as we prim as well, but I don't think that's a necessity at this point.