Giter VIP home page Giter VIP logo

batched-fn's Introduction

batched-fn

Rust server plugin for deploying deep learning models with batched prediction.

Build License Crates Docs


Deep learning models are usually implemented to make efficient use of a GPU by batching inputs together in "mini-batches". However, applications serving these models often receive requests one-by-one. So using a conventional single or multi-threaded server approach will under-utilize the GPU and lead to latency that increases linearly with the volume of requests.

batched-fn is a drop-in solution for deep learning webservers that queues individual requests and provides them as a batch to your model. It can be added to any application with minimal refactoring simply by inserting the batched_fn macro into the function that runs requests through the model.

Features

  • ๐Ÿš€ Easy to use: drop the batched_fn! macro into existing code.
  • ๐Ÿ”ฅ Lightweight and fast: queue system implemented on top of the blazingly fast flume crate.
  • ๐Ÿ™Œ Easy to tune: simply adjust max_delay and max_batch_size.
  • ๐Ÿ›‘ Back pressure mechanism included: just set channel_cap and handle Error::Full by returning a 503 from your webserver.

Examples

Suppose you have a model API that look like this:

// `Batch` could be anything that implements the `batched_fn::Batch` trait.
type Batch<T> = Vec<T>;

#[derive(Debug)]
struct Input {
    // ...
}

#[derive(Debug)]
struct Output {
    // ...
}

struct Model {
    // ...
}

impl Model {
    fn predict(&self, batch: Batch<Input>) -> Batch<Output> {
        // ...
    }

    fn load() -> Self {
        // ...
    }
}

Without batched-fn a webserver route would need to call Model::predict on each individual input, resulting in a bottleneck from under-utilizing the GPU:

use once_cell::sync::Lazy;
static MODEL: Lazy<Model> = Lazy::new(Model::load);

fn predict_for_http_request(input: Input) -> Output {
    let mut batched_input = Batch::with_capacity(1);
    batched_input.push(input);
    MODEL.predict(batched_input).pop().unwrap()
}

But by dropping the batched_fn macro into your code you automatically get batched inference behind the scenes without changing the one-to-one relationship between inputs and outputs:

async fn predict_for_http_request(input: Input) -> Output {
    let batch_predict = batched_fn! {
        handler = |batch: Batch<Input>, model: &Model| -> Batch<Output> {
            model.predict(batch)
        };
        config = {
            max_batch_size: 16,
            max_delay: 50,
        };
        context = {
            model: Model::load(),
        };
    };
    batch_predict(input).await.unwrap()
}

โ—๏ธ Note that the predict_for_http_request function now has to be async.

Here we set the max_batch_size to 16 and max_delay to 50 milliseconds. This means the batched function will wait at most 50 milliseconds after receiving a single input to fill a batch of 16. If 15 more inputs are not received within 50 milliseconds then the partial batch will be ran as-is.

Tuning max batch size and max delay

The optimal batch size and delay will depend on the specifics of your use case, such as how big of a batch you can fit in memory (typically on the order of 8, 16, 32, or 64 for a deep learning model) and how long of a delay you can afford. In general you want to set max_batch_size as high as you can, assuming the total processing time for N examples is minimized with a batch size of N, and keep max_delay small relative to the time it takes for your handler function to process a batch.

Implementation details

When the batched_fn macro is invoked it spawns a new thread where the handler will be ran. Within that thread, every object specified in the context is initialized and then passed by reference to the handler each time it is run.

The object returned by the macro is just a closure that sends a single input and a callback through an asyncronous channel to the handler thread. When the handler finishes running a batch it invokes the callback corresponding to each input with the corresponding output, which triggers the closure to wake up and return the output.

batched-fn's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

yevgnen weshnaw

batched-fn's Issues

Loading models in the runtime

Is there any way to make context of batch_fn not be static? I want to load the models after the server has been initialized with a configuration of a list of models, so trying to do something like this -

let batched_generate = batched_fn! {
            handler = |batch: Vec<Vec<String>>, model: &HashMap<String, SentenceEmbeddingsModel>| -> Vec<Result<Vec<Vec<f32>>, RustBertError>> {
                let mut batched_result = Vec::with_capacity(batch.len());
                for input in batch {
                    let result = model.encode(&input);
                    batched_result.push(result);
                }
                batched_result
            };
            config = {
                max_batch_size: 1,
                max_delay: 100,
                channel_cap: Some(20),
            };
            context = {
                models: &self.loaded_models,
            };
        };

