Source code for rofunc.utils.robolab.formatter.mjcf_parser.element

# Copyright 2018 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Classes to represent MJCF elements in the object model."""

import collections
import copy
import os
import sys

import numpy as np
from lxml import etree

from rofunc.utils.robolab.formatter.mjcf_parser import attribute as attribute_types
from rofunc.utils.robolab.formatter.mjcf_parser import base
from rofunc.utils.robolab.formatter.mjcf_parser import constants
from rofunc.utils.robolab.formatter.mjcf_parser import copier
from rofunc.utils.robolab.formatter.mjcf_parser import debugging
from rofunc.utils.robolab.formatter.mjcf_parser import namescope
from rofunc.utils.robolab.formatter.mjcf_parser import schema
from rofunc.utils.robolab.formatter.mjcf_parser import util

_raw_property = property  # pylint: disable=invalid-name

_UNITS = ('K', 'M', 'G', 'T', 'P', 'E')


def _to_bytes(value_str):
    """Converts a `str` value representing a size in bytes to `int`.

    Args:
      value_str: `str` value to be converted.

    Returns:
      `int` corresponding size in bytes.

    Raises:
      ValueError: if the `str` value passed has an unsupported unit.
    """
    if value_str.isdigit():
        value_int = int(value_str)
    else:
        value_int = int(value_str[:-1])
        unit = value_str[-1].upper()
        if unit not in _UNITS:
            raise ValueError(
                f'unit of `size.memory` should be one of [{", ".join(_UNITS)}], got'
                f' {unit}')
        power = 10 * (_UNITS.index(unit) + 1)
        value_int *= 2 ** power
    return value_int


def _max_bytes(first, second):
    return str(max(_to_bytes(first), _to_bytes(second)))


_CONFLICT_BEHAVIOR_FUNC = {'min': min, 'max': max, 'max_bytes': _max_bytes}


def property(method):  # pylint: disable=redefined-builtin
    """Modifies `@property` to keep track of any `AttributeError` raised.

    Our `Element` implementations overrides the `__getattr__` method. This does
    not interact well with `@property`: if a `property`'s code is buggy so as to
    raise an `AttributeError`, then Python would silently discard it and redirect
    to our `__getattr__` instead, leading to an uninformative stack trace. This
    makes it very difficult to debug issues that involve properties.

    To remedy this, we modify `@property` within this module to store any
    `AttributeError` raised within the respective `Element` object. Then, in our
    `__getattr__` logic, we could re-raise it to preserve the original stack
    trace.

    The reason that this is not implemented as a different decorator is that we
    could accidentally use @property on a new method. This would work fine until
    someone triggers a subtle bug. This is when a proper trace would be most
    useful, but we would still end up with a strange undebuggable stack trace
    anyway.

    Note that at the end of this module, we have a `del property` to prevent this
    override from being broadcasted externally.

    Args:
      method: The method that is being decorated.

    Returns:
      A `property` corresponding to the decorated method.
    """

    def _mjcf_property(self):
        try:
            return method(self)
        except:
            _, err, tb = sys.exc_info()
            err_with_next_tb = err.with_traceback(tb.tb_next)
            if isinstance(err, AttributeError):
                self._last_attribute_error = err_with_next_tb  # pylint: disable=protected-access
            raise err_with_next_tb  # pylint: disable=raise-missing-from

    return _raw_property(_mjcf_property)


def _make_element(spec, parent, attributes=None):
    """Helper function to generate the right kind of Element given a spec."""
    if (spec.name == constants.WORLDBODY
            or (spec.name == constants.SITE
                and (parent.tag == constants.BODY
                     or parent.tag == constants.WORLDBODY))):
        return _AttachableElement(spec, parent, attributes)
    elif isinstance(parent, _AttachmentFrame):
        return _AttachmentFrameChild(spec, parent, attributes)
    elif spec.name == constants.DEFAULT:
        return _DefaultElement(spec, parent, attributes)
    elif spec.name == constants.ACTUATOR:
        return _ActuatorElement(spec, parent, attributes)
    else:
        return _ElementImpl(spec, parent, attributes)


_DEFAULT_NAME_FROM_FILENAME = frozenset(['mesh', 'hfield', 'texture'])


