@PatrickKidger
Patrick Kidger
2 years
So I wrote a blog post!☀️ "JAX vs Julia (vs PyTorch)" With a focus on using #JAX / #julialang in scientific computing + ML computing; covering their similarities (speed, functors, homoiconicity) and differences (introspection, documentation, code quality).
8
60
386

Replies

@bhaveshshrima11
Bhavesh
2 years
@PatrickKidger Thanks, Patrick! Equinox looks nice. Have to try it out. :) Are there any known performance bottlenecks for ODE integrators compared to Jax's ODE ecosystem (jax.experimental.ode)? Personally where Julia shines for me is sci-comp + easy windows installation + CUDA.
1
0
2
@PatrickKidger
Patrick Kidger
2 years
@bhaveshshrima11 Both Julia and JAX get great performance. I'm sure they still have small differences, but I now pick JAX vs Julia based on other choices (as in the blog post!) Also do check out Diffrax () as a much more powerful alternative to jax.experimental.ode.
3
0
3
@cosmic_mar
Marius Millea
2 years
@PatrickKidger Great write up! Only thing I'd personally add is the niceness of more programming constructs being differentiable in Julia (eg grad wrt custom structs, loops, if statements, crazy closures/splats, etc...; altho some may be fixed in more recent Jax versions than I've used)
1
0
2
@PatrickKidger
Patrick Kidger
2 years
@cosmic_mar JAX has grad wrt structs, loops, ifs etc. I think the only JAX limitation here is reverse-mode autodiff wrt unbounded while loops. (Which can often be worked around with with a bit of magic, e.g. see `diffrax.misc.bounded_while_loop`.)
1
0
1
@valentynbez
Valentyn Bezshapkin 🇺🇦
2 years
@PatrickKidger What's your stance on Julia compilation times? They're much larger than Python. There should be a lot of computation in the script to feel that speed up.
1
0
1
@PatrickKidger
Patrick Kidger
2 years
@valentyn_bez Compilation is a reasonable price to pay for fast runtime. IMO the real issue with Julia's JIT compiler is "long time to first error". Python runtime, or JAX tracing, or e.g. static compilation of C++, all tell you about typo-errors pretty quick.
1
0
3
@digitalhealthxx
Sami Nas 👨‍⚕️
2 years
@PatrickKidger Great post and comparison👆
0
0
0
@akatzzzzz
R-E
2 years
@PatrickKidger Great post! I'd add that TPU compilation is currently in Jax's favor. For Julia, you can have fast custom arrays (even on GPU), inlineable structs, very fast scalar code, custom GPU kernels and more. Zygote is made possible due to transforms on SSA IR, Julia also has macros
0
0
5
@janphigoe
Jan Philip Göpfert
2 years
@PatrickKidger Fun read, thanks for sharing your thoughts.
0
0
0
@Cheng_Ching_Wen
PeterCheng
2 years
@PatrickKidger I don’t think the array syntax is a lacklustre of Julia. You can’t really tell what the return type (or shape) of A[1] is in python, it can be either scalar or tensor. Just like a syntax sugar, it might be convenient to write, but hard to read or trace.
1
0
0
@eulerdomy
Cringe not Soy
2 years
@PatrickKidger Great article, Since I do focus mainly on Pytorch + C, Julia/JAX looks very appealing in terms of speed. Roughly, for a simple CNN, how many orders of magnitude do you think Julia is faster than Pytorch, because I keep getting told mixed information, hence the hesitancy to switch
1
0
0