spekk.transformations.for_all

spekk.transformations.for_all#

ForAll transforms a function that works on scalar inputs such that it works on arrays instead (vectorization), and can be used with jax.vmap().

Functions

python_vmap(f, in_axes)

A simple Python implementation of JAX's jax.vmap() based on for-loops.

specced_vmap(f, spec, dimension[, vmap_impl])

Similar to vmap, but flattens/decomposes the kwargs to a list that is supported by vmap.

Classes

ForAll(dimension, *additional_dimensions[, ...])

Vectorize/"make looped" a function such that it works on arrays instead of scalars.