• underPanther@alien.topB
    link
    fedilink
    English
    arrow-up
    1
    ·
    10 months ago

    Libraries like PyTorch and Jax are already high level libraries in my view. The low level stuff is C++/CUDA/XLA.

    I don’t really see the useful extra abstractions in Keras that would lure me to it.

    • odd1e@alien.topB
      link
      fedilink
      English
      arrow-up
      1
      ·
      10 months ago

      What about not having to write your own training loop? Keras takes away a lot of boilerplate code, it makes your code more readable and less likely to contain bugs. I would compare it to scikit-learn: Sure, you can implement your own Random Forest, but why bother?

      • TheCloudTamer@alien.topB
        link
        fedilink
        English
        arrow-up
        1
        ·
        10 months ago

        The reality is that you nearly always need to break into that training abstraction, and so it is useless.

    • Relevant-Yak-9657@alien.topB
      link
      fedilink
      English
      arrow-up
      1
      ·
      10 months ago

      As the others said, it’s a pain to reimplement common layers in JAX (specifically). PyTorch is much higher level in it’s nn API, but personally I despise rewriting the amazing training loop for every implementation. That’s why even JAX uses Flax for common layers, because why use an error prone operator like jax.lax.conv_from_dilated or whatever and fill its 10 arguments every time? I would rather use flax.linen.Conv2D or keras_core.layers.Conv2D in my Sequential layer and prevent debugging a million times. For PyTorch, model.fit() can just quickly suffice and later customized.

    • abio93@alien.topB
      link
      fedilink
      English
      arrow-up
      1
      ·
      10 months ago

      If you use Jax with Keras you are eseentialy doing: keras->jax->jaxpr->llvm->cuda/xla, with probably many more intermediate levels