Creator of Keras confirmed that the new version comes out in a few days. Keras becomes multi-backend again with support for PyTorch, TensorFlow and JAX. Personally, I’m excited to be able to try JAX without having to deep dive into documentation and entire ecosystem. What about you?
Libraries like PyTorch and Jax are already high level libraries in my view. The low level stuff is C++/CUDA/XLA.
I don’t really see the useful extra abstractions in Keras that would lure me to it.
This.
What about not having to write your own training loop? Keras takes away a lot of boilerplate code, it makes your code more readable and less likely to contain bugs. I would compare it to scikit-learn: Sure, you can implement your own Random Forest, but why bother?
The reality is that you nearly always need to break into that training abstraction, and so it is useless.
As the others said, it’s a pain to reimplement common layers in JAX (specifically). PyTorch is much higher level in it’s nn API, but personally I despise rewriting the amazing training loop for every implementation. That’s why even JAX uses Flax for common layers, because why use an error prone operator like jax.lax.conv_from_dilated or whatever and fill its 10 arguments every time? I would rather use flax.linen.Conv2D or keras_core.layers.Conv2D in my Sequential layer and prevent debugging a million times. For PyTorch, model.fit() can just quickly suffice and later customized.
If you use Jax with Keras you are eseentialy doing: keras->jax->jaxpr->llvm->cuda/xla, with probably many more intermediate levels