class _ElementImpl(base.Element):
    """Actual implementation of a generic MJCF element object."""
    __slots__ = ['__weakref__', '_spec', '_parent', '_attributes', '_children',
                 '_own_attributes', '_attachments', '_is_removed', '_init_stack',
                 '_is_worldbody', '_cached_namescope', '_cached_root',
                 '_cached_full_identifier', '_cached_revision',
                 '_last_attribute_error']

    def __init__(self, spec, parent, attributes=None):
        attributes = attributes or {}

        # For certain `asset` elements the `name` attribute can be omitted, in which
        # case the name will be the filename without the leading path and extension.
        # See http://www.mujoco.org/book/XMLreference.html#asset.
        if ('name' not in attributes
                and 'file' in attributes
                and spec.name in _DEFAULT_NAME_FROM_FILENAME):
            _, filename = os.path.split(attributes['file'])
            basename, _ = os.path.splitext(filename)
            attributes['name'] = basename

        self._spec = spec
        self._parent = parent
        self._attributes = collections.OrderedDict()
        self._own_attributes = None
        self._children = []
        self._attachments = collections.OrderedDict()
        self._is_removed = False
        self._is_worldbody = (self.tag == 'worldbody')

        if self._parent:
            self._cached_namescope = self._parent.namescope
            self._cached_root = self._parent.root
        self._cached_full_identifier = ''
        self._cached_revision = -1

        self._last_attribute_error = None

        if debugging.debug_mode():
            self._init_stack = debugging.get_current_stack_trace()

        with debugging.freeze_current_stack_trace():
            for child_spec in self._spec.children.values():
                if not (child_spec.repeated or child_spec.on_demand):
                    self._children.append(_make_element(spec=child_spec, parent=self))

            if constants.DCLASS in attributes:
                attributes[constants.CLASS] = attributes[constants.DCLASS]
                del attributes[constants.DCLASS]

            for attribute_name in attributes.keys():
                self._check_valid_attribute(attribute_name)

            for attribute_spec in self._spec.attributes.values():
                value = None
                # Some Reference attributes refer to a namespace that is specified
                # via another attribute. We therefore have to set things up for
                # the additional indirection.
                if attribute_spec.type is attribute_types.Reference:
                    reference_namespace = (
                        attribute_spec.other_kwargs['reference_namespace'])
                    if reference_namespace.startswith(
                            constants.INDIRECT_REFERENCE_NAMESPACE_PREFIX):
                        attribute_spec = copy.deepcopy(attribute_spec)
                        namespace_attrib_name = reference_namespace[
                                                len(constants.INDIRECT_REFERENCE_NAMESPACE_PREFIX):]
                        attribute_spec.other_kwargs['reference_namespace'] = (
                            self._attributes[namespace_attrib_name])
                if attribute_spec.name in attributes:
                    value = attributes[attribute_spec.name]
                try:
                    self._attributes[attribute_spec.name] = attribute_spec.type(
                        name=attribute_spec.name,
                        required=attribute_spec.required,
                        conflict_allowed=attribute_spec.conflict_allowed,
                        conflict_behavior=attribute_spec.conflict_behavior,
                        parent=self, value=value, **attribute_spec.other_kwargs)
                except:
                    # On failure, clear attributes already created
                    for attribute_obj in self._attributes.values():
                        attribute_obj._force_clear()  # pylint: disable=protected-access
                    # Then raise a meaningful error
                    err_type, err, tb = sys.exc_info()
                    raise err_type(  # pylint: disable=raise-missing-from
                        f'during initialization of attribute {attribute_spec.name!r} of '
                        f'element <{self._spec.name}>: {err}').with_traceback(tb)

    def get_init_stack(self):
        """Gets the stack trace where this element was first initialized."""
        if debugging.debug_mode():
            return self._init_stack

    def get_last_modified_stacks_for_all_attributes(self):
        """Gets a dict of stack traces where each attribute was last modified."""
        return collections.OrderedDict(
            [(name, self._attributes[name].last_modified_stack)
             for name in self._spec.attributes])

    def is_same_as(self, other):
        """Checks whether another element is semantically equivalent to this one.

        Two elements are considered equivalent if they have the same
        specification (i.e. same tag appearing in the same context), the same
        attribute values, and all of their children are equivalent. The ordering
        of non-repeated children is not important for this comparison, while
        the ordering of repeated children are important only amongst the same
        type* of children. In other words, for two bodies to be considered
        equivalent, their child sites must appear in the same order, and their
        child geoms must appear in the same order, but permutations between sites
        and geoms are disregarded. (The only exception is in tendon definition,
        where strict ordering of all children is necessary for equivalence.)

        *Note that the notion of "same type" in this function is very loose:
        for example different actuator element subtypes are treated as separate
        types when children ordering is considered. Therefore, two <actuator>
        elements might be considered equivalent even though they result in different
        orderings of `mjData.ctrl` when compiled. As it stands, this function
        is designed primarily as a testing aid and should not be used to guarantee
        that models are actually identical.

        Args:
          other: An `mjcf.Element`

        Returns:
          `True` if `other` element is semantically equivalent to this one.
        """
        if other is None or other.spec != self._spec:
            return False

        for attribute_name in self._spec.attributes.keys():
            attribute = self._attributes[attribute_name]
            other_attribute = getattr(other, attribute_name)
            if isinstance(attribute.value, base.Element):
                if attribute.value.full_identifier != other_attribute.full_identifier:
                    return False
            elif not np.all(attribute.value == other_attribute):
                return False

        if (self._parent and
                self._parent.tag == constants.TENDON and
                self._parent.parent == self.root):
            return self._tendon_has_same_children_as(other)
        else:
            return self._has_same_children_as(other)

    def _has_same_children_as(self, other):
        """Helper function to check whether another element has the same children.

        See docstring for `is_same_as` for explanation about the treatment of
        children ordering.

        Args:
          other: An `mjcf.Element`

        Returns:
          A boolean
        """
        for child_name, child_spec in self._spec.children.items():
            child = self.get_children(child_name)
            other_child = getattr(other, child_name)
            if not child_spec.repeated:
                if ((child is None and other_child is not None) or
                        (child is not None and not child.is_same_as(other_child))):
                    return False
            else:
                if len(child) != len(other_child):
                    return False
                else:
                    for grandchild, other_grandchild in zip(child, other_child):
                        if not grandchild.is_same_as(other_grandchild):
                            return False
        return True

    def _tendon_has_same_children_as(self, other):
        return all(child.is_same_as(other_child)
                   for child, other_child
                   in zip(self.all_children(), other.all_children()))

    def _alias_attributes_dict(self, other):
        if self._own_attributes is None:
            self._own_attributes = self._attributes
        self._attributes = other

    def _restore_attributes_dict(self):
        if self._own_attributes is not None:
            for attribute_name, attribute in self._attributes.items():
                self._own_attributes[attribute_name].value = attribute.value
            self._attributes = self._own_attributes
            self._own_attributes = None

    @property
    def tag(self):
        return self._spec.name

    @property
    def spec(self):
        return self._spec

    @property
    def parent(self):
        return self._parent

    @property
    def namescope(self):
        return self._cached_namescope

    @property
    def root(self):
        return self._cached_root

    def prefixed_identifier(self, prefix_root):
        if not self._spec.identifier and not self._is_worldbody:
            return None
        elif self._is_worldbody:
            prefix = self.namescope.full_prefix(prefix_root=prefix_root)
            return prefix or 'world'
        else:
            full_identifier = (
                self._attributes[self._spec.identifier].to_xml_string(
                    prefix_root=prefix_root))
            if full_identifier:
                return full_identifier
            else:
                prefix = self.namescope.full_prefix(prefix_root=prefix_root)
                prefix = prefix or constants.PREFIX_SEPARATOR
                return prefix + self._default_identifier

    @property
    def full_identifier(self):
        """Fully-qualified identifier used for this element in the generated XML."""
        if self.namescope.revision > self._cached_revision:
            self._cached_full_identifier = self.prefixed_identifier(
                prefix_root=self.namescope.root)
            self._cached_revision = self.namescope.revision
        return self._cached_full_identifier

    @property
    def _default_identifier(self):
        """The default identifier used if this element is not named by the user."""
        if not self._spec.identifier:
            return None
        else:
            siblings = self.root.find_all(self._spec.namespace,
                                          exclude_attachments=True)
            return '{separator}unnamed_{namespace}_{index}'.format(
                separator=constants.PREFIX_SEPARATOR,
                namespace=self._spec.namespace,
                index=siblings.index(self))

    def __dir__(self):
        out_dir = set()
        classes = (type(self),)
        while classes:
            super_classes = set()
            for klass in classes:
                out_dir.update(klass.__dict__)
                super_classes.update(klass.__bases__)
            classes = super_classes
        out_dir.update(self._spec.children)
        out_dir.update(self._spec.attributes)
        if constants.CLASS in out_dir:
            out_dir.remove(constants.CLASS)
            out_dir.add(constants.DCLASS)
        return sorted(out_dir)

    def find(self, namespace, identifier):
        """Finds an element with a particular identifier.

        This function allows the direct access to an arbitrarily deeply nested
        child element by name, without the need to manually traverse through the
        object tree. The `namespace` argument specifies the kind of element to
        find. In most cases, this corresponds to the element's XML tag name.
        However, if an element has multiple specialized tags, then the namespace
        corresponds to the tag name of the most general element of that kind.
        For example, `namespace='joint'` would search for `<joint>` and
        `<freejoint>`, while `namespace='actuator'` would search for `<general>`,
        `<motor>`, `<position>`, `<velocity>`, and `<cylinder>`.

        Args:
          namespace: A string specifying the namespace being searched. See the
            docstring above for explanation.
          identifier: The identifier string of the desired element.

        Returns:
          An `mjcf.Element` object, or `None` if an element with the specified
          identifier is not found.

        Raises:
          ValueError: if either `namespace` or `identifier` is not a string, or if
            `namespace` is not a valid namespace.
        """
        if not isinstance(namespace, str):
            raise ValueError(
                '`namespace` should be a string: got {!r}'.format(namespace))
        if not isinstance(identifier, str):
            raise ValueError(
                '`identifier` should be a string: got {!r}'.format(identifier))
        if namespace not in schema.FINDABLE_NAMESPACES:
            raise ValueError('{!r} is not a valid namespace. Available: {}.'.format(
                namespace, schema.FINDABLE_NAMESPACES))
        if constants.PREFIX_SEPARATOR in identifier:
            scope_name = identifier.split(constants.PREFIX_SEPARATOR)[0]
            try:
                attachment = self.namescope.get('attached_model', scope_name)
                found_element = attachment.find(
                    namespace, identifier[(len(scope_name) + 1):])
            except (KeyError, ValueError):
                found_element = None
        else:
            try:
                found_element = self.namescope.get(namespace, identifier)
            except KeyError:
                found_element = None
            if found_element and self._parent:
                next_parent = found_element.parent
                while next_parent and next_parent != self:
                    next_parent = next_parent.parent
                if not next_parent:
                    found_element = None
        return found_element

    def find_all(self, namespace,
                 immediate_children_only=False, exclude_attachments=False):
        """Finds all elements of a particular kind.

        The `namespace` argument specifies the kind of element to
        find. In most cases, this corresponds to the element's XML tag name.
        However, if an element has multiple specialized tags, then the namespace
        corresponds to the tag name of the most general element of that kind.
        For example, `namespace='joint'` would search for `<joint>` and
        `<freejoint>`, while `namespace='actuator'` would search for `<general>`,
        `<motor>`, `<position>`, `<velocity>`, and `<cylinder>`.

        Args:
          namespace: A string specifying the namespace being searched. See the
            docstring above for explanation.
          immediate_children_only: (optional) A boolean, if `True` then only
            the immediate children of this element are returned.
          exclude_attachments: (optional) A boolean, if `True` then elements
            belonging to attached models are excluded.

        Returns:
          A list of `mjcf.Element`.

        Raises:
          ValueError: if `namespace` is not a valid namespace.
        """
        if namespace not in schema.FINDABLE_NAMESPACES:
            raise ValueError('{!r} is not a valid namespace. Available: {}'.format(
                namespace, schema.FINDABLE_NAMESPACES))
        out = []
        children = self._children if exclude_attachments else self.all_children()
        for child in children:
            if (namespace == child.spec.namespace or
                    # Direct children of attachment frames have custom namespaces of the
                    # form "joint@attachment_frame_<id>".
                    child.spec.namespace and child.spec.namespace.startswith(
                        namespace + constants.NAMESPACE_SEPARATOR) or
                    # Attachment frames are considered part of the "body" namespace.
                    namespace == constants.BODY and isinstance(child, _AttachmentFrame)):
                out.append(child)
            if not immediate_children_only:
                out.extend(child.find_all(namespace,
                                          exclude_attachments=exclude_attachments))
        return out

    def enter_scope(self, scope_identifier):
        """Finds the root element of the given scope and returns it.

        This function allows the access to a nested scope that is a child of this
        element. The `scope_identifier` argument specifies the path to the child
        scope element.

        Args:
          scope_identifier: The path of the desired scope element.

        Returns:
          An `mjcf.Element` object, or `None` if a scope element with the
          specified path is not found.
        """
        if constants.PREFIX_SEPARATOR in scope_identifier:
            scope_name = scope_identifier.split(constants.PREFIX_SEPARATOR)[0]
            try:
                attachment = self.namescope.get('attached_model', scope_name)
            except KeyError:
                return None

            scope_suffix = scope_identifier[(len(scope_name) + 1):]
            if scope_suffix:
                return attachment.enter_scope(scope_suffix)
            else:
                return attachment
        else:
            try:
                return self.namescope.get('attached_model', scope_identifier)
            except KeyError:
                return None

    def _check_valid_attribute(self, attribute_name):
        if attribute_name not in self._spec.attributes:
            raise AttributeError(
                '{!r} is not a valid attribute for <{}>'.format(
                    attribute_name, self._spec.name))

    def _get_attribute(self, attribute_name):
        self._check_valid_attribute(attribute_name)
        return self._attributes[attribute_name].value

    def get_attribute_xml_string(self,
                                 attribute_name,
                                 prefix_root=None,
                                 *,
                                 precision=constants.XML_DEFAULT_PRECISION,
                                 zero_threshold=0):
        self._check_valid_attribute(attribute_name)
        return self._attributes[attribute_name].to_xml_string(
            prefix_root, precision=precision, zero_threshold=zero_threshold)

    def get_attributes(self):
        fix_attribute_name = (
            lambda name: constants.DCLASS if name == constants.CLASS else name)
        return collections.OrderedDict(
            [(fix_attribute_name(name), self._get_attribute(name))
             for name in self._spec.attributes.keys()
             if self._get_attribute(name) is not None])

    def _set_attribute(self, attribute_name, value):
        self._check_valid_attribute(attribute_name)
        self._attributes[attribute_name].value = value
        self.namescope.increment_revision()

    def set_attributes(self, **kwargs):
        if constants.DCLASS in kwargs:
            kwargs[constants.CLASS] = kwargs[constants.DCLASS]
            del kwargs[constants.DCLASS]
        old_values = []
        with debugging.freeze_current_stack_trace():
            for attribute_name, new_value in kwargs.items():
                old_value = self._get_attribute(attribute_name)
                try:
                    self._set_attribute(attribute_name, new_value)
                    old_values.append((attribute_name, old_value))
                except:
                    # On failure, restore old attribute values for those already set.
                    for name, old_value in old_values:
                        self._set_attribute(name, old_value)
                    # Then raise a meaningful error.
                    err_type, err, tb = sys.exc_info()
                    raise err_type(  # pylint: disable=raise-missing-from
                        f'during assignment to attribute {attribute_name!r} of '
                        f'element <{self._spec.name}>: {err}').with_traceback(tb)

    def _remove_attribute(self, attribute_name):
        self._check_valid_attribute(attribute_name)
        self._attributes[attribute_name].clear()
        self.namescope.increment_revision()

    def _check_valid_child(self, element_name):
        try:
            return self._spec.children[element_name]
        except KeyError:
            raise AttributeError(  # pylint: disable=raise-missing-from
                '<{}> is not a valid child of <{}>'
                .format(element_name, self._spec.name))

    def get_children(self, element_name):
        child_spec = self._check_valid_child(element_name)
        if child_spec.repeated:
            return _ElementListView(spec=child_spec, parent=self)
        else:
            for child in self._children:
                if child.tag == element_name:
                    return child
            if child_spec.on_demand:
                return None
            else:
                raise RuntimeError(
                    'Cannot find the non-repeated child <{}> of <{}>. '
                    'This should never happen, as we pre-create these in __init__. '
                    'Please file an bug report. Thank you.'
                    .format(element_name, self._spec.name))

    def add(self, element_name, **kwargs):
        """Add a new child element to this element.

        Args:
          element_name: The tag of the element to add.
          **kwargs: Attributes of the new element being created.

        Raises:
          ValueError: If the 'element_name' is not a valid child, or if an invalid
            attribute is specified in `kwargs`.

        Returns:
          An `mjcf.Element` corresponding to the newly created child element.
        """
        return self.insert(element_name, position=None, **kwargs)

    def insert(self, element_name, position, **kwargs):
        """Add a new child element to this element.

        Args:
          element_name: The tag of the element to add.
          position: Where to insert the new element.
          **kwargs: Attributes of the new element being created.

        Raises:
          ValueError: If the 'element_name' is not a valid child, or if an invalid
            attribute is specified in `kwargs`.

        Returns:
          An `mjcf.Element` corresponding to the newly created child element.
        """
        child_spec = self._check_valid_child(element_name)
        if child_spec.on_demand:
            need_new_on_demand = self.get_children(element_name) is None
        else:
            need_new_on_demand = False
        if not (child_spec.repeated or need_new_on_demand):
            raise ValueError('A <{}> child already exists, please access it directly.'
                             .format(element_name))
        new_element = _make_element(child_spec, self, attributes=kwargs)
        if position is not None:
            self._children.insert(position, new_element)
        else:
            self._children.append(new_element)
        self.namescope.increment_revision()
        return new_element

    def __getattr__(self, name):
        if self._last_attribute_error:
            # This means that we got here through a @property raising AttributeError.
            # We therefore just re-raise the last AttributeError back to the user.
            # Note that self._last_attribute_error was set by our specially
            # instrumented @property decorator.
            exc = self._last_attribute_error
            self._last_attribute_error = None
            raise exc  # pylint: disable=raising-bad-type
        elif name in self._spec.children:
            return self.get_children(name)
        elif name in self._spec.attributes:
            return self._get_attribute(name)
        elif name == constants.DCLASS and constants.CLASS in self._spec.attributes:
            return self._get_attribute(constants.CLASS)
        else:
            raise AttributeError('object has no attribute: {}'.format(name))

    def __setattr__(self, name, value):
        # If this name corresponds to a descriptor for a slotted attribute or
        # settable property then try to invoke the descriptor to set the attribute
        # and return if successful.
        klass_attr = getattr(type(self), name, None)
        if klass_attr is not None:
            try:
                return klass_attr.__set__(self, value)
            except AttributeError:
                pass
        # If we did not find a settable descriptor then we look in the attribute
        # spec to see if there is a MuJoCo attribute matching this name.
        attribute_name = name if name != constants.DCLASS else constants.CLASS
        if attribute_name in self._spec.attributes:
            self._set_attribute(attribute_name, value)
        else:
            raise AttributeError('can\'t set attribute: {}'.format(name))

    def __delattr__(self, name):
        if name in self._spec.children:
            if self._spec.children[name].repeated:
                raise AttributeError(
                    '`{0}` is a collection of child elements, '
                    'which cannot be deleted. Did you mean to call `{0}.clear()`?'
                    .format(name))
            else:
                return self.get_children(name).remove()
        elif name in self._spec.attributes:
            return self._remove_attribute(name)
        else:
            raise AttributeError('object has no attribute: {}'.format(name))

    def _check_attachments_on_remove(self, affect_attachments):
        if not affect_attachments and self._attachments:
            raise ValueError(
                'please use remove(affect_attachments=True) as this will affect some '
                'attributes and/or children belonging to an attached model')
        for child in self._children:
            child._check_attachments_on_remove(affect_attachments)  # pylint: disable=protected-access

    def remove(self, affect_attachments=False):
        """Removes this element from the model."""
        self._check_attachments_on_remove(affect_attachments)
        if affect_attachments:
            for attachment in self._attachments.values():
                attachment.remove(affect_attachments=True)
        for child in list(self._children):
            child.remove(affect_attachments)
        if self._spec.repeated or self._spec.on_demand:
            self._parent._children.remove(self)  # pylint: disable=protected-access
            for attribute in self._attributes.values():
                attribute._force_clear()  # pylint: disable=protected-access
            self._parent = None
            self._is_removed = True
        else:
            for attribute in self._attributes.values():
                attribute._force_clear()  # pylint: disable=protected-access
        self.namescope.increment_revision()

    @property
    def is_removed(self):
        return self._is_removed

    def all_children(self):
        all_children = [child for child in self._children]
        for attachment in self._attachments.values():
            all_children += [child for child in attachment.all_children()
                             if child.spec.repeated]
        return all_children

    def to_xml(self, prefix_root=None, debug_context=None,
               *,
               precision=constants.XML_DEFAULT_PRECISION,
               zero_threshold=0):
        """Generates an etree._Element corresponding to this MJCF element.

        Args:
          prefix_root: (optional) A `NameScope` object to be treated as root
            for the purpose of calculating the prefix.
            If `None` then no prefix is included.
          debug_context: (optional) A `debugging.DebugContext` object to which
            the debugging information associated with the generated XML is written.
            This is intended for internal use within PyMJCF; users should never need
            manually pass this argument.
          precision: (optional) Number of digits to output for floating point
            quantities.
          zero_threshold: (optional) When outputting XML, floating point quantities
            whose absolute value falls below this threshold will be treated as zero.

        Returns:
          An etree._Element object.
        """
        prefix_root = prefix_root or self.namescope
        xml_element = etree.Element(self._spec.name)
        self._attributes_to_xml(xml_element, prefix_root, debug_context,
                                precision=precision, zero_threshold=zero_threshold)
        self._children_to_xml(xml_element, prefix_root, debug_context,
                              precision=precision, zero_threshold=zero_threshold)
        return xml_element

    def _attributes_to_xml(self, xml_element, prefix_root, debug_context=None,
                           *, precision, zero_threshold):
        del debug_context  # Unused.
        for attribute_name, attribute in self._attributes.items():
            attribute_value = attribute.to_xml_string(prefix_root,
                                                      precision=precision,
                                                      zero_threshold=zero_threshold)
            if attribute_name == self._spec.identifier and attribute_value is None:
                xml_element.set(attribute_name, self.full_identifier)
            elif attribute_value is None:
                continue
            else:
                xml_element.set(attribute_name, attribute_value)

    def _children_to_xml(self, xml_element, prefix_root, debug_context=None,
                         *, precision, zero_threshold):
        for child in self.all_children():
            child_xml = child.to_xml(prefix_root, debug_context,
                                     precision=precision,
                                     zero_threshold=zero_threshold)
            if (child_xml.attrib or len(child_xml)  # pylint: disable=g-explicit-length-test
                    or child.spec.repeated or child.spec.on_demand):
                xml_element.append(child_xml)
                if debugging.debug_mode() and debug_context:
                    debug_comment = debug_context.register_element_for_debugging(child)
                    xml_element.append(debug_comment)
                    if len(child_xml) > 0:  # pylint: disable=g-explicit-length-test
                        child_xml.insert(0, copy.deepcopy(debug_comment))

    def to_xml_string(self, prefix_root=None,
                      self_only=False, pretty_print=True, debug_context=None,
                      *,
                      precision=constants.XML_DEFAULT_PRECISION,
                      zero_threshold=0):
        """Generates an XML string corresponding to this MJCF element.

        Args:
          prefix_root: (optional) A `NameScope` object to be treated as root
            for the purpose of calculating the prefix.
            If `None` then no prefix is included.
          self_only: (optional) A boolean, whether to generate an XML corresponding
            only to this element without any children.
          pretty_print: (optional) A boolean, whether to the XML string should be
            properly indented.
          debug_context: (optional) A `debugging.DebugContext` object to which
            the debugging information associated with the generated XML is written.
            This is intended for internal use within PyMJCF; users should never need
            manually pass this argument.
          precision: (optional) Number of digits to output for floating point
            quantities.
          zero_threshold: (optional) When outputting XML, floating point quantities
            whose absolute value falls below this threshold will be treated as zero.

        Returns:
          A string.
        """
        xml_element = self.to_xml(prefix_root, debug_context,
                                  precision=precision,
                                  zero_threshold=zero_threshold)
        if self_only and len(xml_element) > 0:  # pylint: disable=g-explicit-length-test
            etree.strip_elements(xml_element, '*')
            xml_element.text = '...'
        if (self_only and self._spec.identifier and
                not self._attributes[self._spec.identifier].to_xml_string(
                    prefix_root, precision=precision, zero_threshold=zero_threshold)):
            del xml_element.attrib[self._spec.identifier]
        xml_string = util.to_native_string(
            etree.tostring(xml_element, pretty_print=pretty_print))
        if pretty_print and debug_context:
            return debug_context.commit_xml_string(xml_string)
        else:
            return xml_string

    def __str__(self):
        return self.to_xml_string(self_only=True, pretty_print=False)

    def __repr__(self):
        return 'MJCF Element: ' + str(self)

    def _check_valid_attachment(self, other):
        self_spec = self._spec
        if self_spec.name == constants.WORLDBODY:
            self_spec = self._spec.children[constants.BODY]

        other_spec = other.spec
        if other_spec.name == constants.WORLDBODY:
            other_spec = other_spec.children[constants.BODY]

        if other_spec != self_spec:
            raise ValueError(
                'The attachment must have the same spec as this element.')

    def _attach(self, other, exclude_worldbody=False, dry_run=False):
        """Attaches another element of the same spec to this element.

        All children of `other` will be treated as children of this element.
        All XML attributes which are defined in `other` but not defined in this
        element will be copied over, and any conflicting XML attribute value causes
        an error. After the attachment, any XML attribute modified in this element
        will also affect `other` and vice versa.

        Children of this element which are not a repeated elements will also be
        attached by the corresponding children of `other`.

        Args:
          other: Another Element with the same spec.
          exclude_worldbody: (optional) A boolean. If `True`, then don't do anything
            if `other` is a worldbody.
          dry_run: (optional) A boolean, if `True` only verify that the operation
            is valid without actually committing any change.

        Raises:
          ValueError: If `other` has a different spec, or if there are conflicting
            XML attribute values.
        """
        self._check_valid_attachment(other)
        if exclude_worldbody and other.tag == constants.WORLDBODY:
            return
        if dry_run:
            self._check_conflicting_attributes(other, copying=False)
        else:
            self._attachments[other.namescope] = other
            self._sync_attributes(other, copying=False)
        self._attach_children(other, exclude_worldbody, dry_run)
        if other.tag != constants.WORLDBODY and not dry_run:
            other._alias_attributes_dict(self._attributes)  # pylint: disable=protected-access

    def _detach(self, other_namescope):
        """Detaches a model with the specified namescope."""
        attached_element = self._attachments.get(other_namescope)
        if attached_element:
            attached_element._restore_attributes_dict()  # pylint: disable=protected-access
            del self._attachments[other_namescope]
        for child in self._children:
            child._detach(other_namescope)  # pylint: disable=protected-access

    def _check_conflicting_attributes(self, other, copying):
        for attribute_name, other_attribute in other.get_attributes().items():
            if attribute_name == constants.DCLASS:
                attribute_name = constants.CLASS
            if ((not self._attributes[attribute_name].conflict_allowed)
                    and self._attributes[attribute_name].value is not None
                    and other_attribute is not None
                    and np.asarray(
                        self._attributes[attribute_name].value != other_attribute).any()):
                raise ValueError(
                    'Conflicting values for attribute `{}`: {} vs {}'
                    .format(attribute_name,
                            self._attributes[attribute_name].value,
                            other_attribute))

    def _sync_attributes(self, other, copying):
        self._check_conflicting_attributes(other, copying)
        for attribute_name, other_attribute in other.get_attributes().items():
            if attribute_name == constants.DCLASS:
                attribute_name = constants.CLASS

            self_attribute = self._attributes[attribute_name]
            if other_attribute is not None:
                if self_attribute.conflict_behavior in _CONFLICT_BEHAVIOR_FUNC:
                    if self_attribute.value is not None:
                        self_attribute.value = (
                            _CONFLICT_BEHAVIOR_FUNC[self_attribute.conflict_behavior](
                                self_attribute.value, other_attribute))
                    else:
                        self_attribute.value = other_attribute
                elif copying or not self_attribute.conflict_allowed:
                    self_attribute.value = other_attribute

    def _attach_children(self, other, exclude_worldbody, dry_run=False):
        for other_child in other.all_children():
            if not other_child.spec.repeated:
                self_child = self.get_children(other_child.spec.name)
                self_child._attach(other_child, exclude_worldbody, dry_run)  # pylint: disable=protected-access

    def resolve_references(self):
        for attribute in self._attributes.values():
            if isinstance(attribute, attribute_types.Reference):
                if attribute.value and isinstance(attribute.value, str):
                    referred = self.root.find(
                        attribute.reference_namespace, attribute.value)
                    if referred:
                        attribute.value = referred
        for child in self.all_children():
            child.resolve_references()

    def _update_references(self, reference_dict):
        for attribute in self._attributes.values():
            if isinstance(attribute, attribute_types.Reference):
                if attribute.value in reference_dict:
                    attribute.value = reference_dict[attribute.value]
        for child in self.all_children():
            child._update_references(reference_dict)  # pylint: disable=protected-access


