Source code for rofunc.utils.robolab.formatter.mjcf_parser.attribute

# Copyright 2018 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""Classes representing various MJCF attribute data types."""

import abc
import collections
import hashlib
import io
import os

import numpy as np
from rofunc.utils.robolab.formatter.mjcf_parser import util
from rofunc.utils.robolab.formatter.mjcf_parser import io as resources

from rofunc.utils.robolab.formatter.mjcf_parser import base
from rofunc.utils.robolab.formatter.mjcf_parser import constants
from rofunc.utils.robolab.formatter.mjcf_parser import debugging
from rofunc.utils.robolab.formatter.mjcf_parser import skin

# Copybara placeholder for internal file handling dependency.

_INVALID_REFERENCE_TYPE = (
    'Reference should be an MJCF Element whose type is {valid_type!r}: '
    'got {actual_type!r}.')

_MESH_EXTENSIONS = ('.stl', '.msh', '.obj')

# MuJoCo's compiler enforces this.
_INVALID_MESH_EXTENSION = (
    'Mesh files must have one of the following extensions: {}, got {{}}.'
    .format(_MESH_EXTENSIONS))


class _Attribute(metaclass=abc.ABCMeta):
    """Abstract base class for MJCF attribute data types."""

    def __init__(self, name, required, parent, value,
                 conflict_allowed, conflict_behavior):
        self._name = name
        self._required = required
        self._parent = parent
        self._value = None
        self._conflict_allowed = conflict_allowed
        self._conflict_behavior = conflict_behavior
        self._check_and_assign(value)

    def _check_and_assign(self, new_value):
        if new_value is None:
            self.clear()
        elif isinstance(new_value, str):
            self._assign_from_string(new_value)
        else:
            self._assign(new_value)
        if debugging.debug_mode():
            self._last_modified_stack = debugging.get_current_stack_trace()

    @property
    def last_modified_stack(self):
        if debugging.debug_mode():
            return self._last_modified_stack

    @property
    def value(self):
        return self._value

    @value.setter
    def value(self, new_value):
        self._check_and_assign(new_value)

    @abc.abstractmethod
    def _assign(self, value):
        raise NotImplementedError  # pragma: no cover

    def clear(self):
        if self._required:
            raise AttributeError(
                'Attribute {!r} of element <{}> is required'
                .format(self._name, self._parent.tag))
        else:
            self._force_clear()

    def _force_clear(self):
        self._before_clear()
        self._value = None
        if debugging.debug_mode():
            self._last_modified_stack = debugging.get_current_stack_trace()

    def _before_clear(self):
        pass

    def _assign_from_string(self, string):
        self._assign(string)

    def to_xml_string(self, prefix_root, **kwargs):  # pylint: disable=unused-argument
        if self._value is None:
            return None
        else:
            return str(self._value)

    @property
    def conflict_allowed(self):
        return self._conflict_allowed

    @property
    def conflict_behavior(self):
        return self._conflict_behavior


