Source code for spekk.transformations.common

"Some common utility functions used by :mod:`spekk.transformations`."

from typing import Any, Sequence, Union

import numpy as np

from spekk import trees
from spekk.transformations import common


[docs] def compose(x, *wrapping_functions): """Apply each f in fs to x. Let's say we have some functions: >>> f = lambda x: x+1 >>> g = lambda x: x*2 >>> h = lambda x: x**2 We can use :func:`compose` to apply each function in order: >>> compose(1, f, g, h) # ((1 + 1) * 2) ** 2 = 16 16 This would be the same as calling: >>> h(g(f(1))) # ((1 + 1) * 2) ** 2 = 16 16 In situations with a lot of nested function calls, :func:`compose` may be more readable. Also notice that when using compose, functions are evaluated in the order that they are passed in (left-to-right), while with the nested function calls, the functions are evaluated in the reverse order (right-to-left). :func:`compose` can also be used to build up a function from smaller function transformations: >>> wrap_f_double = lambda f: (lambda x: 2*f(x)) >>> wrap_f_square = lambda f: (lambda x: f(x)**2) >>> f = compose( ... lambda x: x+1, ... wrap_f_double, ... wrap_f_square, ... ) >>> f(1) # ((1 + 1) * 2) ** 2 = 16 16 """ for wrap in wrapping_functions: x = wrap(x) return x
[docs] def identity(x): "Return the input unchanged." return x
[docs] def get_fn_name(f) -> str: if hasattr(f, "__qualname__"): return f.__qualname__ if hasattr(f, "__name__"): return f.__name__ return repr(f)
[docs] def getitem_along_axis(x, axis: int, i: int): slice_ = tuple([slice(None)] * axis + [i]) try: return x.__getitem__(slice_) except TypeError: try: return np.array(x).__getitem__(slice_) except Exception: raise ValueError( f"Cannot get item at index {i} along axis {axis} for {x!r}" )
[docs] def get_args_for_index( args: Sequence, in_axes: Sequence[Union[int, None]], i: int ) -> Sequence: return [ common.getitem_along_axis(arg, a, i) if a is not None else arg for arg, a in zip(args, in_axes) ]
[docs] def map_1_flattened( map_f: callable, flattened_args: Sequence[Any], in_axes: Sequence[Union[int, None]], unflatten: callable, i: int, ): args = [] for arg, axis in zip(flattened_args, in_axes): # If axis is None then we leave the argument as is. if axis is not None: # If axis is not None, then get the item at index i along the given axis. arg = trees.update_leaves( arg, lambda x: not trees.has_treedef(x), lambda x: ( common.getitem_along_axis(x, axis, i) if hasattr(x, "__getitem__") else x ), ) args.append(arg) kwargs = unflatten(args) return map_f(**kwargs)
if __name__ == "__main__": import doctest doctest.testmod()