class _AttachableElement(_ElementImpl):
    """Specialized object representing a <site> or <worldbody> element.

    This element defines a frame to which another MJCF model can be attached.
    """
    __slots__ = []

    def attach(self, attachment):
        """Attaches another MJCF model at this site.

        An empty <body> will be created as an attachment frame. All children of
        `attachment`'s <worldbody> will be treated as children of this frame.
        Furthermore, all other elements in `attachment` are merged into the root
        of the MJCF model to which this element belongs.

        Args:
          attachment: An MJCF `RootElement`

        Returns:
          An `mjcf.Element` corresponding to the attachment frame. A joint can be
          added directly to this frame to give degrees of freedom to the attachment.

        Raises:
          ValueError: If `other` is not a valid attachment to this element.
        """
        if not isinstance(attachment, RootElement):
            raise ValueError('Expected a mjcf.RootElement: got {}'
                             .format(attachment))
        if attachment.namescope.parent is not None:
            raise ValueError('The model specified is already attached elsewhere')
        if attachment.namescope == self.namescope:
            raise ValueError('Cannot merge a model to itself')
        self.root._attach(attachment, exclude_worldbody=True, dry_run=True)  # pylint: disable=protected-access

        if self.namescope.has_identifier('namescope', attachment.model):
            id_number = 1
            while self.namescope.has_identifier(
                    'namescope', '{}_{}'.format(attachment.model, id_number)):
                id_number += 1
            attachment.model = '{}_{}'.format(attachment.model, id_number)
        attachment.namescope.parent = self.namescope

        if self.tag == constants.WORLDBODY:
            frame_parent = self
            frame_siblings = self._children
            index = len(frame_siblings)
        else:
            frame_parent = self._parent
            frame_siblings = self._parent._children  # pylint: disable=protected-access
            index = frame_siblings.index(self) + 1
            while (index < len(frame_siblings)
                   and isinstance(frame_siblings[index], _AttachmentFrame)):
                index += 1
        frame = _AttachmentFrame(frame_parent, self, attachment)
        frame_siblings.insert(index, frame)
        self.root._attach(attachment, exclude_worldbody=True)  # pylint: disable=protected-access
        return frame


