spekk.transformations.wrap.Wrap#

class spekk.transformations.wrap.Wrap(f: callable, *args, **kwargs)[source]#

Bases: Transformation

Simply wraps a function with another function, but useful for keeping information about the spec in a chain of Transformation.

f#

A wrapper function, for example jax.jit().

args#

Optional extra positional arguments to pass to f.

kwargs#

Optional extra keyword arguments to pass to f.

Example

>>> import jax
>>> my_fn = lambda x: x**2
>>> wrapped_fn1 = Wrap(jax.jit)(my_fn)
>>> wrapped_fn2 = jax.jit(my_fn)

wrapped_fn1 and wrapped_fn2 are equivalent, but wrapped_fn1 will propagate information about the spec (if applicable) to nested Transformation.

__init__(f: callable, *args, **kwargs)[source]#

Methods

__init__(f, *args, **kwargs)

transform_function(to_be_wrapped, ...)

Transform the wrapped function given the spec of the input arguments and the spec of the returned value of the wrapped function.

transform_input_spec(spec)

Return a new spec that represent the input arguments that are passed down to the wrapped function after the transformation has been applied.

transform_output_spec(spec)

Return a new spec that represent the returned value of the final transformed function.

transform_function(to_be_wrapped: callable, input_spec: Spec, output_spec: Spec) callable[source]#

Transform the wrapped function given the spec of the input arguments and the spec of the returned value of the wrapped function.

transform_input_spec(spec: Spec) Spec[source]#

Return a new spec that represent the input arguments that are passed down to the wrapped function after the transformation has been applied.

For example, if the transformation vectorizes the wrapped function over a dimension, the wrapped function would only see single items of the dimension at a time. Therefore, the input spec will have one less dimension when passed down to the wrapped function.

transform_output_spec(spec: Spec) Spec[source]#

Return a new spec that represent the returned value of the final transformed function.

For example, if the transformation sums over a dimension of the result of calling the wrapped function, the output spec will have one less dimension.