Giter VIP home page Giter VIP logo

Comments (8)

xadupre avatar xadupre commented on July 22, 2024

It is difficult to answer without knowing the model. Could you try with simplify=False just to make sure this function is not involved. Then what do you mean by 15%? It is 15% of a batch of observations are not equal to the expected outputs? So it means 85% are correct and the model is probably correctly converted.

from onnx.

ffxxjj avatar ffxxjj commented on July 22, 2024

First of all, thank you for your answer. The 15% I mentioned here means that the original accuracy of the model is 95%, but the reasoning accuracy of the converted onnx model is only 15%, which is the same result after I tried simplify = False.Whether I have dynamic batch enabled or not I can only use the batchsize or batchsize=1 defined during the model conversion phase to get the same accuracy as the original model;Here's my inference code

import argparse
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from dataset import tobacco_dataset
from Model.swin_transformer_DSU import *
import onnx
import onnxruntime as ort
from PIL import Image
from torchvision.transforms import transforms


parser = argparse.ArgumentParser("Training")

parser.add_argument('--image_size', type=int, default=448)
parser.add_argument('--batch_size', type=int, default=32)

args = parser.parse_args()
img_size = args.image_size
results = dict()




    knownloader = DataLoader(dataset=known_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)
    unknownloader = DataLoader(dataset=unknown_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2)


    F = SwinTransformer(img_size=448, patch_size=4, in_chans=3, num_classes=1000,
                        embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32],
                        window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                        # window_size=14, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                        drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                        norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
                        use_checkpoint=False
                        # use_checkpoint=True
                        )  # the feature dim is 1024

    net = STAN_OSFGR(F, num_classes=8)

    net.load_state_dict(torch.load(weight_path[i]))

    net = net.cuda()
    net.eval()
    onnx_model_name = r"./Rmodel.onnx"
    onnx_model = onnx.load(onnx_model_name)
    onnx.checker.check_model(onnx_model)

    session = ort.InferenceSession(onnx_model_name, providers=['CUDAExecutionProvider'])
    input_name = session.get_inputs()[0].name
    output_name=session.get_outputs()[0].name


    correct_ori, total_ori = 0, 0
    correct_fp16, total_fp16 = 0, 0
    correct_fp32, total_fp32 = 0, 0
    correct_onnx,total_onnx=0,0

    torch.cuda.empty_cache()


    _pred_u, _labels = [], []
    _pred_ori, _pred_uori= [], []
    _pred_onnx,_pred_uonnx=[],[]

    batch_num = 0
    for batch_idx, (data, labels, _) in enumerate(tqdm(knownloader, desc='know')):
        img, labels = data.cuda(), labels.cuda()
        with torch.no_grad():
             outputs_LSTM = net(img)
        batch_num += 1
        predictions_ori = outputs_LSTM.data.max(1)[1]
        total_ori += labels.size(0)
        correct_ori += (predictions_ori == labels.data).sum()
        _pred_ori.append(outputs_LSTM.data.cpu().numpy())
        _labels.append(labels.data.cpu().numpy())

        input_data = img.cpu().numpy()
        outputs_onnx= session.run(None, {input_name: input_data})
        predictions_onnx=np.argmax(outputs_onnx[0],axis=1)
        total_onnx+=labels.size(0)
        correct_onnx+=(predictions_onnx==labels.data.cpu().numpy()).sum()
        _pred_onnx.append(outputs_onnx[0])



    acc_ori = float(correct_ori) * 100. / float(total_ori)
    acc_onnx=float(correct_onnx)*100. /float(total_onnx)

    print('net_ori Acc: {:.5f}'.format(acc_ori))
    print('net_onnx Acc: {:.5f}'.format(acc_onnx))

from onnx.

xadupre avatar xadupre commented on July 22, 2024

So the model is probably right since 85% of the observations are correctly classified by onnx. Is it possible to check the highest probabilities instead of the label? Sometimes, it two classes have similar probabilities, a small difference may change the top one.

from onnx.

ffxxjj avatar ffxxjj commented on July 22, 2024

I mean 85% of the observations are misclassified, but I use batchsize=1 to deduce something close to the original model, and I can't understand what's wrong with that

from onnx.

xadupre avatar xadupre commented on July 22, 2024

Can you check the discrepancies for the probabilities on the failing observations?

from onnx.

ffxxjj avatar ffxxjj commented on July 22, 2024

Do you mean the difference between the probability of failure predicted by the onnx model and the predicted value of the original model?

from onnx.

xadupre avatar xadupre commented on July 22, 2024

between the probability of predicted by the onnx model and the predicted value of the original model for the observations which fail the first test.

from onnx.

ffxxjj avatar ffxxjj commented on July 22, 2024

Okay, thanks. I'll check it out

from onnx.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.