• Relevant-Yak-9657@alien.topB
    link
    fedilink
    English
    arrow-up
    1
    ·
    10 months ago

    As the others said, it’s a pain to reimplement common layers in JAX (specifically). PyTorch is much higher level in it’s nn API, but personally I despise rewriting the amazing training loop for every implementation. That’s why even JAX uses Flax for common layers, because why use an error prone operator like jax.lax.conv_from_dilated or whatever and fill its 10 arguments every time? I would rather use flax.linen.Conv2D or keras_core.layers.Conv2D in my Sequential layer and prevent debugging a million times. For PyTorch, model.fit() can just quickly suffice and later customized.