Comments (1)
问题解决了~
`import cv2
import numpy as np
import os
import time
import mxnet as mx
from mxnet.contrib import onnx as onnx_mxnet
import onnx
from onnx import numpy_helper
import onnxruntime as ort
from onnx import helper
from onnx import TensorProto
#mxnet2onnx
def mxnet2onnx_test():
sym = './model-r34-amf/model-symbol.json'
params = './model-r34-amf/model-0000.params'
#NCHW
input_shape = [(1,3,112,112)]
onnx_file = './model-r34-amf/onnx/modelnew_onnx.onnx'
#返回转换后的onnx模型的路径
converted_model_path = onnx_mxnet.export_model(sym, params, input_shape, np.float, onnx_file)
#Check the model
onnx.checker.check_model(onnx.load(onnx_file))
print('The model is checked!')
#onnx model inferrence
def onnx_inferred_demo():
onnx_file = './model-r34-amf/onnx/modelnew2_onnx.onnx'
image_path='0.jpg'
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#img=(img-127.5)
#img=img*0.0078125
img = np.transpose( img, (2,0,1) ) #HWC->CHW
ort_session = ort.InferenceSession(onnx_file)
input_name = ort_session.get_inputs()[0].name #'data'
outputs = ort_session.get_outputs()[0].name #'fc1'
input_blob = np.expand_dims(img, axis=0).astype(np.float32) #NCHW
out = ort_session.run([outputs], input_feed={input_name: input_blob})
print(out[0])
def createGraphMemberMap(graph_member_list):
member_map=dict();
for n in graph_member_list:
member_map[n.name]=n;
return member_map
#修改onnx的结构
def onnx_modify_demo():
onnx_file = './model-r34-amf/onnx/modelnew_onnx.onnx'
model = onnx.load(onnx_file)
print("load done~")
graph = model.graph
initializer_map = createGraphMemberMap(graph.initializer)
input_map = createGraphMemberMap(graph.input) #'data' 'scalar_op1' 'scalar_op2'
node_map = createGraphMemberMap(graph.node) #'_minusscalar0' '_mulscalar0'
output_map = createGraphMemberMap(graph.output) #'fc1'
#print(input_map)
#print(input_map.keys())
#====data, double to float====
print("data, double to float")
name='data'
#print(input_map[name])
graph.input.remove(input_map[name])
new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, [1,3,112,112])
graph.input.extend([new_nv])
#input_map = createGraphMemberMap(graph.input)
#print(input_map[name])
#====Sub, double to float====
print("Sub, double to float")
#修改input的dtype
input_map = createGraphMemberMap(graph.input)
name='scalar_op2'
graph.input.remove(input_map[name])
new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, [1])
graph.input.extend([new_nv])
#同时要修改initializer中初始值的type
initializer_map = createGraphMemberMap(graph.initializer)
graph.initializer.remove(initializer_map[name])
new_nv = helper.make_tensor(name, TensorProto.FLOAT, [1],[127.5])
graph.initializer.extend([new_nv])
#====Mul, double to float====
print("Mul, double to float")
input_map = createGraphMemberMap(graph.input)
name='scalar_op3'
graph.input.remove(input_map[name])
new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, [1])
graph.input.extend([new_nv])
initializer_map = createGraphMemberMap(graph.initializer)
graph.initializer.remove(initializer_map[name])
new_nv = helper.make_tensor(name, TensorProto.FLOAT, [1],[0.0078125])
graph.initializer.extend([new_nv])
#====FC1, double to float====
print("FC1, double to float")
output_map = createGraphMemberMap(graph.output)
name='fc1'
graph.output.remove(output_map[name])
new_nv = helper.make_tensor_value_info(name, TensorProto.FLOAT, [1,512])
graph.output.extend([new_nv])
#===PReLU slop===
#C--->C*1*1
print("PReLU slop, C--->C*1*1")
for input_name in input_map.keys():
if input_name.endswith('relu0_gamma') or input_name.endswith('relu1_gamma'):
print(input_name)
input_shape = input_map[input_name].type.tensor_type.shape.dim
input_dim_val=input_shape[0].dim_value
print(input_dim_val)
graph.input.remove(input_map[input_name])
new_nv = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [input_dim_val,1,1])
graph.input.extend([new_nv])
initializer_map = createGraphMemberMap(graph.initializer)
graph.initializer.remove(initializer_map[input_name])
weight_array = numpy_helper.to_array(initializer_map[input_name])
print(weight_array.shape)
weight_array=weight_array.astype(np.float) #np.float32与initializer中不匹配
#===可以查看权重输出!!!
b=[]
for w in weight_array:
b.append(w)
new_nv = helper.make_tensor(input_name, TensorProto.FLOAT, [input_dim_val,1,1], b)
graph.initializer.extend([new_nv])
for node in model.graph.node:
if (node.op_type == "BatchNormalization"):
for attr in node.attribute:
if (attr.name == "spatial"):
attr.i = 1
onnx.save(model, './model-r34-amf/onnx/modelnew2_onnx.onnx')
if name == 'main':
# mxnet2onnx_test() #==mxnet2onnx
# onnx_modify_demo() #===onnx修改===
onnx_inferred_demo() #===onnx前向推导===
`
from arcface_retinaface_mxnet2onnx.
Related Issues (17)
- 运行mxnet2onnx_demo报错 HOT 2
- onnx.onnx_cpp2py_export.checker.ValidationError HOT 1
- retinaface output shape HOT 6
- The onnx input shape is fixed regardless of scales like mxnet? HOT 2
- No conversion function registered for op type null yet
- RetinaFace-R50 转换onnx报错,请问这个如何解决呢 HOT 4
- 下载链接反了
- Arcface Onnx转换问题修改了fc1的数据类型最后显示 dimension mismatch? HOT 1
- MXNetError: Error in operator pre_fc1: Shape inconsistent
- input_map HOT 1
- How to convert the output to coordinates ? HOT 1
- Official Retinaface mnet25 models conversion HOT 20
- Have you tried to convert alignment algorithm? HOT 4
- 您好,请问有LResNet34E-IR,ArcFace@ms1m-refine-v1转onnx之后的模型吗? HOT 1
- Modify json SoftMaxActivition to softmax also convert error HOT 2
- TensorProto (tensor name: rf_c3_upsamplingroi) should contain one and only one value field. HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from arcface_retinaface_mxnet2onnx.