Giter VIP home page Giter VIP logo

ai-models-serialization's Introduction

开源AI模型序列化总结

目录

模型序列化简介

模型序列化是模型部署的第一步,如何把训练好的模型存储起来,以供后续的模型预测使用,是模型部署的首先要考虑的问题。本文主要罗列当前流行开源模型不同序列化方法,以供查阅参考,欢迎添加和指正(Github)。

序列化分类

  • 跨平台跨语言通用序列化方法,主要使用三种格式:XML,JSON,和Protobuf,前两种是文本格式,人和机器都可以理解,后一种是二进制格式,只有机器能理解,但在存储传输解析上有很大的速度优势。

    • PMML (Predictive Model Markup Language),基于XML格式。由数据挖掘组织DMG(Data Mining Group)开发和维护,是表示传统机器学习模型的实际标准,具有广泛的应用。详细参考文章《使用PMML部署机器学习模型》
    • ONNX (Open Neural Network Exchange),基于Protobuf二进制格式。初始由微软和Facebook推出,后面得到了各大厂商和框架的支持,已成为表示深度神经网络模型的不二标准,通过onnx-ml也已经可以支持传统非深度神经网络模型。详细参考文章《使用ONNX部署深度学习和传统机器学习模型》
    • PFA (Portable Format for Analytics),基于JSON格式。PFA同样由PMML的领导组织DMG开发,最新标准是2015发布的0.8.1,后续再没有发布新版本。OpenDataGroup公司开发了基于PFA的预测库Hadrian,提供Java/Scala/Python/R等多语言接口。
    • MLeap,基于JSON或者Protobuf格式。开源但非标准,由初创公司Combust开发,刚开始主要提供对Spark Pipelines的支持,目前也可以支持Scikit-learn等模型。Combust同时提供了MLeap Runtime来支持MLeap格式模型,基于Scala开发,实现了一个独立的预测运行引擎,不依赖于Spark或者Scikit-learn等库。
    • Core ML,基于Protobuf二进制格式,由苹果公司开发,主要目标为在移动设备上使用AI模型。
  • 模型本身提供的自定义序列化方法

    • 文本或者二进制格式
    • 语言专有或者跨语言跨平台自定义格式
  • 语言级通用序列化方法

    • Python - pickle

    • Python - joblib

    • R - rda

      joblib在序列化大numpy数组时有性能优势,pickle的c实现cpickle速度也很快。

  • 用户自定义序列化方法

    • 以上方法都无法达到要求,用户可以使用自定义序列化格式,以满足自己的特殊部署需求:部署性能、模型大小、环境要求等等。但这种方法在模型升级维护以及版本兼容性上是一个大的挑战。

    如何选择模型序列化方法,可以参考以下顺序,优先使用跨平台跨语言通用序列化方法,最后再考虑使用自定义序列化方法:

    DaaS-login

    在同一类型格式选项中,可以参考以下筛选流程:

    DaaS-login

Scikit-learn模型序列化方法:

XGBoost模型序列化方法:

LightGBM模型序列化方式:

Spark-ML模型序列化方式

  • Spark-ML内部存储格式,PipelineModel提供saveload方法,输入的是一个路径,而不是文件名,因为要存储到多个不同的文件中。Spark在大数据的分布式处理有很大优势,比如适合批量预测和模型评估,但是对于实时预测来说,太重量级了,效率不高。提供Scala,Java和Python接口,可以跨平台和语言读取。
  • PMML:JPMML-SparkML
  • ONNX:ONNXMLTools,还在实验阶段。
  • PFA:Aardpfark,支持还不完全。
  • MLeap

Keras模型序列化方法

  • Keras内部格式

    1. HDF5:
    # Save the model
    model.save('path_to_my_model.h5')
    
    # Recreate the exact same model purely from the file
    new_model = keras.models.load_model('path_to_my_model.h5')
    1. TensorFlow SavedModel 格式,该格式是TensorFlow对象的独立序列化格式,由TensorFlow serving和TensorFlow(而不是Python)支持。
    # Export the model to a SavedModel
    model.save('path_to_saved_model', save_format='tf')
    
    # Recreate the exact same model
    new_model = keras.models.load_model('path_to_saved_model')
    1. 分别存储模型结构和模型权重值(Weights)。模型结构可以存储为JSON:
    json_string = model.to_json()
    model = keras.models.model_from_json(json_string)

    或者YAML:

    yaml_string = model.to_yaml()
    model = keras.models.model_from_yaml(yaml_string)

    模型权重值可以存储为HDF5格式:

    model.save_weights('path_to_my_weights.h5')

    或者TF格式:

    model.save_weights('path_to_my_weights', save_format='tf')

    因为该方法没有存储模型训练配置参数和优化器(Optimizer),所以如果您需要再继续训练模型,必须重新调用compile()函数来设置。但是如果只是用于模型预测,这种序列化方式已经足够了,完整的例子:

    # Save JSON config to disk
    json_config = model.to_json()
    with open('model_config.json', 'w') as json_file:
        json_file.write(json_config)
    
    # Save weights to disk
    model.save_weights('path_to_my_weights.h5')
    
    # Reload the model from the 2 files we saved
    with open('model_config.json') as json_file:
        json_config = json_file.read()
    new_model = keras.models.model_from_json(json_config)
    new_model.load_weights('path_to_my_weights.h5')
    
    # Make prediction against the restored model.
    new_model.predict(x_test)
  • PMML: Nyoka,导出的是扩展的PMML模型,不属于PMML标准。

  • ONNX:keras2onnx

Pytorch模型序列化方法

  • Pytorch内部格式:只存储已训练模型的状态(包括weights and biases),因为仅仅为了模型预测。

    # Saving & Loading Model for Inference
    torch.save(model.state_dict(), PATH)
    
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()
  • ONNX:内部支持torch.onnx.export

MXNet模型序列化方法

  • MXNet内部格式

    1. 只存储模型参数,不包含模型结构,加载时需要建立模型结构。
    # Saving model parameters to file
    net = build_net(gluon.nn.Sequential())
    train_model(net)
    net.save_parameters(file_name)
    
    # Loading model parameters from file
    new_net = build_net(gluon.nn.Sequential())
    new_net.load_parameters(file_name, ctx=ctx)
    1. 存储模型参数和结构到JSON文件中,该格式可以跨平台和语言使用,可以在不同的语言中被加载,比如C,C++或者Scala。
    # Saving model parameters AND architecture to file
    net = build_net(gluon.nn.HybridSequential())
    net.hybridize()
    train_model(net)
    # Two files path-symbol.json and path-xxxx.params will be created, where xxxx is the 4 digits epoch number.
    net.export(path)
    
    # Loading model parameters AND architecture from file
    gluon.nn.SymbolBlock.imports(symbol_file, input_names, param_file=None, ctx=None)
  • ONNX:内部支持mxnet.contrib.onnx.export_model

总结

这并不是一个完整的列表,欢迎大家贡献,标星^_^。

Github地址:https://github.com/aipredict/ai-models-serialization

ai-models-serialization's People

Contributors

aipredict avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.