Giter VIP home page Giter VIP logo

Comments (4)

traveller59 avatar traveller59 commented on September 26, 2024 2

example in readme generate a engine with two outputs. so you need to alloc memory for second output and add it to bindings.
I recommend to write a simple high-level API based on allocate_buffers in common.py in tensorrt examples to decrease these kind of bugs. you can get all shapes, dtypes and names of inputs and outputs from a engine instance, use them to write API.

from torch2trt.

oscarriddle avatar oscarriddle commented on September 26, 2024

Hi @traveller59
Thanks for your reply.
After update the 2 outputs to the test script, the error disappeared.

import torch
import torchvision
import tensorrt as trt
import torch2trt
import time
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
if __name__ == '__main__':
    img = np.random.rand(1, 3, 299, 299)
    #img /= 255.0                                                                                                                            
    #img -= 0.5                                                                                                                              
    #img *= 2.0                                                                                                                              
    bindings = []
    img = np.ascontiguousarray(img)
    #out = infer(get_engine('test.engine'), input, 1)                                                                                        
    engine = get_engine('test.engine')
    #runtime = trt.infer.create_infer_runtime(G_LOGGER)                                                                                      
    stream = cuda.Stream()
    context = engine.create_execution_context()

    output = np.empty(1000, dtype = np.float32)
    output2 = np.empty(1000, dtype = np.float32)
    d_input = cuda.mem_alloc(1 * img.nbytes)
    d_output = cuda.mem_alloc(1 * output.nbytes)
    d_output2 = cuda.mem_alloc(1 * output2.nbytes)
    bindings = [int(d_input), int(d_output), int(d_output2)]

    for i in range(100):
        a1 = time.time()
        cuda.memcpy_htod_async(d_input, img, stream)
        context.execute_async(1, bindings, stream.handle, None)
        cuda.memcpy_dtoh_async(output, d_output, stream)
        cuda.memcpy_dtoh_async(output2, d_output2, stream)
        stream.synchronize()
        a2 = time.time()
        print('Batch {}-th, Time {}ms'.format(i, a2-a1))
    print(output)
    print(output2)

I snooped the data in array output and output2, the output2 is all zero and the output is as below:

[1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0. 1. 0. 1. 1. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0.1. 0. 0. 1. 1. 0. 0. 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.0. 1. 0. 0. 0. 1. 1. 0. 0. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 1. 1. 1. 0. 0. 1. 1. 0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 0. 1. 0. 1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 1. 1. 0. 1. 1. 0. 0. 0. 1. 0. 1. 0. 0. 1. 0. 0. 1. 1. 0. 0. 0. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 0. 0. 1. 1. 0. 1. 0. 1. 1. 0. 0. 1. 0. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 1. 1. 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 1. 1. 0. 1. 0. 0. 1. 1. 1. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 1. 0. 1. 0. 1. 1. 0. 0. 1. 1. 0. 1. 1. 0. 1. 0. 1. 0. 1. 1. 0. 0. 1. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 0. 1. 0. 1. 1. 0. 1. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 1. 1. 1. 0. 1. 1. 0. 1. 1. 1. 1. 1. 0. 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 0. 1. 0. 0. 1. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 1. 1. 0. 1. 1. 0. 1. 0. 0. 1. 0. 0. 1. 0. 1. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 1. 1. 1. 0. 1. 1. 1. 0. 0. 1. 1. 0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 1. 0. 1. 1. 1. 0. 1. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 1. 1. 0. 0. 1. 0. 0. 0. 1. 1. 0. 0. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 1. 1. 0. 0. 0. 1. 0. 0. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 1. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 1. 0. 0. 1. 0. 1. 1. 1. 1. 0. 1. 0. 0. 0. 1. 0. 1. 0. 1. 1. 0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 1. 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 1. 1. 0. 1. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 1. 1. 0. 0. 1. 1. 0. 1. 1. 0. 1. 0. 0. 1. 0. 1. 1. 1. 1. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 1. 1. 0. 1. 0. 1. 0. 1. 0. 0. 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 0. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 1. 1. 0. 1. 0. 1. 1. 0. 0. 1. 0. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 1. 0. 1. 0. 1. 1. 1. 1. 0. 0. 1. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 1. 0. 1. 0. 1. 1. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 1. 1. 0. 0. 0. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 1. 0. 1. 1. 1. 1. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 1. 1. 0. 0. 0. 1. 1. 1. 0. 1. 0. 1. 1. 1. 1. 0. 1. 0. 0.]