class _AttachmentFrame(_ElementImpl):
    """An specialized <body> representing a frame holding an external attachment.
    """
    __slots__ = ['_site', '_attachment']

    def __init__(self, parent, site, attachment):
        if parent.tag == constants.WORLDBODY:
            spec = schema.WORLD_ATTACHMENT_FRAME
        else:
            spec = schema.ATTACHMENT_FRAME

        spec_is_copied = False
        for child_name, child_spec in spec.children.items():
            if child_spec.namespace:
                if not spec_is_copied:
                    spec = copy.deepcopy(spec)
                    spec_is_copied = True
                spec_as_dict = child_spec._asdict()
                spec_as_dict['namespace'] = '{}{}attachment_frame_{}'.format(
                    child_spec.namespace, constants.NAMESPACE_SEPARATOR, id(self))
                spec.children[child_name] = type(child_spec)(**spec_as_dict)

        attributes = {}
        with debugging.freeze_current_stack_trace():
            for attribute_name in spec.attributes.keys():
                if hasattr(site, attribute_name):
                    attributes[attribute_name] = getattr(site, attribute_name)
        super().__init__(spec, parent, attributes)
        self._site = site
        self._attachment = attachment
        self._attachments[attachment.namescope] = attachment.worldbody
        self.namescope.add('attachment_frame', attachment.namescope.name, self)
        self.namescope.add('attached_model', attachment.namescope.name, attachment)

    def prefixed_identifier(self, prefix_root=None):
        prefix = self.namescope.full_prefix(prefix_root)
        return prefix + self._attachment.namescope.name + constants.PREFIX_SEPARATOR

    def to_xml(self, prefix_root=None, debug_context=None,
               *,
               precision=constants.XML_DEFAULT_PRECISION,
               zero_threshold=0):
        xml_element = (super().to_xml(prefix_root, debug_context,
                                      precision=precision,
                                      zero_threshold=zero_threshold))
        xml_element.set('name', self.prefixed_identifier(prefix_root))
        return xml_element

    @property
    def full_identifier(self):
        return self.prefixed_identifier(self.namescope.root)

    def _detach(self, other_namescope):
        super()._detach(other_namescope)
        if other_namescope is self._attachment.namescope:
            self.namescope.remove('attachment_frame', self._attachment.namescope.name)
            self.namescope.remove('attached_model', self._attachment.namescope.name)
            self.remove()


