def test_default():
sess = ort.InferenceSession('rvm_mobilenetv3_fp32.onnx')
# sess = ort.InferenceSession('rvm_mobilenetv3_1920_default.onnx')
rec = [np.zeros([1, 1, 1, 1], dtype=np.float32) ] * 4 # 必须用模型一样的 dtype
downsample_ratio = np.array([0.25], dtype=np.float32) # 必须是 FP32
src = cv2.imread("1.jpg")
src = cv2.resize(src, (1920, 1080))
# src 张量是 [B, C, H, W] 形状
src = np.transpose(src, (2, 0, 1)).astype(np.float32)
src = np.expand_dims(src, 0)
print(src.shape)
fgr, pha, *rec = sess.run([], {
'src': src,
'r1i': rec[0],
'r2i': rec[1],
'r3i': rec[2],
'r4i': rec[3],
'downsample_ratio': downsample_ratio
})
pha = (pha * 255).astype(np.uint8)
pha = np.squeeze(pha, 0)
pha = np.transpose(pha, [1, 2, 0])
fgr = (fgr * 255).astype(np.uint8)
print(fgr.shape)
fgr = np.squeeze(fgr, 0)
fgr = np.transpose(fgr, [1, 2, 0])
cv2.imshow("pha", pha)
cv2.imshow("FGR", fgr)
cv2.waitKey(0)