""":class:`ForAll` transforms a function that works on scalar inputs such that it works
on arrays instead (vectorization), and can be used with :func:`jax.vmap`."""
from typing import Callable, Optional, Sequence
from spekk import Spec, trees, util
from spekk.transformations import Transformation, common
T_in_axes = Sequence[Optional[int]]
T_vmap = Callable[[callable, T_in_axes], callable]
[docs]
class ForAll(Transformation):
"""Vectorize/"make looped" a function such that it works on arrays instead of
scalars.
>>> from spekk.transformations import ForAll, compose
>>> f = lambda x, y: x + y
>>> data = {"x": range(2), "y": range(3)}
>>> spec = Spec( {"x": ["dim1"], "y": ["dim2"]} )
Transform f to run on a grid defined by the "dim1" and "dim2" dimensions:
>>> tf = compose(f, ForAll("dim2"), ForAll("dim1")).build(spec)
>>> tf.output_spec
Spec(['dim1', 'dim2'])
>>> result = tf(**data)
>>> util.shape(result)
(2, 3)
>>> result
[[0, 1, 2], [1, 2, 3]]
You can also use vmap over multiple dimensions at once. Note that the order of the
dimensions in ``ForAll("dim2"), ForAll("dim1")`` and ``ForAll("dim1", "dim2")`` is
reversed:
>>> tf = compose(f, ForAll("dim1", "dim2")).build(spec)
>>> tf(**data) # This results in the same as in the previous example
[[0, 1, 2], [1, 2, 3]]"""
[docs]
def __init__(
self,
dimension: str,
*additional_dimensions: str,
vmap_impl: Optional[T_vmap] = None,
):
self.dimensions = [dimension, *additional_dimensions]
self.vmap_impl = vmap_impl
self.__post_init__()
def __post_init__(self):
"Sub-classes may override this method to perform additional initialization."
def __repr__(self) -> str:
dimensions_str = ", ".join([repr(dim) for dim in self.dimensions])
return f"ForAll({dimensions_str})"
[docs]
def specced_vmap(
f: callable,
spec: Spec,
dimension: str,
vmap_impl: Optional[T_vmap] = None,
):
"""Similar to ``vmap``, but flattens/decomposes the ``kwargs`` to a list that is
supported by ``vmap``.
"""
if vmap_impl is None:
vmap_impl = python_vmap
def wrapped(*_unsupported_positional_args, **kwargs):
if _unsupported_positional_args:
raise ValueError(
"Positional arguments are not supported in specced_vmap. Use keyword arguments instead."
)
flattened_args, in_axes, unflatten = util.flatten(kwargs, spec, dimension)
def f_with_unflattening_args(*args):
original_kwargs = unflatten(args)
return f(**original_kwargs)
vmapped_f = vmap_impl(f_with_unflattening_args, in_axes)
return vmapped_f(*flattened_args)
return wrapped
[docs]
def python_vmap(f, in_axes):
"""A simple Python implementation of JAX's :func:`jax.vmap` based on for-loops."""
def wrapped(*args):
sizes = [util.shape(arg)[a] for arg, a in zip(args, in_axes) if a is not None]
size = sizes[0]
if not all(s == size for s in sizes):
raise ValueError(
f"Cannot apply python_vmap to arguments with different sizes over the \
in_axes: {sizes=}, {in_axes=}"
)
# The result for each item in the dimension.
all_results = [
f(*common.get_args_for_index(args, in_axes, i)) for i in range(size)
]
result0 = all_results[0]
# Combine the results such that the returned object has the same shape as each
# individual result.
combined_result = result0
for leaf in trees.leaves(
result0, lambda x: isinstance(x, list) or not trees.has_treedef(x)
):
values = [trees.get(_result, leaf.path) for _result in all_results]
combined_result = trees.set(combined_result, values, leaf.path)
return combined_result
return wrapped
if __name__ == "__main__":
import doctest
doctest.testmod()