Source code for drytools.decorator


'''
==================================
decorator - Some useful decorators
==================================

'''
from collections import ChainMap
from functools import wraps
from operator import eq, ne, gt, lt, ge, le
import inspect

from drytools.annotation.composition import compose_annotations
from drytools.annotation.functions import check, iterify
from drytools.decorator_factory import decorator_factory

[docs]@decorator_factory @compose_annotations def args2attrs(restrict_to:(iterify, set)=(), exclude:(iterify, set)=(), expand_kw=True): ''' Decorator to copy method arguments to instance attributes that have the same names (eg: in __init__) Args: restrict_to (:class:`str` or iterable): if specified, only include these named arguments exclude (:class:`str` or iterable): names of arguments to exclude from copying (even if they're in *include*) expand_kw (bool): make an individual attribute for each variable keyword argument Returns: func: decorator Example: >>> class my_cls: ... @args2attrs ... def __init__(self, a, b): ... self.total = self.a + self.b >>> inst = my_cls(5,2) >>> inst.a, inst.b, inst.total (5, 2, 7) ''' class to_replace_with_empty_dict: pass def decorator(fun): sig = inspect.signature(fun) params_to_copy = set(list(sig.parameters)[1:]) - exclude if restrict_to: params_to_copy &= restrict_to if not params_to_copy: raise ValueError('No eligible parameters') def name_and_adjusted_default(param): default = param.default if default is inspect._empty: if param.kind is inspect._VAR_POSITIONAL: default = () elif param.kind is inspect._VAR_KEYWORD: default = to_replace_with_empty_dict return param.name, default defaults = {k:v for k, v in map(name_and_adjusted_default, sig.parameters.values()) if (k in params_to_copy) and (v is not inspect._empty)} if expand_kw: var_kw_params = {p.name for p in sig.parameters.values() if p.kind is inspect._VAR_KEYWORD} # size 1 or 0 expand_param = lambda param_name: param_name in var_kw_params else: expand_param = lambda param_name: False @wraps(fun) def wrapped(self, *args, **kwargs): for name, value in ChainMap(sig.bind(self, *args, **kwargs).arguments, defaults).items(): if name in params_to_copy: if value is to_replace_with_empty_dict: value = {} if expand_param(name): for k, v in value.items(): setattr(self, k, v) else: setattr(self, name, value) return fun(self, *args, **kwargs) return wrapped return decorator
[docs]@compose_annotations def ordered_by(*attrs: check(isinstance, str, raises=TypeError)): ''' Class decorator factory for adding comparison methods based on one or more attributes Args: attrs (str): Name(s) of attribute(s) to use for ordering instances Returns: func: Function to add comparison methods to the class Example: >>> @ordered_by('name') ... class my_cls: ... def __init__(self, name): ... self.name = name ... def __repr__(self): ... return '{}({})'.format(type(self).__name__, repr(self.name)) >>> sorted([my_cls('foo'), my_cls('bar'), my_cls('bax')]) [my_cls('bar'), my_cls('bax'), my_cls('foo')] ''' if not attrs: raise TypeError('No attrs') def comp_val(instance): return tuple(getattr(instance, attr) for attr in attrs) def decorator(cls): def add_comparison_method(comparison): method_name = '__{comparison.__name__}__'.format(**locals()) fun = lambda self, other: comparison(comp_val(self), comp_val(other)) setattr(cls, method_name, fun) for comparison in [eq, ne, gt, lt, ge, le]: add_comparison_method(comparison) return cls return decorator
if __name__ == '__main__': import doctest doctest.testmod()