Giter VIP home page Giter VIP logo

torch2tflite's Introduction

PyTorch to TensorFlow Lite Converter

Converts PyTorch whole model into Tensorflow Lite

PyTorch -> Onnx -> Tensorflow 2 -> TFLite

Please install first

python3 setup.py install

Args

  • --torch-path Path to local PyTorch model, please save whole model e.g. torch.save(model, PATH)
  • --tf-lite-path Save path for Tensorflow Lite model
  • --target-shape Model input shape to create static-graph (default: (224, 224, 3)
  • --sample-file Path to sample image file. If model is not about computer-vision, please use leave empty and only enter --target-shape
  • --seed Seeds RNG to produce random input data when --sample-file does not exists
  • --log=INFO To see what happens behind

Basic usage of the script

To test with sample file:

python3 -m torch2tflite.converter
    --torch-path tests/mobilenetv2_model.pt
    --tflite-path mobilenetv2.tflite
    --sample-file sample_image.png
    --target-shape 224 224 3

To test with random input to check gradients:

python3 -m torch2tflite.converter
    --torch-path tests/mobilenetv2_model.pt
    --tflite-path mobilenetv2.tflite
    --target-shape 224 224 3
    --seed 10

torch2tflite's People

Contributors

omerferhatt avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

torch2tflite's Issues

No module named 'models'

I tried this code to convert my custom trained YOLOv5 model (.pt)
It returns this message

2021-05-31 10:26:03.181164: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
No module named 'models'

How can I fix this?

OSError: SavedModel file does not exist at: ./converter/tf_model/{saved_model.pbtxt|saved_model.pb}

Hi,

Thank you very much for sharing this code. I have tried applying it to my torch model, the conversion steps from torch to ONNX and ONNX to TF work fine, but it fails at the last step (TF to TFlite) with the following message:

Traceback (most recent call last):
File "converter.py", line 58, in
main()
File "converter.py", line 22, in main
convert(torch_model_path=args.torch_model_path,
File "/home/A49175/torch2tflite-master/converter/torch_to_tflite.py", line 157, in convert
tf_to_tf_lite(tf_path=TF_PATH, tf_lite_path=tf_lite_model_path)
File "/home/A49175/torch2tflite-master/converter/torch_to_tflite.py", line 74, in tf_to_tf_lite
converter = tf.lite.TFLiteConverter.from_saved_model(tf_path) # Path to the SavedModel directory
File "/home/A49175/.conda/envs/tflite-converter/lib/python3.8/site-packages/tensorflow/lite/python/lite.py", line 399, in from_saved_model
saved_model = _load(saved_model_dir, tags)
File "/home/A49175/.conda/envs/tflite-converter/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 578, in load
return load_internal(export_dir, tags)
File "/home/A49175/.conda/envs/tflite-converter/lib/python3.8/site-packages/tensorflow/python/saved_model/load.py", line 588, in load_internal
loader_impl.parse_saved_model_with_debug_info(export_dir))
File "/home/A49175/.conda/envs/tflite-converter/lib/python3.8/site-packages/tensorflow/python/saved_model/loader_impl.py", line 56, in parse_saved_model_with_debug_info
saved_model = _parse_saved_model(export_dir)
File "/home/A49175/.conda/envs/tflite-converter/lib/python3.8/site-packages/tensorflow/python/saved_model/loader_impl.py", line 110, in parse_saved_model
raise IOError("SavedModel file does not exist at: %s/{%s|%s}" %
OSError: SavedModel file does not exist at: ./converter/tf_model/{saved_model.pbtxt|saved_model.pb}

Do you have any idea of the origin of this problem ? I am using the packages versions specified in the requirements.txt.

Thank you!

Issue in the model converting

Hi, I had an error below when the code is in torch.onnx.export. I doubt the problem is the model isn't loaded completely but I'm not sure. Is it possible to provide any model sample to verify my code environment. Thanks!
'dict' object has no attribute 'training'

TypeError: 'tuple' object cannot be interpreted as an integer

I created below colab and getting below "TypeError: 'tuple' object cannot be interpreted as an integer" error.
Could you please help me on this ?
https://colab.research.google.com/drive/19JeOkrxlGP6KtkfGbrkO87COc4pvSZSE?usp=sharing#scrollTo=vdeytDxjFdZ3

Traceback (most recent call last):
File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"main", mod_spec)
File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/content/torch2tflite/torch2tflite/converter.py", line 187, in
args.seed
File "/content/torch2tflite/torch2tflite/converter.py", line 40, in init
self.sample_data = self.load_sample_input(sample_file_path, target_shape, seed, normalize)
File "/content/torch2tflite/torch2tflite/converter.py", line 121, in load_sample_input
data = np.random.random(target_shape).astype(np.float32)
File "mtrand.pyx", line 434, in numpy.random.mtrand.RandomState.random
File "mtrand.pyx", line 425, in numpy.random.mtrand.RandomState.random_sample
File "_common.pyx", line 291, in numpy.random._common.double_fill
TypeError: 'tuple' object cannot be interpreted as an integer

ERROR:root:Can not load PyTorch model. Please make surethat model saved like `torch.save(model, PATH)`

I'm trying to convert a YOLOv5 best.pt weights file to a .tflite file so we can deploy the model on a flutter app.

This is the code:

`import torch
weights_path = '/content/drive/MyDrive/Weights/best.pt'
yolo_path = '/content/yolov5'

model = torch.hub.load(yolo_path, 'custom', weights_path, source='local') # local repo
torch.save(model, '/content/pipeline.pt' )`

You can also go straight to the colab for all the colde.
https://colab.research.google.com/drive/19JeOkrxlGP6KtkfGbrkO87COc4pvSZSE?usp=sharing

I assume I'm not saving the model how the converter wants it, but I can't figure it out.
Can you please explain exactly how the file needs to be saved so the code will work ?

Thanks a lot in advance!

error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

I tried this code to convert my custom trained YOLOv5 model (.pt)
It returns this bug:
cv2.error: OpenCV(4.1.2) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

2021-06-02 03:17:26.651552: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
Showing result!

Traceback (most recent call last):
File "converter.py", line 59, in
main()
File "converter.py", line 30, in main
original_image, tf_lite_image, torch_image = get_example_input(args.test_im_path)
File "/content/drive/My Drive/Univoice/yolov5/torch2tflite/converter/torch_to_tflite.py", line 28, in get_example_input
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
cv2.error: OpenCV(4.1.2) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'

How can I fix this?

Fails to create .tflite

ValueError: Could not open './converter/tf_lite_model.tflite'.

I've loaded my .pt file and a sample test image....but looks like the conversion is not working.
Can you rectify this?

Thanks

Runtime Errori mportError: dlopen

When I have ran:
python -m torch2tflite.converter --torch-path converted.pt --tflite-path siggraph17.tflite --targeet-shape 225 225 3 --seed 10

I have the following error:

  File "/usr/local/Cellar/[email protected]/3.8.12_1/Frameworks/Python.framework/Versions/3.8/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/local/Cellar/[email protected]/3.8.12_1/Frameworks/Python.framework/Versions/3.8/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/XXX/torch2tflite/torch2tflite/converter.py", line 12, in <module>
    import onnx
  File "/XXX/torch2tflite/venv38/lib/python3.8/site-packages/onnx-1.11.0-py3.8-macosx-12-x86_64.egg/onnx/__init__.py", line 10, in <module>
    from .onnx_cpp2py_export import ONNX_ML
ImportError: dlopen(/XXX/torch2tflite/venv38/lib/python3.8/site-packages/onnx-1.11.0-py3.8-macosx-12-x86_64.egg/onnx/onnx_cpp2py_export.cpython-38-darwin.so, 0x0002): symbol not found in flat namespace '__ZN6google8protobuf11MessageLite20ParseFromCodedStreamEPNS0_2io16CodedInputStreamE'```


ANY SUGGESTION, SOLUTION?

Installation issue

Dear Authors, Could you please verify the dependencies. It seems tflite-runtime~=2.5 is not available now.

No local packages or working download links found for tflite-runtime~=2.5
error: Could not find suitable distribution for Requirement.parse('tflite-runtime~=2.5')

Thanks
Maulik

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.