⭐️Announcing Diffrax!⭐️
Numerical differential equation solvers in
#JAX
.
Very efficient, and with oodles of fun features!
GitHub:
Docs:
Install: `pip install diffrax`
🧵 1/n
First of all! Here's the obligatory first quick example.
Note that Diffrax integrates smoothly with normal JAX. You can safely jit/vmap/grad/etc. to your heart's content.
2/n
Diffrax is very fast. (Check out ) It obtains similar performance to DifferentialEquations.jl, and can be as much as 20x faster than torchdiffeq. (!)
This because of how `jax.jit` removes the overhead of the Python interpreter.
3/n
Diffrax has an extensive feature set including ODE solvers, SDEs solvers, high-order solvers, implicit solvers, symplectic solvers etc; adjoint methods; dense output; etc.
Moreover it's carefully designed to be very extensible!
4/n
For example if you're building some kind of differentiable simulator then you may want to handle the stepping yourself.
Diffrax makes this easy:
()
5/n
Moreover check out the documentation for some very extensive pre-made examples on training Neural ODEs/CDEs/SDEs, continuous normalising flows, solving stiff equations, and so on:
e.g. here's how to solve an SDE using Euler's method w/ dense output:
6/n
Now you may be wondering: why Diffrax?🤔Why does it even exist?
Well, the answer is that I had to procrastinate from writing my PhD thesis somehow... :D
...more seriously, it does some new under-the-hood stuff for unifying the way we solve ODEs/SDEs, that I just had to try!
7/n
I'll leave out the technical details from the announcement, but v. happy to discuss with those curious!
(The result is the AbstractTerm interface in Diffrax.)
We can even handle semi-explicit DAEs (and stochastic DAEs) in the same unified way too!
8/n
Acknowledgements: Diffrax benefits *a lot* from second mover advantage.
#PyTorch
torchdiffeq, torchsde,
#SciML
#julialang
DifferentialEquations.jl all gave a lot of inspiration for how to design clean abstractions, easy-to-use APIs etc.
9/n
On that note, how does Diffrax compare to e.g. torchdiffeq, DE.jl etc? Diffrax is much better for advanced use cases:
- solving ODEs/SDEs simultaneously
- adding your own custom ops
- etc.
And as above it's also much faster + has more features than torchdiffeq/etc.
10/n
So to wrap this up! If you want to solve differential equations *really easily* and *really fast*, go check out Diffrax!
GitHub:
Docs:
Install: `pip install diffrax`
11/n
⚡️ My PhD thesis is on arXiv! ⚡️
To quote my examiners it is "the textbook of neural differential equations" - across ordinary/controlled/stochastic diffeqs.
w/ unpublished material:
- generalised adjoint methods
- symbolic regression
- + more!
v🧵 1/n
And if you're way-ahead-of-the-curve and think you've heard about Diffrax before: it's because it was made public a week early as a sneak preview --
-- and got a bunch of traction despite only being pre-release and formally unannounced. :D
13/12
@jodemaey
It'll run on the GPU just fine; no specific requirements. Diffrax is written in pure JAX, so it will run on the GPU in the same way that any JAX code does.
@PatrickKidger
@NandoDF
I am really grateful for your contributions Patrick
Torchsde and your explanations are helping me a lot on my masters and I'm sure Diffrax would do the same