@PatrickKidger
Thanks for the explanations! You mention the lack of footguns and the ease of use, where does it stands in term of performance, have you had the time to benchmark it? Thanks =)
@NoisyFrequency
So because it always ends up boiling down to jax.jit producing an XLA computation graph (regardless of library), then the performance is the usual excellent performance you can expect from JAX.
When using this in the context of diffeq solves I've seen ~100x speedup over PyTorch.
@PatrickKidger
Recently into Flax to put together a blog. There were parts with the Flax effect system that were quite nice
* `vmap` with different parameters ()
* caching for RNNs (
)
* avoiding recomp of parameters ()