Comments (8)
Sorry to hijack this a bit, but one interesting question is whether Jax (given its objects are immutable) is able to run multiple threads in parallel (thread safety) on the same model pipeline / model weights. That could be a clear advantage for this module.
Note there is also whispercpp (C++ implementation and Python binding by that name) which could be of interest, but also runs fairly slowly (using CPU only).
With OpenAI whisper module running a single 30 second segment on RTX 3090 and large-v2 only seems to take me 200 ms, which I consider very fast (uses roughly 10 GB VRAM of 24 GB and apparently frees it when idle). It seems to work fine for audio segments up to 30 seconds, but other implementations do overlapping 30 s segments in order to support longer files, and that slows them down a lot even if the audio would fit in a single segment.
from whisper-jax.
@Tronic Thanks for the fast reply!
With OpenAI whisper module running a single 30 second segment on RTX 3090 and large-v2 only seems to take me 200 ms, which I consider very fast (also doesn't use much VRAM). It seems to work fine for audio segments up to 30 seconds, but other implementations do overlapping 30 s segments in order to support longer files, and that slows them down a lot even if the audio would fit in a single segment.
Really? Wow, that's blazing fast. May I ask what your setup is and how did you install JAX? (also, are you running on native linux?)
Are you really telling me that the stock OpenAI/Whisper is taking you only 200ms? I once tested it on a lambdagpu A100 and got similar results to mine
from whisper-jax.
@Tronic Did you do any modifications to beam size and precision etc? Mine takes up about 20GB of VRAM
from whisper-jax.
Using OpenAI in Python with default settings:
model = whisper.load_model("large", device="cuda") # This takes a couple of seconds
options = whisper.DecodingOptions(without_timestamps=True, task="translate", language="en")
pcm = torch.zeros(whisper.audio.N_SAMPLES, dtype=torch.float32) # Must be exactly 30 seconds (audio at beginning)
mel = whisper.log_mel_spectrogram(pcm).to(model.device)
result = whisper.decode(model, mel, options) # 200 ms
Possibly it gets a bit slower if the buffer contains full 30 seconds of actual speech, but I am processing short segments so that is not a problem.
The whisper-jax module uses all my 24 GB of VRAM even on small and gives out of memory errors on medium and large models. Didn't try tuning any settings yet.
from whisper-jax.
I have a RTX 4090 native on Linux. I run this against a 1h 17 min audio.
import datetime
import json
import jax.numpy as jnp
from whisper_jax import FlaxWhisperPipline
def transcribe_70():
# instantiate pipeline
t0 = datetime.datetime.now()
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", batch_size=16, dtype=jnp.bfloat16)
print(f"Load Model at {t0}")
t1 = datetime.datetime.now()
print(f"Loading took {t1 - t0}")
print(f"started at {t1}")
outputs = pipeline("episode.mp3", task="transcribe", return_timestamps=True)
t2 = datetime.datetime.now()
print(f"ended at {t2}")
print(f"time elapsed: {t2 - t1}")
t1 = datetime.datetime.now()
outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
t2 = datetime.datetime.now()
print(f"ended at {t2}")
print(f"time elapsed: {t2 - t1}")
with open("output70.json", "w") as f:
f.write(json.dumps(outputs))
if __name__ == '__main__':
transcribe_70()
I get the following results:
Load Model at 2023-04-21 11:19:43.209201
Loading took 0:00:07.193107
started at 2023-04-21 11:19:50.402308
ended at 2023-04-21 11:23:39.018227
time elapsed: 0:03:48.615919
There was an error while processing timestamps, we haven't found a timestamp as last token. Was WhisperTimeStampLogitsProcessor used?
ended at 2023-04-21 11:25:43.578735
time elapsed: 0:02:04.560488
When I use native whisper
import whisper
import torch
import json
import datetime
if torch.cuda.is_available():
device = torch.device("cuda:0")
print("GPU")
print(torch.cuda.current_device())
print(torch.cuda.device(0))
print(torch.cuda.get_device_name(0))
else:
device = torch.device("cpu")
print("CPU")
t0 = datetime.datetime.now()
print(f"Load Model at {t0}")
model = whisper.load_model('large')
t1 = datetime.datetime.now()
print(f"Loading took {t1-t0}")
print(f"started at {t1}")
# do the transcription
output = model.transcribe("audio.mp3")
# show time elapsed after transcription is complete.
t2 = datetime.datetime.now()
print(f"ended at {t2}")
print(f"time elapsed: {t2 - t1}")
with open("xxx.json", "w") as f:
f.write(json.dumps(output))
It takes about 208 seconds.
So it is
- Whisper: 208 sec
- 1st run 228 sec
- 2nd run 124 sec
Anything I can improve further? It's almost a 2x which is nice. But I was hoping for a bit more. So where is my problem?
from whisper-jax.
@SinanAkkoyun I am seeing something similar on my end (up to 4x slower than faster-whisper) for small files
from whisper-jax.
Thank you all for the great responses!
I really appreciate it!
from whisper-jax.
Hello @oliverwehrens when i ran your code this got printed but never completed the results my mp3 files is about 2 sec
Load Model at 2023-04-24 03:19:10.960131
Loading took 0:00:16.827978
started at 2023-04-24 03:19:27.788109
from whisper-jax.
Related Issues (20)
- why whisper-jax did not use my GPU? HOT 3
- Rust impl
- Unsuccessful deployment HOT 1
- Coral TPU support HOT 2
- Slower than openai whisper with my gpu HOT 2
- I want to use whisper-at models HOT 1
- Has translate be integrated into transcribe? It returns English but expect Chinese. HOT 3
- Slow post processing HOT 1
- unable to run TPU using current kaggle environment HOT 1
- Large Model causing performance degradation?
- Shape Error when running on GPU HOT 2
- HuggingFace space erroring more often than usual HOT 1
- Transcription issues.
- Punctuation mark
- Confidence score and average log probability on Whisper-JAX
- whisper-large-v3 (in demo code) VS whisper-large-v2 (in kaggle notebook) HOT 1
- Add wrapper for wyoming API
- Kernel always restarting when JIT compiling the forward call on MacBook Pro M3 Max
- Huggingface instance hangs when given Youtube URL with playlist
- JIT compile always crashes the kernel and restarts on google colab TPU. HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from whisper-jax.