Source code for rofunc.utils.robolab.formatter.mjcf_parser.schema

# Copyright 2018 The dm_control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""A Python object representation of Mujoco's MJCF schema.

The root schema is provided as a module-level constant `schema.MUJOCO`.
"""

import collections
import copy
import os

from rofunc.utils.robolab.formatter.mjcf_parser import io as resources
from lxml import etree

from rofunc.utils.robolab.formatter.mjcf_parser import attribute

_SCHEMA_XML_PATH = os.path.join(os.path.dirname(__file__), 'schema.xml')

_ARRAY_DTYPE_MAP = {
    'int': int,
    'float': float,
    'string': str
}

_SCALAR_TYPE_MAP = {
    'int': attribute.Integer,
    'float': attribute.Float,
    'string': attribute.String
}

ElementSpec = collections.namedtuple(
    'ElementSpec', ('name', 'repeated', 'on_demand', 'identifier', 'namespace',
                    'attributes', 'children'))

AttributeSpec = collections.namedtuple(
    'AttributeSpec', ('name', 'type', 'required',
                      'conflict_allowed', 'conflict_behavior', 'other_kwargs'))

# Additional namespaces that are not present in the MJCF schema but can
# be used in `find` and `find_all`.
_ADDITIONAL_FINDABLE_NAMESPACES = frozenset(['attachment_frame'])


def _str2bool(string):
    """Converts either 'true' or 'false' (not case-sensitively) into a boolean."""
    if string is None:
        return False
    else:
        string = string.lower()

    if string == 'true':
        return True
    elif string == 'false':
        return False
    else:
        raise ValueError(
            'String should either be `true` or `false`: got {}'.format(string))


[docs]def parse_schema(schema_path): """Parses the schema XML. Args: schema_path: Path to the schema XML file. Returns: An `ElementSpec` for the root element in the schema. """ with resources.GetResourceAsFile(schema_path) as file_handle: schema_xml = etree.parse(file_handle).getroot() return _parse_element(schema_xml)
def _parse_element(element_xml): """Parses an <element> element in the schema.""" name = element_xml.get('name') if not name: raise ValueError('Element must always have a name') repeated = _str2bool(element_xml.get('repeated')) on_demand = _str2bool(element_xml.get('on_demand')) attributes = collections.OrderedDict() attributes_xml = element_xml.find('attributes') if attributes_xml is not None: for attribute_xml in attributes_xml.findall('attribute'): attributes[attribute_xml.get('name')] = _parse_attribute(attribute_xml) identifier = None namespace = None for attribute_spec in attributes.values(): if attribute_spec.type == attribute.Identifier: identifier = attribute_spec.name namespace = element_xml.get('namespace') or name children = collections.OrderedDict() children_xml = element_xml.find('children') if children_xml is not None: for child_xml in children_xml.findall('element'): children[child_xml.get('name')] = _parse_element(child_xml) element_spec = ElementSpec( name, repeated, on_demand, identifier, namespace, attributes, children) recursive = _str2bool(element_xml.get('recursive')) if recursive: element_spec.children[name] = element_spec common_keys = set(element_spec.attributes).intersection(element_spec.children) if common_keys: raise RuntimeError( 'Element \'{}\' contains the following attributes and children with ' 'the same name: \'{}\'. This violates the design assumptions of ' 'this library. Please file a bug report. Thank you.' .format(name, sorted(common_keys))) return element_spec def _parse_attribute(attribute_xml): """Parses an <attribute> element in the schema.""" name = attribute_xml.get('name') required = _str2bool(attribute_xml.get('required')) conflict_allowed = _str2bool(attribute_xml.get('conflict_allowed')) conflict_behavior = attribute_xml.get('conflict_behavior', 'replace') attribute_type = attribute_xml.get('type') other_kwargs = {} if attribute_type == 'keyword': attribute_callable = attribute.Keyword other_kwargs['valid_values'] = attribute_xml.get('valid_values').split(' ') elif attribute_type == 'array': array_size_str = attribute_xml.get('array_size') attribute_callable = attribute.Array other_kwargs['length'] = int(array_size_str) if array_size_str else None other_kwargs['dtype'] = _ARRAY_DTYPE_MAP[attribute_xml.get('array_type')] elif attribute_type == 'identifier': attribute_callable = attribute.Identifier elif attribute_type == 'reference': attribute_callable = attribute.Reference other_kwargs['reference_namespace'] = ( attribute_xml.get('reference_namespace') or name) elif attribute_type == 'basepath': attribute_callable = attribute.BasePath other_kwargs['path_namespace'] = attribute_xml.get('path_namespace') elif attribute_type == 'file': attribute_callable = attribute.File other_kwargs['path_namespace'] = attribute_xml.get('path_namespace') else: try: attribute_callable = _SCALAR_TYPE_MAP[attribute_type] except KeyError as exc: raise ValueError( 'Invalid attribute type: {}'.format(attribute_type) ) from exc return AttributeSpec( name=name, type=attribute_callable, required=required, conflict_allowed=conflict_allowed, conflict_behavior=conflict_behavior, other_kwargs=other_kwargs)
[docs]def collect_namespaces(root_spec): """Constructs a set of namespaces in a given ElementSpec. Args: root_spec: An `ElementSpec` for the root element in the schema. Returns: A set of strings specifying the names of all the namespaces that are present in the spec. """ findable_namespaces = set() def update_namespaces_from_spec(spec): findable_namespaces.add(spec.namespace) for child_spec in spec.children.values(): if child_spec is not spec: update_namespaces_from_spec(child_spec) update_namespaces_from_spec(root_spec) return findable_namespaces
MUJOCO = parse_schema(_SCHEMA_XML_PATH) FINDABLE_NAMESPACES = frozenset( collect_namespaces(MUJOCO).union(_ADDITIONAL_FINDABLE_NAMESPACES)) def _attachment_frame_spec(is_world_attachment): """Create specs for attachment frames. Attachment frames are specialized <body> without an identifier. The only allowed children are joints which also don't have identifiers. Args: is_world_attachment: Whether we are creating a spec for attachments to worldbody. If `True`, allow <freejoint> as child. Returns: An `ElementSpec`. """ frame_spec = ElementSpec( 'body', repeated=True, on_demand=False, identifier=None, namespace='body', attributes=collections.OrderedDict(), children=collections.OrderedDict()) body_spec = MUJOCO.children['worldbody'].children['body'] # 'name' and 'childclass' attributes are excluded. for attrib_name in ( 'mocap', 'pos', 'quat', 'axisangle', 'xyaxes', 'zaxis', 'euler', 'user'): frame_spec.attributes[attrib_name] = copy.deepcopy( body_spec.attributes[attrib_name]) inertial_spec = body_spec.children['inertial'] frame_spec.children['inertial'] = copy.deepcopy(inertial_spec) joint_spec = body_spec.children['joint'] frame_spec.children['joint'] = ElementSpec( 'joint', repeated=True, on_demand=False, identifier=None, namespace='joint', attributes=copy.deepcopy(joint_spec.attributes), children=collections.OrderedDict()) if is_world_attachment: freejoint_spec = (MUJOCO.children['worldbody'] .children['body'].children['freejoint']) frame_spec.children['freejoint'] = ElementSpec( 'freejoint', repeated=False, on_demand=True, identifier=None, namespace='joint', attributes=copy.deepcopy(freejoint_spec.attributes), children=collections.OrderedDict()) return frame_spec ATTACHMENT_FRAME = _attachment_frame_spec(is_world_attachment=False) WORLD_ATTACHMENT_FRAME = _attachment_frame_spec(is_world_attachment=True)
[docs]def override_schema(schema_xml_path): """Override the schema with a custom xml. This method updates several global variables and care should be taken not to call it if the pre-update values have already been used. Args: schema_xml_path: Path to schema xml file. """ global MUJOCO global FINDABLE_NAMESPACES MUJOCO = parse_schema(schema_xml_path) FINDABLE_NAMESPACES = frozenset( collect_namespaces(MUJOCO).union(_ADDITIONAL_FINDABLE_NAMESPACES))