Meanwhile, I randomized the same shape input by torch.rand().cuda() and input it into the pytorch model, and the output tensor is as below:

tensor([[-1.1718e+00,  4.6266e-01,  1.0212e+00,  2.2632e-01,  4.7583e-01,
         -6.3883e-01,  1.0015e-01,  1.5616e+00,  1.0254e+00,  1.1679e+00,
          1.8992e+00,  2.9743e+00,  3.4913e+00,  2.1630e+00,  2.3985e+00,
          2.6715e+00,  3.0376e+00,  1.4133e+00,  3.0594e+00,  2.1019e+00,
          1.1495e+00,  5.6692e+00,  3.9588e+00,  3.3769e+00,  2.3353e+00,
         -1.0645e+00, -7.9425e-01, -8.1189e-01, -1.1870e+00, -5.1096e-01,
         -8.8577e-02,  2.2844e-01, -1.5711e+00, -2.0256e-01, -3.3481e-01,
...

Seems my conversion configuration is incorrect, would you give some advice?
Thanks

Below is my conversion script:

import torch
import torchvision
import tensorrt as trt
import torch2trt
import time
import numpy

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

net = torchvision.models.inception_v3(pretrained=True).eval()
inputs = torch.rand(1, 3, 299, 299)
graph_pth = torch2trt.GraphModule(net, inputs, param_exclude=".*AuxLogits.*")                                                                                       
torch_mode_out = graph_pth(inputs)                                                                                           
def toy_example(x):
    return torch.softmax(x, 1), torch.sigmoid(x)
graph_pth_toy = torch2trt.GraphModule(toy_example, torch_mode_out)
probs, sigmoid = graph_pth_toy(torch_mode_out, verbose=True)

with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as trt_net:
    builder.max_workspace_size = 1 << 30                                                                   
    with torch2trt.trt_network(trt_net): # must use this to enter trt mode                                                                   
        img = trt_net.add_input(name="image", shape=[3, 299, 299], dtype=trt.float32)
        trt_mode_out = graph_pth(img, verbose=True) # call graph_pth like torch module call                                                                                                                                     
        trt_mode_out, sigmoid = graph_pth_toy(trt_mode_out)
    trt_mode_out.name = "output_softmax"
    sigmoid.name = "output_sigmoid"
    trt_net.mark_output(tensor=trt_mode_out)
    trt_net.mark_output(tensor=sigmoid)
    engine = builder.build_cuda_engine(trt_net)                                                                                                    
    with open("test.engine", "wb") as f:
        f.write(engine.serialize())

from torch2trt.

traveller59 avatar traveller59 commented on September 26, 2024

you need to convert image to np.float32.
please don't use raw API... you can try my high-level API in newest code.

from torch2trt.

oscarriddle avatar oscarriddle commented on September 26, 2024

Hi, I tried your newest inference code and get the output results.
I noticed you compared different results by norm to check the coherency.
I also tried to import the exactly same input to original pytorch method, like below

import torch
import torchvision
import tensorrt as trt
import torch2trt
import time
import numpy as np

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
net = torchvision.models.inception_v3(pretrained=True).eval()                                                                                 
img = np.load('input_raw.bin.npy')
inputs = torch.from_numpy(img)                                                                                                    
model = net.cuda()
inp = inputs.cuda()
for i in range(1):
    a1 = time.time()
    out = model(inp)
    a2 = time.time()
    print('{}, {}, {}'.format(i, out, a2-a1))

But got a different result compared to the tensorrt way. Would you leave some comments about how to address this issue?

from torch2trt.

Related Issues (9)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.