Giter VIP home page Giter VIP logo

Comments (8)

Tronic avatar Tronic commented on August 20, 2024

Sorry to hijack this a bit, but one interesting question is whether Jax (given its objects are immutable) is able to run multiple threads in parallel (thread safety) on the same model pipeline / model weights. That could be a clear advantage for this module.

Note there is also whispercpp (C++ implementation and Python binding by that name) which could be of interest, but also runs fairly slowly (using CPU only).

With OpenAI whisper module running a single 30 second segment on RTX 3090 and large-v2 only seems to take me 200 ms, which I consider very fast (uses roughly 10 GB VRAM of 24 GB and apparently frees it when idle). It seems to work fine for audio segments up to 30 seconds, but other implementations do overlapping 30 s segments in order to support longer files, and that slows them down a lot even if the audio would fit in a single segment.

from whisper-jax.

SinanAkkoyun avatar SinanAkkoyun commented on August 20, 2024

@Tronic Thanks for the fast reply!

With OpenAI whisper module running a single 30 second segment on RTX 3090 and large-v2 only seems to take me 200 ms, which I consider very fast (also doesn't use much VRAM). It seems to work fine for audio segments up to 30 seconds, but other implementations do overlapping 30 s segments in order to support longer files, and that slows them down a lot even if the audio would fit in a single segment.

Really? Wow, that's blazing fast. May I ask what your setup is and how did you install JAX? (also, are you running on native linux?)
Are you really telling me that the stock OpenAI/Whisper is taking you only 200ms? I once tested it on a lambdagpu A100 and got similar results to mine

from whisper-jax.

SinanAkkoyun avatar SinanAkkoyun commented on August 20, 2024

@Tronic Did you do any modifications to beam size and precision etc? Mine takes up about 20GB of VRAM

from whisper-jax.

Tronic avatar Tronic commented on August 20, 2024

Using OpenAI in Python with default settings:

model = whisper.load_model("large", device="cuda")  # This takes a couple of seconds
options = whisper.DecodingOptions(without_timestamps=True, task="translate", language="en")

pcm = torch.zeros(whisper.audio.N_SAMPLES, dtype=torch.float32)  # Must be exactly 30 seconds (audio at beginning)
mel = whisper.log_mel_spectrogram(pcm).to(model.device)
result = whisper.decode(model, mel, options)  # 200 ms

Possibly it gets a bit slower if the buffer contains full 30 seconds of actual speech, but I am processing short segments so that is not a problem.

The whisper-jax module uses all my 24 GB of VRAM even on small and gives out of memory errors on medium and large models. Didn't try tuning any settings yet.

from whisper-jax.

oliverwehrens avatar oliverwehrens commented on August 20, 2024

I have a RTX 4090 native on Linux. I run this against a 1h 17 min audio.

import datetime
import json

import jax.numpy as jnp
from whisper_jax import FlaxWhisperPipline

def transcribe_70():
    # instantiate pipeline
    t0 = datetime.datetime.now()
    pipeline = FlaxWhisperPipline("openai/whisper-large-v2", batch_size=16, dtype=jnp.bfloat16)
    print(f"Load Model at  {t0}")
    t1 = datetime.datetime.now()
    print(f"Loading took {t1 - t0}")
    print(f"started at {t1}")

    outputs = pipeline("episode.mp3", task="transcribe", return_timestamps=True)
    t2 = datetime.datetime.now()
    print(f"ended at {t2}")
    print(f"time elapsed: {t2 - t1}")

    t1 = datetime.datetime.now()
    outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
    t2 = datetime.datetime.now()
    print(f"ended at {t2}")
    print(f"time elapsed: {t2 - t1}")




    with open("output70.json", "w") as f:
        f.write(json.dumps(outputs))


if __name__ == '__main__':
    transcribe_70()

I get the following results:

Load Model at  2023-04-21 11:19:43.209201
Loading took 0:00:07.193107
started at 2023-04-21 11:19:50.402308
ended at 2023-04-21 11:23:39.018227
time elapsed: 0:03:48.615919
There was an error while processing timestamps, we haven't found a timestamp as last token. Was WhisperTimeStampLogitsProcessor used?
ended at 2023-04-21 11:25:43.578735
time elapsed: 0:02:04.560488

When I use native whisper

import whisper
import torch
import json
import datetime

if torch.cuda.is_available():
  device = torch.device("cuda:0")
  print("GPU")
  print(torch.cuda.current_device())
  print(torch.cuda.device(0))
  print(torch.cuda.get_device_name(0))
else:
  device = torch.device("cpu")
  print("CPU")

t0 = datetime.datetime.now()
print(f"Load Model at  {t0}")
model = whisper.load_model('large')
t1 = datetime.datetime.now()
print(f"Loading took {t1-t0}")
print(f"started at {t1}")

# do the transcription
output = model.transcribe("audio.mp3")

# show time elapsed after transcription is complete.
t2 = datetime.datetime.now()
print(f"ended at {t2}")
print(f"time elapsed: {t2 - t1}")

with open("xxx.json", "w") as f:
    f.write(json.dumps(output))

It takes about 208 seconds.

So it is

  • Whisper: 208 sec
  • 1st run 228 sec
  • 2nd run 124 sec

Anything I can improve further? It's almost a 2x which is nice. But I was hoping for a bit more. So where is my problem?

from whisper-jax.

ezerhouni avatar ezerhouni commented on August 20, 2024

@SinanAkkoyun I am seeing something similar on my end (up to 4x slower than faster-whisper) for small files

from whisper-jax.

SinanAkkoyun avatar SinanAkkoyun commented on August 20, 2024

Thank you all for the great responses!
I really appreciate it!

from whisper-jax.

syedmustafa54 avatar syedmustafa54 commented on August 20, 2024

Hello @oliverwehrens when i ran your code this got printed but never completed the results my mp3 files is about 2 sec

Load Model at  2023-04-24 03:19:10.960131
Loading took 0:00:16.827978
started at 2023-04-24 03:19:27.788109

from whisper-jax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.