Source code for spekk.transformations.apply
":class:`Apply` applies function ``f`` to the output of the wrapped function."
from typing import Callable
import spekk.transformations.common as common
from spekk import Spec, trees
from spekk.transformations.axis import Axis, concretize_axes
from spekk.transformations.base import Transformation
[docs]
class Apply(Transformation):
"""Transform a function such that ``f`` is applied to the output of it.
Attributes:
f: The function to apply to the result of the wrapped function.
args: Optional extra positional arguments to pass to ``f``.
kwargs: Optional extra keyword arguments to pass to ``f``.
"""
[docs]
def __init__(self, f: Callable, *args, **kwargs):
self.f = f
self.args = args
self.kwargs = kwargs
def with_extra_output_spec_transform(self, t: Callable[[Spec], Spec]):
copy = Apply(self.f, *self.args, **self.kwargs)
copy.extra_output_spec_transform = t
return copy
def __repr__(self) -> str:
args_str = ", ".join([str(arg) for arg in self.args])
kwargs_str = ", ".join([f"{k}={str(v)}" for k, v in self.kwargs.items()])
repr_str = f"Apply({common.get_fn_name(self.f)}"
if self.args:
repr_str += f", {args_str}"
if self.kwargs:
repr_str += f", {kwargs_str}"
# Make sure the repr string is not too long
if len(repr_str) > 140:
repr_str = repr_str[: (140 - len("… <truncated>"))] + "… <truncated>"
return repr_str + ")"