Where self.loaded_models is created in the constructor of the struct but looks like context needs be static. Any thoughts on how to accomplish this?

[Question] How to make `max_batch_size` configurable?

Hi,

I'm trying to avoid writing

        config = {
            max_batch_size: 32,
        };

and use

        config = {
            max_batch_size: max_batch_size,
        };

instead so that I can adjust the batch size dynamically or read the config from some variable/setting. But I got errors like

error[E0435]: attempt to use a non-constant value in a constant
  --> src/http/predict.rs:76:61
   |
70 |       let handler = batched_fn::batched_fn! {
   |  ___________________-
71 | |         handler = |batch: Vec<Input>, model: &SentimentModel| -> Vec<Output> {
72 | |             let predictions = model.predict(&batch.iter().map(String::as_str).collect::<Vec<&str>>().as_slice());
73 | |             predictions.iter().map(|x| Response { score: x.score }).collect()
...  |
76 | |             max_batch_size: if Cuda::cudnn_is_available() { batch_size } else { 1 },
   | |                                                             ^^^^^^^^^^ non-constant value
...  |
85 | |         };
86 | |     };
   | |_____- help: consider using `let` instead of `static`: `let BATCHED_FN`

For more information about this error, try `rustc --explain E0435`.
error: could not compile `axum-sst` due to previous error

Also it seems that the batch size can not be adjusted once the thread is started. Even so, how can I pass max_batch_size by a variable when initializing? Must the passed max_batch_size be static or const? e.g.

struct Config {
    max_batch_size: usize,
}

static PREDICT_CONFIG: Config = Config { max_batch_size: 32 }; // or `Config.from_file`...

...
        config = {
            max_batch_size: if Cuda::cudnn_is_available() { PREDICT_CONFIG.max_batch_size } else { 1 },
        };

Thanks.

[Question] How to benchmark the execution?

Thanks for the great package!

When tuning max_batch_size and max_delay, I wonder how can I benchmark or record the GPU execution time for a batch? I don't know where to put the time related codes in the code due to my limited knowledge to Rust. It might be easy for a single request but I have no idea for a batch.

Thanks.

How to use in an actix server?

Hi, thanks for this library!

I'm trying to use the library to write a demo using rust-bert and actix. When I tried to put the model in batched_fn!, I got error like


error[E0277]: `*mut torch_sys::C_tensor` cannot be shared between threads safely
   --> src/routes.rs:28:25
    |
28  |       let batch_predict = batched_fn! {
    |  _________________________^
29  | |         handler = |batch: Vec<(Tensor, Tensor, Tensor, Tensor, Tensor)>, model: &PredictModel| -> Vec<String> {
30  | |             let output = model.predict(batch.clone());
31  | |             println!("Processed batch {:?} -> {:?}", batch, output);
...   |
40  | |         };
41  | |     };
    | |_____^ `*mut torch_sys::C_tensor` cannot be shared between threads safely
    |
    = help: within `(tch::Tensor, tch::Tensor, tch::Tensor, tch::Tensor, tch::Tensor)`, the trait `Sync` is not implemented for `*mut torch_sys::C_tensor`
    = note: required because it appears within the type `tch::Tensor`
    = note: required because it appears within the type `(tch::Tensor, tch::Tensor, tch::Tensor, tch::Tensor, tch::Tensor)`
note: required by a bound in `BatchedFn`
   --> /Users/user/.asdf/installs/rust/1.59.0/registry/src/github.com-1ecc6299db9ec823/batched-fn-0.2.2/src/lib.rs:227:25
    |
227 |     T: 'static + Send + Sync + std::fmt::Debug,
    |                         ^^^^ required by this bound in `BatchedFn`
    = note: this error originates in the macro `$crate::__batched_fn_internal` (in Nightly builds, run with -Z macro-backtrace for more info)

Need the model be Sync and Send? I know there's a rust-dl-webserver project but I'm not quite understand the mechanism differences between actix and warp as I'm pretty new to Rust. Can you provide an simple actix example or help me understand the usage with batched_fn!? e.g. How does context, config, handler works? Are they all required? Is the code inside context run only once for initialization (like loading the model)? Should one put other fields beside model in context?

Many thanks.

Channel is disconnected if requests are canceled

