spekk.transformations.wrap.Wrap#
- class spekk.transformations.wrap.Wrap(f: callable, *args, **kwargs)[source]#
Bases:
TransformationSimply wraps a function with another function, but useful for keeping information about the spec in a chain of
Transformation.- f#
A wrapper function, for example
jax.jit().
- args#
Optional extra positional arguments to pass to
f.
- kwargs#
Optional extra keyword arguments to pass to
f.
Example
>>> import jax >>> my_fn = lambda x: x**2 >>> wrapped_fn1 = Wrap(jax.jit)(my_fn) >>> wrapped_fn2 = jax.jit(my_fn)
wrapped_fn1 and wrapped_fn2 are equivalent, but wrapped_fn1 will propagate information about the spec (if applicable) to nested
Transformation.Methods
__init__(f, *args, **kwargs)transform_function(to_be_wrapped, ...)Transform the wrapped function given the spec of the input arguments and the spec of the returned value of the wrapped function.
transform_input_spec(spec)Return a new spec that represent the input arguments that are passed down to the wrapped function after the transformation has been applied.
transform_output_spec(spec)Return a new spec that represent the returned value of the final transformed function.
- transform_function(to_be_wrapped: callable, input_spec: Spec, output_spec: Spec) callable[source]#
Transform the wrapped function given the spec of the input arguments and the spec of the returned value of the wrapped function.
- transform_input_spec(spec: Spec) Spec[source]#
Return a new spec that represent the input arguments that are passed down to the wrapped function after the transformation has been applied.
For example, if the transformation vectorizes the wrapped function over a dimension, the wrapped function would only see single items of the dimension at a time. Therefore, the input spec will have one less dimension when passed down to the wrapped function.