class _AttachmentFrameChild(_ElementImpl):
    """A child element of an attachment frame.

    Right now, this is always a <joint> or a <freejoint>. The name of the joint
    is not freely specifiable, but instead just inherits from the parent frame.
    This ensures uniqueness, as attachment frame identifiers always end in '/'.
    """
    __slots__ = []

    def to_xml(self, prefix_root=None, debug_context=None,
               *,
               precision=constants.XML_DEFAULT_PRECISION,
               zero_threshold=0):
        xml_element = (super().to_xml(prefix_root, debug_context,
                                      precision=precision,
                                      zero_threshold=zero_threshold))
        if self.spec.namespace is not None:
            if self.name:
                name = (self._parent.prefixed_identifier(prefix_root) +
                        self.name + constants.PREFIX_SEPARATOR)
            else:
                name = self._parent.prefixed_identifier(prefix_root)
            xml_element.set('name', name)
        return xml_element

    def prefixed_identifier(self, prefix_root=None):
        if self.name:
            return (self._parent.prefixed_identifier(prefix_root) +
                    self.name + constants.PREFIX_SEPARATOR)
        else:
            return self._parent.prefixed_identifier(prefix_root)


class _DefaultElement(_ElementImpl):
    """Specialized object representing a <default> element.

    This is necessary for the proper handling of global defaults.
    """
    __slots__ = []

    def _attach(self, other, exclude_worldbody=False, dry_run=False):
        self._check_valid_attachment(other)
        if ((not isinstance(self._parent, RootElement))
                or (not isinstance(other.parent, RootElement))):
            raise ValueError('Only global <{}> can be attached'
                             .format(constants.DEFAULT))
        if not dry_run:
            self._attachments[other.namescope] = other

    def all_children(self):
        return [child for child in self._children]

    def to_xml(self, prefix_root=None, debug_context=None,
               *,
               precision=constants.XML_DEFAULT_PRECISION,
               zero_threshold=0):
        prefix_root = prefix_root or self.namescope
        xml_element = (super().to_xml(prefix_root, debug_context,
                                      precision=precision,
                                      zero_threshold=zero_threshold))
        if isinstance(self._parent, RootElement):
            root_default = etree.Element(self._spec.name)
            root_default.append(xml_element)
            for attachment in self._attachments.values():
                attachment_xml = attachment.to_xml(prefix_root, debug_context,
                                                   precision=precision,
                                                   zero_threshold=zero_threshold)
                for attachment_child_xml in attachment_xml:
                    root_default.append(attachment_child_xml)
            xml_element = root_default
        return xml_element