I'm having an issue using gRPC where if the request is cancelled prematurely it kills the batched_fn, and will only be resolved by restarting the service completely, and I was wondering if there is a way to re-initialize the thread if it disconnects at runtime?

I'm fairly certain it is also similar if not the same issue as this poster from another repo:
epwalsh/rust-dl-webserver#60

Steps to reproduce with gRPC:

  1. send gRPC request to a endpoint that relies on batched_fn!(...)
  2. cancel the request before it completes
  3. retry the request (it should be failing)
    Note: this could be either during or after the batched_fn! was initialized, either one if the request is canceled it breaks the batched_fn for the rest of the program's lifespan
// Uses rust_bert for SentenceEmbeddingsModel
async fn encode_setence(input: String) -> Result<Vec<f32>, batched_fn::Error> {
    let batch_encode = batched_fn! {
        handler = |batch: Vec<String>, model: &SentenceEmbeddingsModel| -> Vec<Vec<f32>> {
            let span = info_span!("batch_handler");
            let _enter = span.enter();
            debug!("{:?}", batch);
            model.encode(&batch).unwrap()
        };
        config = {
            max_batch_size: 16,
            max_delay: 100,
            channel_cap: Some(20),
        };
        context = {
            model: {
                let span = info_span!("batch_context");
                let _context_enter = span.enter();
                info!("Initializing Model...");
                let span = info_span!("model_load");
                let _load_enter = span.enter();
                info!("Cuda: {}", Cuda::cudnn_is_available());
                info!("Model: {}", LOCAL_MODEL);
                SentenceEmbeddingsBuilder::local(LOCAL_MODEL)
                    .with_device(tch::Device::cuda_if_available())
                    .create_model()
                    .context("Failed to initialize embedding model").unwrap()
            },
        };
    };

    batch_encode(input).await
}
#[tonic::async_trait]
impl Encoder for MyEncoder {
    #[instrument(skip_all)]
    async fn encode_sentence(
        &self,
        request: Request<SentenceRequest>,
    ) -> Result<Response<EmbeddingReply>, Status> {
        info!("EncodeSentence request recieved...");

        let data = encode_setence(request.into_inner().sentence)
            .await
            .map_err(|err| {
                warn!("{:?}", err);
                Status::internal("batch encoder broken")
            })?;

        let reply = EmbeddingReply { data };

        Ok(Response::new(reply))
    }
}
2024-03-05T17:13:50.133775Z  INFO example: Initializing Tonic...
2024-03-05T17:13:50.134101Z  INFO tonic_startup: example: address: [::1]:50051
2024-03-05T17:13:50.134206Z  INFO tonic_startup: example: Initializing Tonic Reflection...
2024-03-05T17:13:50.135073Z  INFO example: Tonic Initialized...
2024-03-05T17:13:56.507081Z  INFO encode_sentence: example: EncodeSentence request received...
2024-03-05T17:13:56.507693Z  INFO batch_context: example: Initializing Model...
2024-03-05T17:13:56.524687Z  INFO batch_context:model_load: example: Cuda: true
2024-03-05T17:14:00.489850Z  INFO encode_sentence: example: EncodeSentence request received...
2024-03-05T17:14:01.497046Z  INFO encode_sentence: example: EncodeSentence request received...
2024-03-05T17:14:02.012228Z  INFO encode_sentence: example: EncodeSentence request received...
thread '<unnamed>' panicked at example\src\main.rs:85:24:
Channel from calling thread disconnected: "SendError(..)"
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
2024-03-05T17:14:04.710488Z  INFO encode_sentence: example: EncodeSentence request received...
2024-03-05T17:14:04.710735Z  WARN encode_sentence: example: Disconnected

Awaiting inside the closure

Hello! thanks for this library ๐Ÿ˜„

I'm having trouble executing async tasks form within the closure. From what I can tell, the closure cannot be an async function, and the model inference that I'm doing needs to be await-ed from within a tokio runtime.

Here's what I'm trying to do :

batched_fn! {
    handler = |batch: Batch<Input>, model: &Model| -> Batch<Output> {
        let output = model.predict(batch.clone()).await; // note the await
        println!("Processed batch {:?} -> {:?}", batch, output);
        output
    };
    config = {
        max_batch_size: 4,
        max_delay: 50,
    };
    context = {
        model: Model::load(),
    };
};

Any way this could be possible? I'm willing to help if you can point me in the right direction!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.