Comments (8)
It is difficult to answer without knowing the model. Could you try with simplify=False just to make sure this function is not involved. Then what do you mean by 15%? It is 15% of a batch of observations are not equal to the expected outputs? So it means 85% are correct and the model is probably correctly converted.
from onnx.
First of all, thank you for your answer. The 15% I mentioned here means that the original accuracy of the model is 95%, but the reasoning accuracy of the converted onnx model is only 15%, which is the same result after I tried simplify = False.Whether I have dynamic batch enabled or not I can only use the batchsize or batchsize=1 defined during the model conversion phase to get the same accuracy as the original model;Here's my inference code
import argparse
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from dataset import tobacco_dataset
from Model.swin_transformer_DSU import *
import onnx
import onnxruntime as ort
from PIL import Image
from torchvision.transforms import transforms
parser = argparse.ArgumentParser("Training")
parser.add_argument('--image_size', type=int, default=448)
parser.add_argument('--batch_size', type=int, default=32)
args = parser.parse_args()
img_size = args.image_size
results = dict()
knownloader = DataLoader(dataset=known_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
unknownloader = DataLoader(dataset=unknown_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
F = SwinTransformer(img_size=448, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
# window_size=14, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False
# use_checkpoint=True
) # the feature dim is 1024
net = STAN_OSFGR(F, num_classes=8)
net.load_state_dict(torch.load(weight_path[i]))
net = net.cuda()
net.eval()
onnx_model_name = r"./Rmodel.onnx"
onnx_model = onnx.load(onnx_model_name)
onnx.checker.check_model(onnx_model)
session = ort.InferenceSession(onnx_model_name, providers=['CUDAExecutionProvider'])
input_name = session.get_inputs()[0].name
output_name=session.get_outputs()[0].name
correct_ori, total_ori = 0, 0
correct_fp16, total_fp16 = 0, 0
correct_fp32, total_fp32 = 0, 0
correct_onnx,total_onnx=0,0
torch.cuda.empty_cache()
_pred_u, _labels = [], []
_pred_ori, _pred_uori= [], []
_pred_onnx,_pred_uonnx=[],[]
batch_num = 0
for batch_idx, (data, labels, _) in enumerate(tqdm(knownloader, desc='know')):
img, labels = data.cuda(), labels.cuda()
with torch.no_grad():
outputs_LSTM = net(img)
batch_num += 1
predictions_ori = outputs_LSTM.data.max(1)[1]
total_ori += labels.size(0)
correct_ori += (predictions_ori == labels.data).sum()
_pred_ori.append(outputs_LSTM.data.cpu().numpy())
_labels.append(labels.data.cpu().numpy())
input_data = img.cpu().numpy()
outputs_onnx= session.run(None, {input_name: input_data})
predictions_onnx=np.argmax(outputs_onnx[0],axis=1)
total_onnx+=labels.size(0)
correct_onnx+=(predictions_onnx==labels.data.cpu().numpy()).sum()
_pred_onnx.append(outputs_onnx[0])
acc_ori = float(correct_ori) * 100. / float(total_ori)
acc_onnx=float(correct_onnx)*100. /float(total_onnx)
print('net_ori Acc: {:.5f}'.format(acc_ori))
print('net_onnx Acc: {:.5f}'.format(acc_onnx))
from onnx.
So the model is probably right since 85% of the observations are correctly classified by onnx. Is it possible to check the highest probabilities instead of the label? Sometimes, it two classes have similar probabilities, a small difference may change the top one.
from onnx.
I mean 85% of the observations are misclassified, but I use batchsize=1 to deduce something close to the original model, and I can't understand what's wrong with that
from onnx.
Can you check the discrepancies for the probabilities on the failing observations?
from onnx.
Do you mean the difference between the probability of failure predicted by the onnx model and the predicted value of the original model?
from onnx.
between the probability of predicted by the onnx model and the predicted value of the original model for the observations which fail the first test.
from onnx.
Okay, thanks. I'll check it out
from onnx.
Related Issues (20)
- Shape Inference crash on Gemm
- onnx.checker crashes on STFT
- onnx.checker crashes on LayerNormalization
- Importing `onnx==1.16.1` causes a segmentation fault on MacOS 11 (Big Sur) HOT 7
- Compatibility with numpy>=2.0
- version_converter can't convert model to opset21 HOT 1
- ONNX check SIGSEV when checking the attached model HOT 3
- Tensor and Integer Comparison Problem in ONNX Export HOT 4
- reporting a vulnerability of download_model function HOT 11
- How can "then_branch" and "else_branch" of "if-op" support input from out of subgraph? HOT 11
- ONNX checker does not validate C90 identifier compatibility. HOT 2
- onnx.reference: Cast to float8e4m3fnuz treats +/-inf wrong HOT 2
- Why Exporation into ONNX will cost way much higher RAM? HOT 1
- External Data Conversion is not saving most data in page aligned 4k sizes. Therefore, mmap support disabled for these initializers
- `numpy_helper.from_array` fails with `onnx._custom_element_types` `int4` and `uint4` HOT 3
- onnx.utils.extract_model failed to extract subgraph from whisper-tiny-decoder HOT 2
- pytroch2onnx CheckerError: Node .. input 0 is marked single but has an empty string in the graph
- Create 1.16.2 release? HOT 1
- ai.onnxruntime.OrtException: Unexpected number of requestedOutputs & pinnedOutputs, expected [1,1) found 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from onnx.