class _ActuatorElement(_ElementImpl):
    """Specialized object representing an <actuator> element."""

    __slots__ = ()

    def _children_to_xml(self, xml_element, prefix_root, debug_context=None,
                         *,
                         precision=constants.XML_DEFAULT_PRECISION,
                         zero_threshold=0):
        debug_comments = {}
        for child in self.all_children():
            child_xml = child.to_xml(prefix_root, debug_context,
                                     precision=precision,
                                     zero_threshold=zero_threshold)
            if debugging.debug_mode() and debug_context:
                debug_comment = debug_context.register_element_for_debugging(child)
                debug_comments[child_xml] = debug_comment
                if len(child_xml) > 0:  # pylint: disable=g-explicit-length-test
                    child_xml.insert(0, copy.deepcopy(debug_comment))
            xml_element.append(child_xml)
            if debugging.debug_mode() and debug_context:
                xml_element.append(debug_comments[child_xml])


[docs]class RootElement(_ElementImpl): """The root `<mujoco>` element of an MJCF model.""" __slots__ = ['_namescope'] def __init__(self, model=None, model_dir='', assets=None): model = model or 'unnamed_model' self._namescope = namescope.NameScope( model, self, model_dir=model_dir, assets=assets) super().__init__( spec=schema.MUJOCO, parent=None, attributes={'model': model}) def _attach(self, other, exclude_worldbody=False, dry_run=False): self._check_valid_attachment(other) if not dry_run: self._attachments[other.namescope] = other self._attach_children(other, exclude_worldbody, dry_run) self.namescope.increment_revision() @property def namescope(self): return self._namescope @property def root(self): return self @property def model(self): return self._namescope.name @model.setter def model(self, new_name): old_name = self._namescope.name self._namescope.name = new_name self._attributes['model'].value = new_name if self.parent_model: self.parent_model.namescope.rename('attachment_frame', old_name, new_name) self.parent_model.namescope.rename('attached_model', old_name, new_name)
[docs] def attach(self, other): return self.worldbody.attach(other)
[docs] def detach(self): parent_model = self.parent_model if not parent_model: raise RuntimeError( 'Cannot `detach` a model that is not attached to some other model.') else: parent_model._detach(self.namescope) # pylint: disable=protected-access self.namescope.parent = None
[docs] def include_copy(self, other, override_attributes=False): other_copier = copier.Copier(other) new_elements = other_copier.copy_into(self, override_attributes) self._update_references(new_elements) self.namescope.increment_revision()
@property def parent_model(self): """The RootElement of the MJCF model to which this one is attached.""" namescope_parent = self._namescope.parent return namescope_parent.mjcf_model if namescope_parent else None @property def root_model(self): return self.parent_model.root_model if self.parent_model else self
[docs] def get_assets(self): """Returns a dict containing the binary assets referenced in this model. This will contain `{vfs_filename: contents}` pairs. `vfs_filename` will be the name of the asset in MuJoCo's Virtual File System, which corresponds to the filename given in the XML returned by `to_xml_string()`. `contents` is a bytestring. This dict can be used together with the result of `to_xml_string()` to construct a `mujoco.Physics` instance: ```python physics = mujoco.Physics.from_xml_string( xml_string=mjcf_model.to_xml_string(), assets=mjcf_model.get_assets()) ``` """ # Get the assets referenced within this `RootElement`'s namescope. assets = {file_obj.to_xml_string(): file_obj.get_contents() for file_obj in self.namescope.files if file_obj.value} # Recursively add assets belonging to attachments. for attached_model in self._attachments.values(): assets.update(attached_model.get_assets()) return assets
[docs] def get_assets_map(self): # Get the assets referenced within this `RootElement`'s namescope. assets = {file_obj._parent.name: file_obj.value.prefix + file_obj.value.extension for file_obj in self.namescope.files if file_obj.value} return assets
@property def full_identifier(self): return self._namescope.full_prefix(self._namescope.root) def __copy__(self): new_model = RootElement(model=self._namescope.name, model_dir=self.namescope.model_dir) new_model.include_copy(self) return new_model def __deepcopy__(self, _): return self.__copy__()
[docs] def is_same_as(self, other): if other is None or other.spec != self._spec: return False return self._has_same_children_as(other)
class _ElementListView: """A hybrid list/dict-like view to a group of repeated MJCF elements.""" def __init__(self, spec, parent): self._spec = spec self._parent = parent self._elements = self._parent._children # pylint: disable=protected-access self._scoped_elements = collections.OrderedDict( [(scope_namescope.name, getattr(scoped_parent, self._spec.name)) for scope_namescope, scoped_parent in self._parent._attachments.items()]) @property def spec(self): return self._spec @property def tag(self): return self._spec.name @property def namescope(self): return self._parent.namescope @property def parent(self): return self._parent def __len__(self): return len(self._full_list()) def __iter__(self): return iter(self._full_list()) def _identifier_not_found_error(self, index): return KeyError('An element <{}> with {}={!r} does not exist' .format(self._spec.name, self._spec.identifier, index)) def _find_index(self, index): """Locates an element given the index among siblings with the same tag.""" if isinstance(index, str) and self._spec.identifier: for i, element in enumerate(self._elements): if (element.tag == self._spec.name and getattr(element, self._spec.identifier) == index): return i raise self._identifier_not_found_error(index) else: count = 0 for i, element in enumerate(self._elements): if element.tag == self._spec.name: if index == count: return i else: count += 1 raise IndexError('list index out of range') def _full_list(self): out_list = [element for element in self._elements if element.tag == self._spec.name] for scoped_elements in self._scoped_elements.values(): out_list += scoped_elements[:] return out_list def clear(self): for child in self._full_list(): child.remove() def __getitem__(self, index): if (isinstance(index, str) and self._spec.identifier and constants.PREFIX_SEPARATOR in index): scope_name = index.split(constants.PREFIX_SEPARATOR)[0] scoped_elements = self._scoped_elements[scope_name] try: return scoped_elements[index[(len(scope_name) + 1):]] except KeyError: # Re-raise so that the error shows the full, un-stripped index string raise self._identifier_not_found_error(index) # pylint: disable=raise-missing-from elif isinstance(index, slice) or (isinstance(index, int) and index < 0): return self._full_list()[index] else: return self._elements[self._find_index(index)] def __delitem__(self, index): found_index = self._find_index(index) self._elements[found_index].remove() def __str__(self): return str( [element.to_xml_string( prefix_root=self.namescope, self_only=True, pretty_print=False) for element in self._full_list()]) def __repr__(self): return 'MJCF Elements List: ' + str(self) # This restores @property back to Python's built-in one. del property del _raw_property