So I wrote a blog post!☀️
"JAX vs Julia (vs PyTorch)"
With a focus on using
#JAX
/
#julialang
in scientific computing + ML computing; covering their similarities (speed, functors, homoiconicity) and differences (introspection, documentation, code quality).
@PatrickKidger
Thanks, Patrick! Equinox looks nice. Have to try it out. :) Are there any known performance bottlenecks for ODE integrators compared to Jax's ODE ecosystem (jax.experimental.ode)? Personally where Julia shines for me is sci-comp + easy windows installation + CUDA.
@bhaveshshrima11
Both Julia and JAX get great performance. I'm sure they still have small differences, but I now pick JAX vs Julia based on other choices (as in the blog post!)
Also do check out Diffrax () as a much more powerful alternative to jax.experimental.ode.
@PatrickKidger
Great write up! Only thing I'd personally add is the niceness of more programming constructs being differentiable in Julia (eg grad wrt custom structs, loops, if statements, crazy closures/splats, etc...; altho some may be fixed in more recent Jax versions than I've used)
@cosmic_mar
JAX has grad wrt structs, loops, ifs etc. I think the only JAX limitation here is reverse-mode autodiff wrt unbounded while loops. (Which can often be worked around with with a bit of magic, e.g. see `diffrax.misc.bounded_while_loop`.)
@PatrickKidger
What's your stance on Julia compilation times?
They're much larger than Python. There should be a lot of computation in the script to feel that speed up.
@valentyn_bez
Compilation is a reasonable price to pay for fast runtime.
IMO the real issue with Julia's JIT compiler is "long time to first error".
Python runtime, or JAX tracing, or e.g. static compilation of C++, all tell you about typo-errors pretty quick.
@PatrickKidger
Great post!
I'd add that TPU compilation is currently in Jax's favor.
For Julia, you can have fast custom arrays (even on GPU), inlineable structs, very fast scalar code, custom GPU kernels and more. Zygote is made possible due to transforms on SSA IR, Julia also has macros
@PatrickKidger
I don’t think the array syntax is a lacklustre of Julia. You can’t really tell what the return type (or shape) of A[1] is in python, it can be either scalar or tensor. Just like a syntax sugar, it might be convenient to write, but hard to read or trace.
@PatrickKidger
Great article, Since I do focus mainly on Pytorch + C, Julia/JAX looks very appealing in terms of speed.
Roughly, for a simple CNN, how many orders of magnitude do you think Julia is faster than Pytorch, because I keep getting told mixed information, hence the hesitancy to switch