# Copyright (c) 2014 Evalf
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
'''
This module defines the function :func:`parse`, which parses a tensor
expression.
'''
import re, collections, functools
# Convenience function to create a constant in ExpressionAST (details in
# docstring of `parse` below).
_ = lambda arg: (None, arg)
def _sp(count, singular, plural):
'''format ``count``+ ``singular` or ``plural`` depending on ``count``'''
return '{} {}'.format(count, singular if count == 1 else plural)
class ExpressionSyntaxError(Exception): pass
class AmbiguousAlignmentError(Exception): pass
class _IntermediateError(Exception):
'''Intermediate exception, to be catched and converted into an ``ExpressionSyntaxError``.'''
def __init__(self, msg, at=None, count=None):
self.msg = msg
self.at = at
self.count = count
super().__init__(msg)
_Token = collections.namedtuple('_Token', ['type', 'data', 'pos'])
_Token.__doc__ = 'An indivisible part of an expression string.'
_Token.type.__doc__ = 'The type of this token.'
_Token.data.__doc__ = 'Substring of the expression string that belongs to this token.'
_Token.pos.__doc__ = 'The start position of the token in the expression string.'
_Length = collections.namedtuple('_Length', ['pos'])
_Length.__doc__ = 'Yet unknown length, introduced at ``pos`` in the expression string.'
_Length.pos.__doc__ = 'The position where this :class:`_Length` is introduced.'
class _Array:
'''ExpressionAST with shape, indices.
The :class:`_Array` class combines an ExpressionAST with shape and indices
and maintains a list of summed indices in the expression string resulting in
this :class:`_Array`.
Attributes
----------
ast : :class:`tuple`
The ExpressionAST (see :func:`parse`).
indices : :class:`str`
The indices of the array represented by the :attr:`ast`.
shape : :class:`tuple` of :class:`int`\\s or :class:`_Length`\\s
The shape of the array represented by the :attr:`ast`.
summed : :class:`frozenset` of indices (:class:`str`)
A set of indices that are summed in the expression string resulting in
this :class:`_Array`. The indices are not allowed in expressions
involving this :class:`_Array`. For example, index `i` in expression
string ``'a_ij b_i'`` is summed and cannot be used in an expression like
``('a_ij b_i) c_i``.
linked_lengths : :class:`frozenset` of :class:`frozensets` of :class:`_Length`\\s and :class:`int`\\s
A set of sets of :class:`_Length`\\s and :class:`int`\\s. A
:class:`_Length` is introduced if an axis of an :class:`_Array` has an
unknown length, e.g. a dirac has two axes of equal, but unknown length.
All class:`_Length`\\s in a set have the same length. If a set contains
an :class:`int` the class:`_Length`\\s are resolved.
ndim : :class:`int`
The number of dimensions of this :class:`_Array`.
Args
----
ast : :class:`tuple`
See :attr:`_Array.ast`.
indices : :class:`str`
See :attr:`_Array.indices`.
shape : :class:`tuple` of :class:`int`\\s or :class:`_Length`\\s
See :attr:`_Array.shape`.
summed : :class:`frozenset` of indices (:class:`str`)
See :attr:`_Array.summed`.
linked_lengths : :class:`frozenset` of :class:`frozensets` of :class:`_Length`\\s and :class:`int`\\s
See :attr:`_Array.linked_lengths`.
'''
@classmethod
def wrap(cls, ast, indices, shape):
'''Create an :class:`_Array` by wrapping ``ast``.
The ``ast`` should be a constant or variable. Duplicate indices are summed
and numeric indices are replaced by a getitem.
'''
if len(indices) != len(shape):
raise _IntermediateError('Expected {}, got {}.'.format(_sp(len(shape), 'index', 'indices'), len(indices)))
return cls._apply_indices(ast, 0, indices, shape, frozenset(), {})
@classmethod
def _apply_indices(cls, ast, offset, indices, shape, summed, linked_lengths):
'''Wrap ``ast`` in an :class:`_Array`, thereby summing indices occuring twice and applying numeric indices.
When wrapping a variable or gradient the indices of may appear twice,
indicating summation, or numeric, indicating a getitem. This method wraps
``ast`` and applies summation and getitem if needed.
Args
----
ast : :class:`tuple`
See :attr:`_Array.ast`.
offset : :class:`int`
Start at index ``offset`` when looking for indices occuring twice (in
the entire list of ``indices``, not only those in ``indices[offset:]``)
or numeric indices. The list ``indices[offset:]`` is assumed to be
already processed.
indices : :class:`str`
See :attr:`_Array.indices`.
shape : :class:`tuple` of :class:`int`\\s or :class:`_Length`\\s
See :attr:`_Array.shape`.
``indices``.
summed : :class:`frozenset` of indices (:class:`str`)
See :attr:`_Array.summed`.
linked_lengths : :class:`frozenset` of :class:`frozensets` of :class:`_Length`\\s and :class:`int`\\s
See :attr:`_Array.linked_lengths`.
Returns
-------
wrapped_ast : :class:`_Array foo : bar`
'''
summed = set(summed)
linked_lengths = set(linked_lengths)
i = offset
dims = tuple(range(len(indices)))
while i < len(indices):
index = indices[i]
j = indices.index(index)
if '0' <= index <= '9':
index = int(index)
if isinstance(shape[i], int) and index >= shape[i]:
raise _IntermediateError('Index of dimension {} with length {} out of range.'.format(dims[i], shape[i]))
ast = 'getitem', ast, _(i), _(index)
indices = indices[:i] + indices[i+1:]
shape = shape[:i] + shape[i+1:]
dims = dims[:i] + dims[i+1:]
elif index in summed:
raise _IntermediateError('Index {!r} occurs more than twice.'.format(index))
elif j < i:
linked_lengths = set(cls._update_lengths(linked_lengths, index, shape[j], shape[i]))
ast = 'trace', ast, _(j), _(i)
indices = indices[:j] + indices[j+1:i] + indices[i+1:]
shape = shape[:j] + shape[j+1:i] + shape[i+1:]
dims = dims[:j] + dims[j+1:i] + dims[i+1:]
summed.add(index)
i -= 1
else:
if isinstance(shape[i], _Length) and not any(shape[i] in g for g in linked_lengths):
linked_lengths.add(frozenset([shape[i]]))
i += 1
return cls(ast, indices, shape, summed, linked_lengths)
@classmethod
def stack(cls, arrays, index):
'''Stack ``arrays`` along axis ``index``.
The arrays are stacked in given order. All arrays should have matching
shapes, except for the axis labeled ``index``. If an array does not have
the supplied ``index``, the array is expanded with an axis of length one
before stacking. For example, stacking a scalar and an array with shape
``{i: 2}`` along ``i`` gives an array with shape ``{i: 3}``.
Args
----
arrays : a :class:`~collections.abc.Sequence` of :class:`_Array` objects
The arrays to stack.
index : :class:`str`
The index along which to stack the ``arrays``.
Returns
-------
array : :class:`_Array`
The stacked array.
'''
# TODO: assert is_valid_lhs_indices(index)
if len(arrays) == 0:
raise _IntermediateError('Cannot stack 0 arrays.')
if len(set(frozenset(array.indices) - {index} for array in arrays)) != 1:
raise _IntermediateError(
'Cannot stack arrays with unmatched indices (excluding the stack index {!r}): {}.'
.format(index, ', '.join(array.indices for array in arrays)))
indices = index + ''.join(i for i in arrays[0].indices if i != index)
arrays = [(array.append_axis(index, 1) if index not in array.indices else array).transpose(indices) for array in arrays]
if len(arrays) == 1:
return arrays[0]
helper = arrays[0].replace(indices=arrays[0].indices[1:], shape=arrays[0].shape[1:])
for other in arrays[1:]:
other = other.replace(indices=other.indices[1:], shape=other.shape[1:])
shape, linked_lengths = helper._join_shapes(other)
helper = helper.replace(shape=shape, linked_lengths=linked_lengths, summed=helper.summed | other.summed)
# Apply `helper.linked_lengths` to all `arrays`. If the lengths at
# `index` is not known at this point, we won't be able to resolve this
# ever, so raise an exception here.
length = 0
for array in arrays:
shape = array._simplify_shape(helper.linked_lengths)
if isinstance(shape[0], _Length):
raise _IntermediateError('Cannot determine the length of the stack axis, because the length at {} is unknown.'.format(shape[0].pos), at=shape[0].pos)
length += shape[0]
ast = ('concatenate',) + tuple(array.ast for array in arrays)
return helper.replace(ast=ast, indices=indices, shape=(length,)+helper.shape)
@staticmethod
def align(*arrays):
'''Align ``arrays`` to the first array.
Args
----
arrays : :class:`_Array`
The arrays to align.
Returns
-------
aligned_arrays : :class:`tuple` of :class:`_Array` objects
The aligned arrays.
'''
assert len(arrays) > 0
if len(set(frozenset(array.indices) for array in arrays)) != 1:
raise _IntermediateError(
'Cannot align arrays with unmatched indices: {}.'
.format(', '.join(array.indices for array in arrays)))
arrays = [array.transpose(arrays[0].indices) for array in arrays]
helper = arrays[0]
for other in arrays[1:]:
shape, linked_lengths = helper._join_shapes(other)
helper = helper.replace(shape=shape, linked_lengths=linked_lengths, summed=helper.summed | other.summed)
return tuple(array.replace(shape=helper.shape, linked_lengths=helper.linked_lengths, summed=helper.summed) for array in arrays)
def __init__(self, ast, indices, shape, summed, linked_lengths):
assert isinstance(indices, str)
self.ast = tuple(ast)
self.indices = indices
self.shape = tuple(shape)
self.summed = frozenset(summed)
self.linked_lengths = frozenset(linked_lengths)
self.ndim = len(self.indices)
def _join_shapes(self, other):
'''Verify ``self + other`` is valid and return the resulting shape and linked lengths.
Args
----
other : :class:`_Array`
Should have the same (order of) indices as this array.
Returns
-------
shape : :class:`tuple`
The simplified shape of ``self + other``.
linked_lengths : :class:`frozenset` of :class:`frozensets` of :class:`_Length`\\s and :class:`int`\\s
See :attr:`_Array.linked_lengths`. Updated with links resulting from
applying ``self + other``.
'''
assert self.indices == other.indices, 'unaligned'
groups = set(self.linked_lengths | other.linked_lengths)
for index, a, b in zip(self.indices, self.shape, other.shape):
if a == b:
continue
if not isinstance(a, _Length) and not isinstance(b, _Length):
raise _IntermediateError('Shapes at index {!r} differ: {}, {}.'.format(index, a, b))
groups.add(frozenset({a, b}))
linked_lengths = self._join_lengths(other, groups)
return self._simplify_shape(linked_lengths), linked_lengths
def _simplify_shape(self, linked_lengths):
'''Return simplified shape by replacing :class:`_Length`\\s with :class:`int`\\s according to the ``linked_lengths``.'''
shape = []
cache = {k: v for v in linked_lengths for k in v}
for length in self.shape:
if isinstance(length, _Length):
for l in cache[length]:
if not isinstance(l, _Length):
length = l
break
shape.append(length)
return shape
def _join_lengths(*args):
'''Return updated linked lengths resulting from ``self + other``.'''
groups = set()
for arg in args:
groups |= arg.linked_lengths if isinstance(arg, _Array) else arg
cache = {}
for g in groups:
# g = frozenset(itertools.chain.from_iterable(map(linked_lenghts.get, g)))
new_g = set()
for k in g:
new_g |= cache.get(k, frozenset([k]))
new_g = frozenset(new_g)
cache.update((k, new_g) for k in new_g)
linked_lengths = frozenset(cache.values())
# Verify.
for g in linked_lengths:
known = tuple(sorted(set(k for k in g if not isinstance(k, _Length))))
if len(known) > 1:
raise _IntermediateError('Axes have different lengths: {}.'.format(', '.join(map(str, known))))
return linked_lengths
@staticmethod
def _update_lengths(linked_lengths, index, a, b):
'''Add link ``a``, ``b`` to ``linked_lengths``.'''
cache = {l: g for g in linked_lengths for l in g}
if a != b:
if not isinstance(a, _Length) and not isinstance(b, _Length):
raise _IntermediateError('Shapes at index {!r} differ: {}, {}.'.format(index, a, b))
g = cache.get(a, frozenset([a])) | cache.get(b, frozenset([b]))
cache.update((k, g) for k in g)
# Verify.
known = tuple(sorted(set(k for k in g if not isinstance(k, _Length))))
if len(known) > 1:
raise _IntermediateError('Shapes at index {!r} differ: {}.'.format(index, ', '.join(map(str, known))))
elif isinstance(a, _Length):
cache.setdefault(a, frozenset([a]))
return frozenset(cache.values())
def __neg__(self):
'''Return -self.'''
return self.replace(ast=('neg', self.ast))
def _add_sub(self, other, op, name):
'''Return op(self, other).'''
if frozenset(self.indices) != frozenset(other.indices):
raise _IntermediateError('Cannot {} arrays with unmatched indices: {!r}, {!r}.'.format(name, self.indices, other.indices))
other = other.transpose(self.indices)
shape, linked_lengths = self._join_shapes(other)
return _Array((op, self.ast, other.ast), self.indices, shape, self.summed, linked_lengths)
def __add__(self, other):
'''Return self+other.'''
return self._add_sub(other, 'add', 'add')
def __sub__(self, other):
'''Return self-other.'''
return self._add_sub(other, 'sub', 'subtract')
def __mul__(self, other):
'''Return self*other.'''
for a, b in ((self, other), (other, self)):
for index in sorted(frozenset(a.indices) | a.summed):
if index in b.summed:
raise _IntermediateError('Index {!r} occurs more than twice.'.format(index))
common = []
for index, length in zip(self.indices, self.shape):
if index in other.indices:
common.append(index)
else:
other = other.append_axis(index, length)
for index, length in zip(other.indices, other.shape):
if index not in self.indices:
self = self.append_axis(index, length)
indices = self.indices
other = other.transpose(indices)
shape, linked_lengths = self._join_shapes(other)
ast = 'mul', self.ast, other.ast
for index in reversed(common):
i = self.indices.index(index)
ast = 'sum', ast, _(i)
indices = indices[:i] + indices[i+1:]
shape = shape[:i] + shape[i+1:]
return _Array(ast, indices, shape, self.summed | other.summed | frozenset(common), linked_lengths)
def __truediv__(self, other):
'''Return self/value.'''
if other.ndim > 0:
raise _IntermediateError('A denominator must have dimension 0.')
for index in sorted((self.summed | set(self.indices)) & other.summed):
raise _IntermediateError('Index {!r} occurs more than twice.'.format(index))
return _Array(('truediv', self.ast, other.ast), self.indices, self.shape, self.summed | other.summed, self._join_lengths(other))
def __pow__(self, other):
'''Return self**value.'''
if other.ndim > 0:
raise _IntermediateError('An exponent must have dimension 0.')
for index in sorted((self.summed | set(self.indices)) & other.summed):
raise _IntermediateError('Index {!r} occurs more than twice.'.format(index))
return _Array(('pow', self.ast, other.ast), self.indices, self.shape, self.summed | other.summed, self._join_lengths(other))
def grad(self, index, geom, type):
'''Return the gradient w.r.t. ``geom``.'''
assert geom.ndim == 1
assert not isinstance(geom.shape[0], _Length)
assert type in ('grad','surfgrad')
ast = type, self.ast, _(geom)
return _Array._apply_indices(ast, self.ndim, self.indices+index, self.shape+geom.shape, self.summed, self.linked_lengths)
def derivative(self, arg, indices):
'Return the derivative to ``arg``.'
return _Array._apply_indices(('derivative', self.ast, arg.ast), self.ndim, self.indices+indices, self.shape+arg.shape, self.summed, self.linked_lengths)
def append_axis(self, index, length):
'''Return an :class:`_Array` with one additional axis.'''
if index in self.indices or index in self.summed:
raise _IntermediateError('Duplicate index: {!r}.'.format(index))
linked_lengths = self.linked_lengths
if isinstance(length, _Length):
for group in linked_lengths:
if length in group:
break
else:
linked_lengths |= frozenset({frozenset({length})})
return _Array(('append_axis', self.ast, _(length)), self.indices+index, self.shape+(length,), self.summed, linked_lengths)
def transpose(self, indices):
'''Return an :class:`_Array` transposed according to ``indices``.'''
if len(indices) != len(set(indices)):
raise _IntermediateError('Cannot transpose from {!r} to {!r}: duplicate indices.'.format(self.indices, indices))
elif set(self.indices) != set(indices):
raise _IntermediateError('Cannot transpose from {!r} to {!r}: indices differ.'.format(self.indices, indices))
if self.indices == indices:
return self
else:
transpose = tuple(map(self.indices.index, indices))
shape = tuple(map(self.shape.__getitem__, transpose))
return _Array(('transpose', self.ast, _(transpose)), indices, shape, self.summed, self.linked_lengths)
def replace(self, **updates):
'''Return a copy of this :class:`_Array` with attributes replaced by ``updates``.'''
kwargs = dict(ast=self.ast, indices=self.indices, shape=self.shape, summed=self.summed, linked_lengths=self.linked_lengths)
kwargs.update(updates)
return _Array(**kwargs)
class _ExpressionParser:
'''Expression parser
Args
----
expression : :class:`str`
See argument ``expression`` of :func:`parse`.
variables : :class:`dict` of :class:`str` and :class:`nutils.function.Array` pairs
See argument ``variables`` of :func:`parse`.
functions : :class:`dict` of :class:`str` and :class:`int` pairs
See argument ``functions`` of :func:`parse`.
arg_shapes : :class:`dict` of :class:`str` and :class:`tuple` or :class:`int`\\s pairs
See argument ``arg_shapes`` of :func:`parse`.
default_geometry_name : class:`str`
See argument ``default_geometry_name`` of :func:`parse`.
'''
eye_symbols = '$', 'δ'
normal_symbols = 'n',
def __init__(self, expression, variables, functions, arg_shapes, default_geometry_name):
self.expression = expression
self.variables = variables
self.functions = functions
self.arg_shapes = dict(arg_shapes)
self.default_geometry_name = default_geometry_name
def highlight(f):
'wrap ``f`` in a function that converts ``_IntermediateError`` objects'
def wrapper(self, *args, **kwargs):
if hasattr(self, '_tokens'):
pos = self._next.pos
else:
pos = 0
try:
return f(self, *args, **kwargs)
except _IntermediateError as e:
if e.at is None:
at = pos
count = self._next.pos - pos if self._next.pos > pos else len(self._next.data)
else:
at = e.at
count = 1 if e.count is None else e.count
raise ExpressionSyntaxError(e.msg + '\n' + self.expression + '\n' + ' '*at + '^'*count) from e
return wrapper
def _consume(self):
'advance to next token'
self._index += 1
if self._index >= len(self._tokens):
raise _IntermediateError('Unexpected end of expression.', at=len(self.expression))
return self._current
def _consume_if_whitespace(self):
'advance to next token if it is a whitespace'
if self._next.type == 'whitespace':
self._consume()
@highlight
def _consume_assert_whitespace(self):
'assert the next token is whitespace, skip it, and advance to next token'
if self._consume().type != 'whitespace':
raise _IntermediateError('Missing whitespace.', at=self._current.pos)
@highlight
def _consume_assert_equal(self, value, msg=None):
'assert the next token is equal to ``value``'
token = self._consume()
if token.type != value:
if msg is None:
msg = 'Expected {!r}.'.format(value)
raise _IntermediateError(msg, at=token.pos)
return token
@property
def _current(self):
'the current token'
return self._tokens[self._index]
@property
def _next(self):
'the next token'
return self._tokens[min(len(self._tokens)-1, self._index+1)]
@property
def _next_non_whitespace(self):
'the next non-whitespace token'
return self._tokens[self._index+2] if self._next.type == 'whitespace' else self._next
def _get_variable(self, name):
'get variable by ``name`` or raise an error'
value = self.variables.get(name, None)
if value is None:
raise _IntermediateError('Unknown variable: {!r}.'.format(name))
return value
def _get_geometry(self, name):
'get geometry by ``name`` or raise an error'
geom = self._get_variable(name)
if geom.ndim != 1:
raise _IntermediateError('Invalid geometry: expected 1 dimension, but {!r} has {}.'.format(name, geom.ndim))
return geom
def _get_arg(self, name, indices, indices_start):
'get arg by ``name`` or raise an error'
if name in self.arg_shapes:
shape = self.arg_shapes[name]
if len(shape) != len(indices):
raise _IntermediateError('Argument {!r} previously defined with {} instead of {}.'.format(name, _sp(len(shape), 'axis', 'axes'), len(indices)))
else:
shape = tuple(_Length(indices_start+i) for i, j in enumerate(indices))
self.arg_shapes[name] = shape
return _Array.wrap(('arg', _(name)) + tuple(map(_, shape)), indices, shape)
@highlight
def parse_lhs_arg(self, seen_lhs):
'parse lhs arg, e.g. the "x_ij" in "x_kk(x_ij=a_ij)"'
token = self._consume()
if token.type != 'variable':
raise _IntermediateError("Expected an argument, e.g. 'argname'.")
if token.data.startswith('?'):
raise _IntermediateError("The argument name at the left hand side of a substitution must not be prefixed by a '?'.")
name = token.data
if name in seen_lhs:
raise _IntermediateError("Argument {!r} occurs more than once.".format(name))
seen_lhs[name] = token
indices = self._consume().data if self._next.type == 'indices' else ''
for i, index in enumerate(indices):
if index in indices[i+1:]:
raise _IntermediateError('Repeated indices are not allowed on the left hand side.')
elif '0' <= index <= '9':
raise _IntermediateError('Numeric indices are not allowed on the left hand side.')
return self._get_arg(name, indices, self._current.pos)
@highlight
def parse_var(self):
'parse a component of a term, e.g. "1", "a_i", "(2 a_i)", "a_i^2", "abs(x)"'
if self._next.type == '(':
self._consume()
value = self.parse_subexpression()
self._consume_assert_equal(')')
value = value.replace(ast=('group', value.ast))
elif self._next.type == '[':
self._consume()
value = self.parse_subexpression()
self._consume_assert_equal(']')
value = value.replace(ast=('jump', value.ast))
if self._next.type == 'geometry':
geometry_name = self._consume().data
else:
geometry_name = self.default_geometry_name
geom = self._get_geometry(geometry_name)
if self._next.type == 'indices':
indices = self._consume().data
value *= _Array.wrap(('normal', _(geom)), indices, geom.shape)
elif self._next.type == '{':
self._consume()
value = self.parse_subexpression()
self._consume_assert_equal('}')
value = value.replace(ast=('mean', value.ast))
elif self._next.type == '<':
self._consume()
args = self.parse_comma_separated(end='>', parse_item=self.parse_subexpression)
indices = self._consume()
if indices.type != 'indices':
raise _IntermediateError('Expected 1 index.', at=indices.pos, count=len(indices.data))
if len(indices.data) != 1:
raise _IntermediateError('Expected 1 index, got {}.'.format(len(indices.data)), at=indices.pos, count=len(indices.data))
if '0' <= indices.data <= '9':
raise _IntermediateError('Expected a non-numeric index, got {!r}.'.format(indices.data), at=indices.pos, count=len(indices.data))
value = _Array.stack(args, indices.data)
elif self._next.type == 'eye':
self._consume()
if self._next.type == 'indices':
indices = self._consume().data
else:
indices = ''
length = _Length(self._current.pos)
value = _Array.wrap(('eye', _(length)), indices, (length, length))
elif self._next.type == 'normal':
self._consume()
if self._next.type == 'geometry':
geometry_name = self._consume().data
else:
geometry_name = self.default_geometry_name
geom = self._get_geometry(geometry_name)
if self._next.type == 'indices':
indices = self._consume().data
else:
indices = ''
value = _Array.wrap(('normal', _(geom)), indices, geom.shape)
elif self._next.type == 'variable':
token = self._consume()
name = token.data
if name in self.functions and name not in self.variables: # function (and not overriden as variable)
self._consume_assert_equal('(', msg="Expected '(' for function {}.".format(name))
args = self.parse_comma_separated(end=')', parse_item=self.parse_subexpression)
nargs = self.functions[name]
if len(args) != nargs:
raise _IntermediateError('Function {!r} takes {}, got {}.'.format(name, _sp(nargs, 'argument', 'arguments'), len(args)))
args = _Array.align(*args)
value = args[0].replace(ast=('call', _(name))+tuple(arg.ast for arg in args))
elif name.startswith('?'):
indices = self._consume().data if self._next.type == 'indices' else ''
value = self._get_arg(name[1:], indices, self._current.pos)
else:
raw = self._get_variable(name)
indices = self._consume().data if self._next.type == 'indices' else ''
value = _Array.wrap(_(raw), indices, raw.shape)
else:
raise _IntermediateError('Expected a variable, group or function call.')
if self._next.type == 'gradient':
token = self._consume()
if token.data.startswith(',?'):
name = token.data[2:]
if '_' in name:
name, indices = name.split('_', 1)
indices_start = token.pos+3+len(name)
else:
indices = ''
indices_start = 0
arg = self._get_arg(name, indices, indices_start)
value = value.derivative(arg, indices)
else:
gradtype = {',': 'grad', ';': 'surfgrad'}[token.data[0]]
if '_' in token.data[1:]:
geometry_name, indices = token.data[1:].split('_', 1)
else:
geometry_name = self.default_geometry_name
indices = token.data[1:]
geom = self._get_geometry(geometry_name)
for i, index in enumerate(indices):
value = value.grad(index, geom, gradtype)
elif self._next.type == 'indices':
raise _IntermediateError("Indices can only be specified for variables, e.g. 'a_ij', not for groups, e.g. '(a+b)_ij'.", at=self._next.pos, count=len(self._next.data))
if self._next.type == '(':
self._consume()
subs = self.parse_comma_separated(end=')', parse_item=functools.partial(self.parse_substitution, seen_lhs={}))
if not subs:
raise _IntermediateError("Zero substitutions are not allowed.")
ast = ['substitute', value.ast]
links = []
for lhs, rhs in subs:
ast += [lhs.ast, rhs.ast]
links += [rhs.linked_lengths, frozenset(zip(lhs.shape, rhs.shape))]
value = value.replace(ast=ast, linked_lengths=value._join_lengths(*links))
if self._next.type == '^':
token = self._consume()
if self._next.type == '(':
self._consume()
exponent = self.parse_subexpression()
self._consume_assert_equal(')')
else:
if self._next.type == '-':
self._consume()
negate = True
else:
negate = False
exponent = self.parse_const_scalar()
if negate:
exponent = -exponent
value = value**exponent
return value
@highlight
def parse_const_scalar(self):
'parse a constant scalar, e.g. "1", "1.0", "0.1"'
token = self._consume()
if token.type == 'int':
value = _Array.wrap(_(int(token.data)), '', [])
elif token.type == 'float':
value = _Array.wrap(_(float(token.data)), '', [])
else:
raise _IntermediateError('Expected a number.')
if self._next.type == 'gradient':
self._consume()
raise _IntermediateError('Taking a derivative of a constant is not allowed.')
return value
@highlight
def parse_const(self):
'parse a const, possibly with indices, e.g. "1_j"'
value = self.parse_const_scalar()
if self._next.type == 'indices':
token = self._consume()
indices = token.data
for i, index in enumerate(indices):
if '0' <= index <= '9':
raise _IntermediateError('Numeric indices are not allowed on constant values.')
if index in indices[1+i:]:
raise _IntermediateError('Indices of a constant value may not be repeated.')
value = value.append_axis(index, _Length(pos=token.pos+i))
if self._next.type == 'gradient':
self._consume()
raise _IntermediateError('Taking a derivative of a constant is not allowed.')
return value
@highlight
def parse_numerator(self):
'parse the numerator part of a fraction'
if self._next.type in ('int', 'float'):
value = self.parse_const()
else:
value = self.parse_var()
while True:
stop = self._next.pos
if self._next_non_whitespace.type in (')', ']', '}', '>', 'EOF', '+', '-', '/', '|', ','):
break
self._consume_assert_whitespace()
value *= self.parse_var()
return value
@highlight
def parse_denominator(self):
'parse the denominator part of a fraction'
value = self.parse_numerator()
if value.ndim > 0:
raise _IntermediateError('A denominator must have dimension 0.')
return value
def parse_comma_separated(self, end, parse_item):
'parse comma separated values until end token, e.g. "1, 2 (a_ij b_j + 3))" with end token ")"'
items = []
self._consume_if_whitespace()
if self._next.type != end:
while True:
items.append(parse_item())
self._consume_if_whitespace()
if self._next.type != ',':
break
self._consume_assert_equal(',')
self._consume_assert_whitespace()
self._consume_assert_equal(end)
return items
@highlight
def parse_substitution(self, seen_lhs):
'parse a substitution, e.g. "x_ij=a_ij" in "?x_kk(x_ij=a_ij)"'
lhs = self.parse_lhs_arg(seen_lhs)
self._consume_if_whitespace()
self._consume_assert_equal('=')
self._consume_if_whitespace()
rhs = self.parse_subexpression()
if set(lhs.indices) != set(rhs.indices):
raise _IntermediateError('Left and right hand side should have the same indices, got {!r} and {!r}.'.format(lhs.indices, rhs.indices))
rhs = rhs.transpose(lhs.indices)
return lhs, rhs
@highlight
def parse_term(self):
'parse a term, e.g. "a b_i (2 c_i + 1)"'
value = self.parse_numerator()
if self._next_non_whitespace.type == '/':
self._consume_assert_whitespace()
token = self._consume()
assert token.type == '/'
self._consume_assert_whitespace()
denominator = self.parse_denominator()
value /= denominator
return value
@highlight
def parse_subexpression(self):
'parse a scope: the entire expression or a subexpression between parentheses'
self._consume_if_whitespace()
negate = self._next.type == '-'
if negate:
self._consume()
self._consume_if_whitespace()
value = self.parse_term()
if negate:
value = -value
while self._next_non_whitespace.type not in ('|', 'EOF', '_', ')', ']', '}', '>', ','):
self._consume_assert_whitespace()
op_token = self._consume()
if op_token.type not in '+-':
raise _IntermediateError('Expected {!r} or {!r}.'.format('+', '-'), at=op_token.pos, count=len(op_token.data))
self._consume_assert_whitespace()
r_value = self.parse_term()
value = {'+': value.__add__, '-': value.__sub__}[op_token.type](r_value)
self._consume_if_whitespace()
return value
@highlight
def tokenize(self):
'subdivide :attr:`expression` in indivisible tokens'
pos = 0
tokens = [_Token('BOF', '', pos)]
while pos < len(self.expression):
m = re.match(r'\s+', self.expression[pos:])
if m:
tokens.append(_Token('whitespace', m.group(0), pos))
pos += m.end()
continue
if self.expression[pos] in '+-^/|=[]{}()<>,':
tokens.append(_Token(self.expression[pos], self.expression[pos], pos))
pos += 1
continue
m = re.match(r'[?]?[a-zA-Zα-ωΑ-Ω][a-zA-Zα-ωΑ-Ω0-9]*', self.expression[pos:])
if m:
if m.group(0) in self.eye_symbols:
tokens.append(_Token('eye', m.group(0), pos))
elif m.group(0) in self.normal_symbols:
tokens.append(_Token('normal', m.group(0), pos))
else:
tokens.append(_Token('variable', m.group(0), pos))
pos += m.end()
continue
m = re.match(r'[0-9]*[.][0-9]*', self.expression[pos:])
if m:
if m.group(0).startswith('0') and not m.group(0).startswith('0.'):
raise _IntermediateError('Leading zeros are forbidden.', at=pos, count=len(m.group(0)))
tokens.append(_Token('float', m.group(0), pos))
pos += m.end()
continue
m = re.match(r'[0-9]+', self.expression[pos:])
if m:
if m.group(0).startswith('0') and not m.group(0) == '0':
raise _IntermediateError('Leading zeros are forbidden.', at=pos, count=len(m.group(0)))
tokens.append(_Token('int', m.group(0), pos))
pos += m.end()
continue
if self.expression[pos] == '_':
pos += 1
parts = 0
m = re.match(r'[a-zA-Zα-ωΑ-Ω][a-zA-Zα-ωΑ-Ω0-9]*_', self.expression[pos:])
if m:
withgeom = m.group(0)[:-1]
tokens.append(_Token('geometry', m.group(0)[:-1], pos))
pos += m.end()
else:
withgeom = None
m = re.match(r'[a-zA-Z0-9]+', self.expression[pos:])
if m:
tokens.append(_Token('indices', m.group(0), pos))
pos += m.end()
parts += 1
m_arg = re.match(r',[?][a-zA-Zα-ωΑ-Ω][a-zA-Zα-ωΑ-Ω0-9]*(_[a-zA-Z0-9]+)?', self.expression[pos:])
m_geom = re.match(r'[,;]([a-zA-Zα-ωΑ-Ω][a-zA-Zα-ωΑ-Ω0-9]*_)?([a-zA-Z0-9]+)', self.expression[pos:])
if m_arg:
tokens.append(_Token('gradient', m_arg.group(0), pos))
pos += m_arg.end()
parts += 1
elif m_geom:
if withgeom is not None and not m_geom.group(1):
variant_geom = m_geom.group(0)[0] + withgeom + '_' + m_geom.group(2)
variant_default = m_geom.group(0)[0] + self.default_geometry_name + '_' + m_geom.group(2)
raise _IntermediateError('Missing geometry, e.g. {!r} or {!r}.'.format(variant_geom, variant_default), at=pos)
tokens.append(_Token('gradient', m_geom.group(0), pos))
pos += m_geom.end()
parts += 1
if parts == 0:
raise _IntermediateError('Missing indices.', at=pos)
continue
raise _IntermediateError('Unknown symbol: {!r}.'.format(self.expression[pos]), at=pos)
tokens.append(_Token('EOF', '', pos))
self._tokens = tokens
self._index = 0
def _replace_lengths(ast, lengths):
'replace all :class:`_Length` objects in ``ast`` with the lengths in ``lengths``'
if ast[0] is not None:
return (ast[0],) + tuple(_replace_lengths(arg, lengths) for arg in ast[1:])
elif isinstance(ast[1], _Length):
return _(lengths[ast[1]])
else:
return ast
[docs]def parse(expression, variables, functions, indices, arg_shapes={}, default_geometry_name='x'):
'''Parse ``expression`` and return AST.
This function parses a tensor expression with `Einstein Summation
Convection`_ stored in a :class:`str` and returns an Abstract Syntax Tree
(AST). The syntax of ``expression`` is as follows:
* **Integers** or **decimal numbers** are denoted in the usual way.
Examples: ``1``, ``1.2``, ``.2``. A number may not start with a zero,
except when followed by a dot: ``0.1`` is valid, but ``01`` is not.
* **Variables** are denoted with a string of alphanumeric characters. The
first character may not be a numeral. Unlike Python variables,
underscores are not allowed, as they have a special meaning. If the
variable is an array with one or more axes, all those axes should be
labeled with a latin character, the index, and appended to the variable
with an underscore. For example an array ``a`` with two axes can be
denoted with ``a_ij``. Optionally, a single numeral may be used to
select an item at the concerning axis. Example: in ``a_i0`` the first
axis of ``a`` is labeled ``i`` and the first element of the second axis
is selected. If the same index occurs twice, the trace is taken along
the concerning axes. Example: the trace of the first and third axes of
``b`` is denoted by ``b_iji``. It is invalid to specify an index more
than twice. The following names cannot be used as variables: ``n``,
``δ``, ``$``. The variable named ``x``, or the value of argument
``default_geometry_name``, has a special meaning, detailed below.
* A term, the **product** of two or more arrays or scalars, is denoted by
space-separated variables, constants or compound expressions. Example:
``a b c`` denotes the product of the scalars ``a``, ``b`` and ``c``. A
term may start with a number, but a number is not allowed in other parts
of the term. Example: ``2 a`` denotes two times ``a``; ``2 2 a`` and ``2
a 2``` are invalid. When two arrays in a term have the same index, this
index is summed. Example: ``a_i b_i`` denotes the inner product of ``a``
and ``b`` and ``A_ij b_j``` a matrix vector product. It is not allowed
to use an index more than twice in a term.
* The operator ``/`` denotes a **fraction**. Example: in ``a b / c d`` ``a
b`` is the numerator and ``c d`` the denominator. Both the numerator and
the denominator may start with a number. Example: ``2 a / 3 b``. The
denominator must be a scalar. Example: ``2 / a_i b_i`` is valid, but ``2
a_i / b_i`` is not.
.. warning::
This syntax is different from the Python syntax. In Python ``a*b /
c*d`` is mathematically equivalent to ``a*b*d/c``.
* The operators ``+`` and ``-`` denote **add** and **subtract**. Both
operators should be surrounded by whitespace, e.g. ``a + b``. Both
operands should have the same shape. Example: ``a_ij + b_i c_j`` is a
valid, provided that the lengths of the axes with the same indices match,
but ``a_ij + b_i`` is invalid. At the beginning of an expression or a
compound ``-`` may be used to negate the following term. Example: in
``-a b + c`` the term ``a b`` is negated before adding ``c``. It is not
allowed to negate other terms: ``a + -b`` is invalid, so is ``a -b``.
* An expression surrounded by parentheses is a **compound expression** and
can be used as single entity in a term. Example: ``(a_i + b_i) c_i``
denotes the inner product of ``a_i + b_i`` with ``c_i``.
* **Exponentiation** is denoted by a ``^``, where the left and right
operands should be a number, variable or compound expression and the
right operand should be a scalar. Example: ``a^2`` denotes the square of
``a``, ``a^-2`` denotes ``a`` to the power ``-2`` and ``a^(1 / 2)`` the
square root of ``a``.
* An **argument** is denoted by a name — following the same rules as a
variable name — prefixed with a question mark. An argument is a scalar
or array with a yet unknown value. Example: ``basis_i ?coeffs_i``
denotes the inner product of a basis with unknown coefficient vector
``?coeffs``. If possible the shape of the argument is deduced from the
expression. In the previous example the shape of ``?coeffs`` is equal to
the shape of ``basis``. If the shape cannot be deduced from the
expression the shape should be defined manually (see :func:`parse`).
Arguments and variables live in separate namespaces: ``?x`` and ``x`` are
different entities.
* An argument may be **substituted** by appending without whitespace
``(arg = value)`` to a variable of compound expression, where ``arg`` is
an argument and ``value`` the substitution. The substitution applies to
the variable of compound expression only. The value may be an
expression. Example: ``2 ?x(x = 3 + y)`` is equivalent to ``2 (3 + y)``
and ``2 ?x(x=y) + 3`` is equivalent to ``2 (y) + 3``. It is possible to
apply multiple substitutions. Example: ``(?x + ?y)(x = 1, y = )2`` is
equivalent to ``1 + 2``.
* The **gradient** of a variable to the default geometry — the default
geometry is variable ``x`` unless overriden by the argument
``default_geometry_name`` — is denoted by an underscore, a comma and an
index. If the variable is an array with more than one axis, the
underscore is omitted. Example: ``a_,i`` denotes the gradient of the
scalar ``a`` to the geometry and ``b_i,j`` the gradient of vector ``b``.
The gradient of a compound expression is denoted by an underscore, a
comma and an index. Example: ``(a_i + b_j)_,k`` denotes the gradient of
``a_i + b_j``. The usual summation rules apply and it is allowed to use
a numeral as index. The **surface gradient** is denoted with a semicolon
instead of a comma, but follows the same rules as the gradient otherwise.
Example: ``a_i;j`` is the sufrace gradient of ``a_i`` to the geometry.
It is also possible to take the gradient to another geometry by appending
the name of the geometry, which should exist as a variable, and an
underscore directly after the comma of semicolon. Example:
``a_i,altgeom_j`` denotes the gradient of ``a_i`` to ``altgeom`` and the
gradient axis has index ``j``. Futhermore, it is possible to take the
**derivative** to an argument by adding the argument with appropriate
indices after the comma. Example: ``(?x^2)_,?x`` denotes the derivative
of ``?x^2`` to ``?x``, which is equivalent to ``2 ?x``, and ``(?y_i
?y_i),?y_j`` is the derivative of ``?y_i ?y_i`` to ``?y_j``, which is
equivalent to ``2 ?y_j``.
* The **normal** of the default geometry is denoted by ``n_i``, where the
index ``i`` may be replaced with an index of choice. The normal with
respect to different geometry is denoted by appending an underscore with
the name of the geometry right after ``n``. Example: ``n_altgeom_j`` is
the normal with respect to geometry ``altgeom``.
* A **dirac** is denoted by ``δ`` or ``$`` and takes two indices. The
shape of the dirac is deduced from the expression. Example: let ``A`` be
a square matrix with three rows and columns, then ``δ_ij`` in ``(A_ij - λ
δ_ij) x_j`` has three rows and columns as well.
* An expression surrounded by square brackets or curly braces denotes the
**jump** or **mean**, respectively, of the enclosed expression. Example:
``[ a_i ]`` denotes the jump of ``a_i`` and ``{ a_i + b_i }`` denotes the
mean of ``a_i + b_i``.
* A **function call** is denoted by a name — following the same rules as
for a variable name — directly followed by the left parenthesis ``(``,
without a space. The arguments to the function are separated by a comma
and at least one space. The function is applied pointwise to the
arguments and all arguments should have the same shape. Example:
``f(x_i, y_i)``.denotes the call to function ``f`` with arguments ``x_i``
and ``y_i``. Functions and variables share a namespace: defining a
variable with the same name as a function renders the function
inaccessible.
* A **stack** of two or more arrays along an axis is denoted by a ``<``
followed by comma and space separated arrays followed by ``>`` and an
index. If an argument does not have an axis with the specified stack
index, the argument is expanded with an axis of length one. Beside the
stack axis, all arguments should have the same shape. Example: ``<1,
x_i>_i``, with ``x`` a vector of length three, creates an array with
components ``1``, ``x_0``, ``x_1``, ``x_2``.
.. _`Einstein Summation Convection`: https://en.wikipedia.org/wiki/Einstein_notation
Args
----
expression : :class:`str`
The expression to parse. See :mod:`~nutils.expression` for the
expression syntax.
variables : :class:`dict` of :class:`str` and :class:`nutils.function.Array` pairs
A :class:`dict` of variable names and array pairs. All variables used in
the ``expression`` should exist in ``variables``.
functions : :class:`dict` of :class:`str` and :class:`int` pairs
A :class:`dict` of function names and number of arguments pairs. All
functions used in the ``expression`` should exist in ``functions``.
indices : :class:`str`
The indices used for aligning the resulting array. For example, let
``expression`` be ``'a_ij'``. If ``indices`` is ``'ij'``, then the
returned array is simply ``variables['a']``, but if ``indices`` is
``'ji'`` the transpose of ``variables['a']`` is returned. All indices of
the ``expression`` should be listed precisely once.
arg_shapes : :class:`dict` of :class:`str` and :class:`tuple` or :class:`int`\\s pairs
A :class:`dict` of argument names and shapes. If ``expression`` contains
an argument not present in ``arg_shapes`` the shape will be decuded from
the expression and added to a copy of ``arg_shapes``.
default_geometry_name : :class:`str`
The name of the default geometry variable. When computing a gradient or
the normal, e.g. ``'f_,i'`` or ``'n_i'``, this variable is used as the
geometry, unless the geometry is explicitly mentioned in the expression.
Default: ``'x'``.
Returns
-------
ast : :class:`tuple`
The parsed ``expression`` as an abstract syntax tree (AST). The AST is a
:class:`tuple` of an opcode and arguments. The special opcode ``None``
indicates that the single argument is used verbatim. All other opcodes
have AST as arguments. The following opcodes exist::
(None, const)
('group', group)
('arg', name, *shape)
('substitute', array, arg, value)
('call', func, arg)
('eye', length)
('normal', geom)
('getitem', array, dim, index)
('trace', array, n1, n2)
('sum', array, axis)
('concatenate', *args)
('grad', array, geom)
('surfgrad', array, geom)
('derivative', func, target)
('append_axis', array, length)
('transpose', array, trans)
('jump', array)
('mean', array)
('neg', array)
('add', left, right)
('sub', left, right)
('mul', left, right)
('truediv', left, right)
('pow', left, right)
arg_shapes : :class:`dict` of :class:`str` and :class:`tuple` of :class:`int`\\s pairs
A copy of ``arg_shapes`` updated with shapes of arguments present in this
``expression``.
'''
parser = _ExpressionParser(expression, variables, functions, arg_shapes, default_geometry_name)
parser.tokenize()
value = parser.parse_subexpression()
parser._consume_assert_equal('EOF', msg='Unexpected symbol at end of expression.')
if indices is None:
if value.ndim > 1:
raise AmbiguousAlignmentError(
'Cannot unambiguously align the array because the array has more than one dimension.\n'
+ expression + '\n'
+ '^'*len(expression))
ast = value.ast
else:
try:
ast = value.transpose(indices).ast
except _IntermediateError as e:
raise ExpressionSyntaxError(e.msg + '\n' + expression + '\n' + '^'*len(expression)) from e
lengths = {}
undetermined = set()
for group in value.linked_lengths:
val = None
for i in group:
if not isinstance(i, _Length):
assert val is None
val = i
if val is None:
undetermined.update(i.pos for i in group)
else:
lengths.update((length, val) for length in group)
for pos in sorted(undetermined):
raise ExpressionSyntaxError('Length of axis cannot be determined from the expression.' + '\n' + expression + '\n' + ' '*pos + '^')
arg_shapes = dict(arg_shapes)
for arg, shape in parser.arg_shapes.items():
arg_shapes[arg] = tuple(lengths.get(i, i) for i in shape)
return _replace_lengths(ast, lengths), arg_shapes
# vim:shiftwidth=2:softtabstop=2:expandtab:foldmethod=indent:foldnestmax=2