[docs]class String(_Attribute): """A string MJCF attribute.""" def _assign(self, value): if not isinstance(value, str): raise ValueError('Expect a string value: got {}'.format(value)) elif not value: self.clear() else: self._value = value
[docs]class Integer(_Attribute): """An integer MJCF attribute.""" def _assign(self, value): try: float_value = float(value) int_value = int(float(value)) if float_value != int_value: raise ValueError except ValueError: raise ValueError( 'Expect an integer value: got {}'.format(value)) from None self._value = int_value
[docs]class Float(_Attribute): """An float MJCF attribute.""" def _assign(self, value): try: float_value = float(value) except ValueError: raise ValueError('Expect a float value: got {}'.format(value)) from None self._value = float_value
[docs] def to_xml_string(self, prefix_root=None, *, precision=constants.XML_DEFAULT_PRECISION, zero_threshold=0, **kwargs): if self._value is None: return None else: out = io.BytesIO() value = self._value if abs(value) < zero_threshold: value = 0.0 np.savetxt(out, [value], fmt=f'%.{precision:d}g', newline=' ') return util.to_native_string(out.getvalue())[:-1] # Strip trailing space.
[docs]class Keyword(_Attribute): """A keyword MJCF attribute.""" def __init__(self, name, required, parent, value, conflict_allowed, conflict_behavior, valid_values): self._valid_values = collections.OrderedDict( (value.lower(), value) for value in valid_values) super().__init__(name, required, parent, value, conflict_allowed, conflict_behavior) def _assign(self, value): if value is None or value == '': # pylint: disable=g-explicit-bool-comparison self.clear() else: try: self._value = self._valid_values[str(value).lower()] except KeyError: raise ValueError('Expect keyword to be one of {} but got: {}'.format( list(self._valid_values.values()), value)) from None @property def valid_values(self): return list(self._valid_values.keys())
[docs]class Array(_Attribute): """An array MJCF attribute.""" def __init__(self, name, required, parent, value, conflict_allowed, conflict_behavior, length, dtype): self._length = length self._dtype = dtype super().__init__(name, required, parent, value, conflict_allowed, conflict_behavior) def _assign(self, value): self._value = self._check_shape(np.array(value, dtype=self._dtype)) def _assign_from_string(self, string): self._assign(np.fromstring(string, dtype=self._dtype, sep=' '))
[docs] def to_xml_string(self, prefix_root=None, *, precision=constants.XML_DEFAULT_PRECISION, zero_threshold=0, **kwargs): if self._value is None: return None else: out = io.BytesIO() value = self._value if zero_threshold: value = np.copy(value) value[np.abs(value) < zero_threshold] = 0 np.savetxt(out, value, fmt=f'%.{precision:d}g', newline=' ') return util.to_native_string(out.getvalue())[:-1] # Strip trailing space.
def _check_shape(self, array): actual_length = array.shape[0] if len(array.shape) > 1: raise ValueError('Expect one-dimensional array: got {}'.format(array)) if self._length and actual_length > self._length: raise ValueError('Expect array with no more than {} entries: got {}' .format(self._length, array)) return array
[docs]class Identifier(_Attribute): """A string attribute that represents a unique identifier of an element.""" def _assign(self, value): if not isinstance(value, str): raise ValueError('Expect a string value: got {}'.format(value)) elif not value: self.clear() elif self._parent.spec.namespace == 'body' and value == 'world': raise ValueError('A body cannot be named \'world\'. ' 'The name \'world\' is used by MuJoCo to refer to the ' '<worldbody>.') elif constants.PREFIX_SEPARATOR in value: raise ValueError( 'An identifier cannot contain a {!r}, ' 'as this is reserved for scoping purposes: got {!r}' .format(constants.PREFIX_SEPARATOR, value)) else: old_value = self._value if value != old_value: self._parent.namescope.add( self._parent.spec.namespace, value, self._parent) if old_value: self._parent.namescope.remove(self._parent.spec.namespace, old_value) self._value = value def _before_clear(self): if self._value: self._parent.namescope.remove(self._parent.spec.namespace, self._value) def _defaults_string(self, prefix_root): prefix = self._parent.namescope.full_prefix(prefix_root, as_list=True) prefix.append(self._value or '') return constants.PREFIX_SEPARATOR.join(prefix) or constants.PREFIX_SEPARATOR
[docs] def to_xml_string(self, prefix_root=None, **kwargs): if self._parent.tag == constants.DEFAULT: return self._defaults_string(prefix_root) elif self._value: prefix = self._parent.namescope.full_prefix(prefix_root, as_list=True) prefix.append(self._value) return constants.PREFIX_SEPARATOR.join(prefix) else: return self._value
[docs]class Reference(_Attribute): """A string attribute that represents a reference to an identifier.""" def __init__(self, name, required, parent, value, conflict_allowed, conflict_behavior, reference_namespace): self._reference_namespace = reference_namespace super().__init__(name, required, parent, value, conflict_allowed, conflict_behavior) def _check_dead_reference(self): if isinstance(self._value, base.Element) and self._value.is_removed: self.clear() @property def value(self): self._check_dead_reference() return super().value @value.setter def value(self, new_value): super(Reference, self.__class__).value.fset(self, new_value) @property def reference_namespace(self): if isinstance(self._reference_namespace, _Attribute): return constants.INDIRECT_REFERENCE_ATTRIB.get( self._reference_namespace.value, self._reference_namespace.value) else: return self._reference_namespace def _assign(self, value): if not isinstance(value, (base.Element, str)): raise ValueError( 'Expect a string or `mjcf.Element` value: got {}'.format(value)) elif not value: self.clear() else: if isinstance(value, base.Element): value_namespace = ( value.spec.namespace.split(constants.NAMESPACE_SEPARATOR)[0]) if value_namespace != self.reference_namespace: raise ValueError(_INVALID_REFERENCE_TYPE.format( valid_type=self.reference_namespace, actual_type=value_namespace)) self._value = value def _before_clear(self): if isinstance(self._value, base.Element): if isinstance(self._reference_namespace, _Attribute): self._reference_namespace._force_clear() # pylint: disable=protected-access def _defaults_string(self, prefix_root): """Generates the XML string if this is a reference to a defaults class. To prevent global defaults from clashing, we turn all global defaults into a properly named defaults class. Therefore, care must be taken when this attribute is not explicitly defined. If the parent element can be traced up to a body with a nontrivial 'childclass' then must continue to leave this attribute undefined. Args: prefix_root: A `NameScope` object to be treated as root for the purpose of calculating the prefix. Returns: A string to be used in the generated XML. """ self._check_dead_reference() prefix = self._parent.namescope.full_prefix(prefix_root) if not self._value: defaults_root = self._parent.parent while defaults_root is not None: if (hasattr(defaults_root, constants.CHILDCLASS) and defaults_root.childclass): break defaults_root = defaults_root.parent if defaults_root is None: # This element doesn't belong to a childclass'd body. global_class = self._parent.root.default.dclass or '' out_string = (prefix + global_class) or constants.PREFIX_SEPARATOR else: out_string = None else: out_string = prefix + self._value return out_string
[docs] def to_xml_string(self, prefix_root, **kwargs): self._check_dead_reference() if isinstance(self._value, base.Element): return self._value.prefixed_identifier(prefix_root) elif (self.reference_namespace == constants.DEFAULT and self._name != constants.CHILDCLASS): return self._defaults_string(prefix_root) elif self._value: return self._parent.namescope.full_prefix(prefix_root) + self._value else: return None
[docs]class BasePath(_Attribute): """A string attribute that represents a base path for an asset type.""" def __init__(self, name, required, parent, value, conflict_allowed, conflict_behavior, path_namespace): self._path_namespace = path_namespace super().__init__(name, required, parent, value, conflict_allowed, conflict_behavior) def _assign(self, value): if not isinstance(value, str): raise ValueError('Expect a string value: got {}'.format(value)) elif not value: self.clear() else: self._parent.namescope.replace( constants.BASEPATH, self._path_namespace, value) self._value = value def _before_clear(self): if self._value: self._parent.namescope.remove(constants.BASEPATH, self._path_namespace)
[docs] def to_xml_string(self, prefix_root=None, **kwargs): return None
[docs]class BaseAsset: """Base class for binary assets.""" __slots__ = ('extension', 'prefix') def __init__(self, extension, prefix=''): self.extension = extension self.prefix = prefix def __eq__(self, other): return self.get_vfs_filename() == other.get_vfs_filename()
[docs] def get_vfs_filename(self): """Returns the name of the asset file as registered in MuJoCo's VFS.""" # Hash the contents of the asset to get a unique identifier. hash_string = hashlib.sha1(util.to_binary_string(self.contents)).hexdigest() # Prepend the prefix, if one exists. if self.prefix: prefix = self.prefix raw_length = len(prefix) + len(hash_string) + len(self.extension) + 1 if raw_length > constants.MAX_VFS_FILENAME_LENGTH: trim_amount = raw_length - constants.MAX_VFS_FILENAME_LENGTH prefix = prefix[:-trim_amount] filename = '-'.join([prefix, hash_string]) else: filename = hash_string # An extension is needed because MuJoCo's compiler looks at this when # deciding how to load meshes and heightfields. return filename + self.extension
[docs]class Asset(BaseAsset): """Class representing a binary asset.""" __slots__ = ('contents',) def __init__(self, contents, extension, prefix=''): """Initializes a new `Asset`. Args: contents: The contents of the file as a bytestring. extension: A string specifying the file extension (e.g. '.png', '.stl'). prefix: (optional) A prefix applied to the filename given in MuJoCo's VFS. """ self.contents = contents super().__init__(extension, prefix)
[docs]class SkinAsset(BaseAsset): """Class representing a binary asset corresponding to a skin.""" __slots__ = ('skin', 'parent', '_cached_revision', '_cached_contents') def __init__(self, contents, parent, extension, prefix=''): self.skin = skin.parse( contents, lambda body_name: parent.root.find('body', body_name)) self.parent = parent self._cached_revision = -1 self._cached_contents = None super().__init__(extension, prefix) @property def contents(self): if self._cached_revision < self.parent.namescope.revision: self._cached_contents = skin.serialize(self.skin) self._cached_revision = self.parent.namescope.revision return self._cached_contents
[docs]class File(_Attribute): """Attribute representing an asset file.""" def __init__(self, name, required, parent, value, conflict_allowed, conflict_behavior, path_namespace): self._path_namespace = path_namespace super().__init__(name, required, parent, value, conflict_allowed, conflict_behavior) parent.namescope.files.add(self) def _assign(self, value): if not value: self.clear() else: if isinstance(value, str): asset = self._get_asset_from_path(value) elif isinstance(value, Asset): asset = value else: raise ValueError('Expect either a string or `Asset` value: got {}' .format(value)) self._validate_extension(asset.extension) self._value = asset def _get_asset_from_path(self, path): """Constructs a `Asset` given a file path.""" _, basename = os.path.split(path) filename, extension = os.path.splitext(basename) assetdir = None if self._parent.namescope.has_identifier( constants.BASEPATH, constants.ASSETDIR_NAMESPACE ): assetdir = self._parent.namescope.get( constants.BASEPATH, constants.ASSETDIR_NAMESPACE ) if path in self._parent.namescope.assets: # Look in the dict of pre-loaded assets before checking the filesystem. contents = self._parent.namescope.assets[path] else: # Construct the full path to the asset file, prefixed by the path to the # model directory, and by `meshdir` or `texturedir` if appropriate. path_parts = [] if self._parent.namescope.model_dir: path_parts.append(self._parent.namescope.model_dir) if self._parent.namescope.has_identifier( constants.BASEPATH, self._path_namespace ): base_path = self._parent.namescope.get( constants.BASEPATH, self._path_namespace ) path_parts.append(base_path) elif ( self._path_namespace in (constants.TEXTUREDIR_NAMESPACE, constants.MESHDIR_NAMESPACE) and assetdir is not None ): path_parts.append(assetdir) path_parts.append(path) full_path = os.path.join(*path_parts) # pylint: disable=no-value-for-parameter contents = resources.GetResource(full_path) if self._parent.tag == constants.SKIN: return SkinAsset(contents=contents, parent=self._parent, extension=extension, prefix=filename) else: return Asset(contents=contents, extension=extension, prefix=filename) def _validate_extension(self, extension): if self._parent.tag == constants.MESH: if extension.lower() not in _MESH_EXTENSIONS: raise ValueError(_INVALID_MESH_EXTENSION.format(extension))
[docs] def get_contents(self): """Returns a bytestring representing the contents of the asset.""" if self._value is None: raise RuntimeError('You must assign a value to this attribute before ' 'querying the contents.') return self._value.contents
[docs] def to_xml_string(self, prefix_root=None, **kwargs): """Returns the asset filename as it will appear in the generated XML.""" del prefix_root # Unused if self._value is not None: return self._value.get_vfs_filename() else: return None