Nice work on this! I've been thinking of doing something similar with JAX with either tfjs-tflite or ORT Web.
I'm curious if you have any plans to take this further, or just wanted to get a demo/POC working for fun?
Also wondering whether you've tried benchmarking against Pytorch's CPU backend? Curious how much slower it is. You could probably get a decent speedup in your Github Pages demo by including this script in the <head>
- it's quite hacky, but AFAIK it's the only way to get cross-origin isolation working on Github Pages since there's no feature to set COOP/COEP headers yet.
Regarding JAX, it seems like it should be relatively easy at this point (given recent model conversion tooling improvements) to do a similar thing as you've done here, but I've been thinking that it'd be cool to just get the whole JAX lib (CPU backend) working in the browser, so even the model-building side of things can be done in the browser. I don't have much experience with C/C++, and so the build process is kind of a black box to me, and currently stuck here (more info) in case this is interesting to you. JAX is quite "lean" in terms of dependencies and stuff, so it seems like it shouldn't be toooo hard. Will try to keep chipping away at this in my spare time.
There is apparently some work going on to get PyTorch working in the browser, which would be quite a feat - seems like there are a lot more dependencies/"moving parts" in PyTorch compared to JAX.