Source code for nutils.util

# Copyright (c) 2014 Evalf
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

"""
The util module provides a collection of general purpose methods. Most
importantly it provides the :func:`run` method which is the preferred entry
point of a nutils application, taking care of command line parsing, output dir
creation and initiation of a log file.
"""

from . import numeric, config
import sys, os, numpy, collections.abc, inspect, functools, operator, numbers, pathlib

supports_outdirfd = os.open in os.supports_dir_fd and os.listdir in os.supports_fd

sum = functools.partial(functools.reduce, operator.add)
product = functools.partial(functools.reduce, operator.mul)

def cumsum(seq):
  offset = 0
  for i in seq:
    yield offset
    offset += i

def gather(items):
  gathered = []
  d = {}
  for key, value in items:
    try:
      values = d[key]
    except KeyError:
      d[key] = values = []
      gathered.append((key, values))
    values.append(value)
  return gathered

def allequal(seq1, seq2):
  seq1 = iter(seq1)
  seq2 = iter(seq2)
  for item1, item2 in zip(seq1, seq2):
    if item1 != item2:
      return False
  if list(seq1) or list(seq2):
    return False
  return True

[docs]class NanVec(numpy.ndarray): 'nan-initialized vector' def __new__(cls, length): vec = numpy.empty(length, dtype=float).view(cls) vec[:] = numpy.nan return vec @property def where(self): return ~numpy.isnan(self.view(numpy.ndarray)) def __iand__(self, other): if self.dtype != float: return self.view(numpy.ndarray).__iand__(other) where = self.where if numpy.isscalar(other): self[where] = other else: assert numeric.isarray(other) and other.shape == self.shape self[where] = other[where] return self def __and__(self, other): if self.dtype != float: return self.view(numpy.ndarray).__and__(other) return self.copy().__iand__(other) def __ior__(self, other): if self.dtype != float: return self.view(numpy.ndarray).__ior__(other) wherenot = ~self.where self[wherenot] = other if numpy.isscalar(other) else other[wherenot] return self def __or__(self, other): if self.dtype != float: return self.view(numpy.ndarray).__or__(other) return self.copy().__ior__(other) def __invert__(self): if self.dtype != float: return self.view(numpy.ndarray).__invert__() nanvec = NanVec(len(self)) nanvec[numpy.isnan(self)] = 0 return nanvec
def regularize(bbox, spacing, xy=numpy.empty((0,2))): xy = numpy.asarray(xy) index0 = numeric.floor(bbox[:,0] / (2*spacing)) * 2 - 1 shape = numeric.ceil(bbox[:,1] / (2*spacing)) * 2 + 2 - index0 index = numeric.round(xy / spacing) - index0 keep = numpy.logical_and(numpy.greater_equal(index, 0), numpy.less(index, shape)).all(axis=1) mask = numpy.zeros(shape, dtype=bool) for i, ind in enumerate(index): if keep[i]: if not mask[tuple(ind)]: mask[tuple(ind)] = True else: keep[i] = False coursex = mask[0:-2:2] | mask[1:-1:2] | mask[2::2] coarsexy = coursex[:,0:-2:2] | coursex[:,1:-1:2] | coursex[:,2::2] vacant, = (~coarsexy).ravel().nonzero() newindex = numpy.array(numpy.unravel_index(vacant, coarsexy.shape)).T * 2 + index0 + 1 return numpy.concatenate([newindex * spacing, xy[keep]], axis=0) def run(*functions): print('WARNING util.run is deprecated, please use cli.run instead') assert functions import datetime, inspect, contextlib from . import cli, config with contextlib.ExitStack() as stack: stack.enter_context(userconfig()) properties = {k: v for k, v in vars(config).items() if not k.startswith('_')} properties['tbexplore'] = properties.pop('pdb') if '-h' in sys.argv[1:] or '--help' in sys.argv[1:]: print('Usage: %s [FUNC] [ARGS]' % sys.argv[0]) print(''' --help Display this help --nprocs=%(nprocs)-14s Select number of processors --outrootdir=%(outrootdir)-10s Define the root directory for output --outdir= Define custom directory for output --verbose=%(verbose)-13s Set verbosity level, 9=all --richoutput=%(richoutput)-10s Use rich output (colors, unicode) --htmloutput=%(htmloutput)-10s Generate an HTML log --tbexplore=%(tbexplore)-11s Start traceback explorer on error --imagetype=%(imagetype)-11s Set image type --symlink=%(symlink)-13s Create symlink to latest results --recache=%(recache)-13s Overwrite existing cache --dot=%(dot)-17s Set graphviz executable --profile=%(profile)-13s Show profile summary at exit''' % properties) for i, func in enumerate(functions): print() print('Arguments for %s%s' % (func.__name__, '' if i else ' (default)')) print() print('\n'.join(' --{}={}'.format(parameter.name, parameter.default) for parameter in inspect.signature(func).parameters.values() if parameter.kind not in (parameter.VAR_POSITIONAL, parameter.VAR_KEYWORD))) sys.exit(0) func = functions[0] argv = sys.argv[1:] funcbyname = {func.__name__: func for func in functions} if argv and argv[0] in funcbyname: func = funcbyname[argv[0]] argv = argv[1:] kwargs = {parameter.name: parameter.default for parameter in inspect.signature(func).parameters.values() if parameter.kind not in (parameter.VAR_POSITIONAL, parameter.VAR_KEYWORD)} for arg in argv: arg = arg.lstrip('-') try: arg, val = arg.split('=', 1) val = eval(val, sys._getframe(1).f_globals) except ValueError: # split failed val = True except (SyntaxError,NameError): # eval failed pass arg = arg.replace('-', '_') if arg in kwargs: kwargs[arg] = val else: assert arg in properties, 'invalid argument %r' % arg properties[arg] = val missing = [arg for arg, val in kwargs.items() if val is inspect.Parameter.empty] assert not missing, 'missing mandatory arguments: {}'.format(', '.join(missing)) properties['pdb'] = properties.pop('tbexplore') stack.enter_context(config(**properties)) status = cli.call(func, kwargs, scriptname=os.path.basename(sys.argv[0]), funcname=func.__name__) sys.exit(status) class hashlessdict(collections.abc.MutableMapping): __slots__ = '__keys', '__values' def __new__(cls, *args, **kwargs): self = object.__new__(cls) self.__keys = [] self.__values = [] return self def __init__(self, init=()): for key, value in init.items() if isinstance(init, collections.abc.Mapping) else init: self.__keys.append(key) self.__values.append(value) def __getitem__(self, key): try: index = self.__keys.index(key) except ValueError as e: raise KeyError(key) from e else: return self.__values[index] def __setitem__(self, key, value): try: index = self.__keys.index(key) except ValueError: self.__keys.append(key) self.__values.append(value) else: self.__values[index] = value def __delitem__(self, key): try: index = self.__keys.index(key) except ValueError as e: raise KeyError(key) from e else: del self.__keys[index] del self.__values[index] def __iter__(self): return iter(self.__keys) def __len__(self): return len(self.__keys) def __bool__(self): return len(self.__keys) > 0 def __contains__(self, key): return key in self.__keys def __eq__(self, other): return isinstance(other, hashlessdict) and self.__keys == other.__keys and self.__values == other.__values def get(self, key, value=None): try: index = self.__keys.index(key) except ValueError: return value else: return self.__values[index] def keys(self): return tuple(self.__keys) def values(self): return tuple(self.__values) def items(self): return zip(self.__keys, self.__values) def copy(self): return hashlessdict(self) class frozendict(collections.abc.Mapping): __slots__ = '__base', '__hash' def __new__(cls, base): if isinstance(base, frozendict): return base self = object.__new__(cls) self.__base = dict(base) self.__hash = hash(frozenset(self.__base.items())) # check immutability and precompute hash return self def __reduce__(self): return frozendict, (self.__base,) def __eq__(self, other): if self is other: return True if not isinstance(other, frozendict): return False if self.__base is other.__base: return True if self.__hash != other.__hash or self.__base != other.__base: return False # deduplicate self.__base = other.__base return True __getitem__ = lambda self, item: self.__base.__getitem__(item) __iter__ = lambda self: self.__base.__iter__() __len__ = lambda self: self.__base.__len__() __hash__ = lambda self: self.__hash __contains__ = lambda self, key: self.__base.__contains__(key) copy = lambda self: self.__base.copy() class frozenmultiset(collections.abc.Container): __slots__ = '__items', '__key' def __new__(cls, items): if isinstance(items, frozenmultiset): return items self = object.__new__(cls) self.__items = tuple(items) self.__key = frozenset((item, self.__items.count(item)) for item in self.__items) return self def __and__(self, other): items = list(self.__items) isect = [] for item in other: try: items.remove(item) except ValueError: pass else: isect.append(item) return frozenmultiset(isect) def __sub__(self, other): items = list(self.__items) for item in other: items.remove(item) return frozenmultiset(items) __reduce__ = lambda self: (frozenmultiset, (self.__items,)) __hash__ = lambda self: hash(self.__key) __eq__ = lambda self, other: isinstance(other, frozenmultiset) and self.__key == other.__key __contains__ = lambda self, item: item in self.__items __iter__ = lambda self: iter(self.__items) __len__ = lambda self: len(self.__items) __bool__ = lambda self: bool(self.__items) __add__ = lambda self, other: frozenmultiset(self.__items + tuple(other)) isdisjoint = lambda self, other: not any(item in self.__items for item in other) def enforcetypes(f, signature=None): if signature is None: signature = inspect.signature(f) annotations = [(param.name, param.annotation) for param in signature.parameters.values() if param.annotation != param.empty] if not annotations: return f @functools.wraps(f) def wrapped(*args, **kwargs): bound = signature.bind(*args, **kwargs) bound.apply_defaults() for name, op in annotations: bound.arguments[name] = op(bound.arguments[name]) return f(*bound.args, **bound.kwargs) return wrapped
[docs]def obj2str(obj): '''compact, lossy string representation of arbitrary object''' return '['+','.join(obj2str(item) for item in obj)+']' if isinstance(obj, collections.abc.Iterable) \ else str(obj).strip('0').rstrip('.') or '0' if isinstance(obj, numbers.Real) \ else str(obj)
[docs]def single_or_multiple(f): """ Method wrapper, converts first positional argument to tuple: tuples/lists are passed on as tuples, other objects are turned into tuple singleton. Return values should match the length of the argument list, and are unpacked if the original argument was not a tuple/list. >>> class Test: ... @single_or_multiple ... def square(self, args): ... return [v**2 for v in args] ... >>> T = Test() >>> T.square(2) 4 >>> T.square([2,3]) [4, 9] Args ---- f: method Method that expects a tuple as first positional argument, and that returns a list/tuple of the same length. Returns ------- Wrapped method. """ @functools.wraps(f) def wrapped(self, arg0, *args, **kwargs): ismultiple = isinstance(arg0, (list,tuple)) arg0mod = tuple(arg0) if ismultiple else (arg0,) retvals = f(self, arg0mod, *args, **kwargs) if not ismultiple: retvals, = retvals return retvals return wrapped
# vim:shiftwidth=2:softtabstop=2:expandtab:foldmethod=indent:foldnestmax=2