Source code for rofunc.learning.RofuncRL.tasks.isaacgymenv.ase.humanoid_amp

# Copyright (c) 2018-2022, NVIDIA Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
#    contributors may be used to endorse or promote products derived from
#    this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
from enum import Enum

from gym import spaces
from isaacgym.torch_utils import *

import rofunc as rf
from rofunc.learning.RofuncRL.tasks.isaacgymenv.ase.humanoid import Humanoid, dof_to_obs
from rofunc.learning.RofuncRL.tasks.isaacgymenv.ase.motion_lib import MotionLib
from rofunc.learning.RofuncRL.tasks.utils import torch_jit_utils as torch_utils


[docs]class HumanoidAMP(Humanoid):
[docs] class StateInit(Enum): Default = 0 Start = 1 Random = 2 Hybrid = 3
def __init__(self, cfg, rl_device, sim_device, graphics_device_id, headless, virtual_screen_capture, force_render): self.cfg = cfg state_init = cfg["env"]["stateInit"] self._state_init = HumanoidAMP.StateInit[state_init] self._hybrid_init_prob = cfg["env"]["hybridInitProb"] self._num_amp_obs_steps = cfg["env"]["numAMPObsSteps"] assert self._num_amp_obs_steps >= 2 self._reset_default_env_ids = [] self._reset_ref_env_ids = [] super().__init__(config=self.cfg, rl_device=rl_device, sim_device=sim_device, graphics_device_id=graphics_device_id, headless=headless, virtual_screen_capture=virtual_screen_capture, force_render=force_render) motion_file = cfg["env"].get("motion_file", None) if rf.oslab.is_absl_path(motion_file): motion_file_path = motion_file elif motion_file.split("/")[0] == "examples": motion_file_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "../../../../../../" + motion_file, ) else: motion_file_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "../../../../../../examples/data/amp/" + motion_file, ) self._load_motion(motion_file_path) self._amp_obs_space = spaces.Box(np.ones(self.get_num_amp_obs()) * -np.Inf, np.ones(self.get_num_amp_obs()) * np.Inf) self._amp_obs_buf = torch.zeros((self.num_envs, self._num_amp_obs_steps, self._num_amp_obs_per_step), device=self.device, dtype=torch.float) self._curr_amp_obs_buf = self._amp_obs_buf[:, 0] self._hist_amp_obs_buf = self._amp_obs_buf[:, 1:] self._amp_obs_demo_buf = None
[docs] def post_physics_step(self): super().post_physics_step() self._update_hist_amp_obs() self._compute_amp_observations() amp_obs_flat = self._amp_obs_buf.view(-1, self.get_num_amp_obs()) self.extras["amp_obs"] = amp_obs_flat
[docs] def get_num_amp_obs(self): return self._num_amp_obs_steps * self._num_amp_obs_per_step
@property def amp_observation_space(self): return self._amp_obs_space def fetch_amp_obs_demo(self, num_samples): return self.task.fetch_amp_obs_demo(num_samples)
[docs] def fetch_amp_obs_demo(self, num_samples): dt = self.dt motion_ids = self._motion_lib.sample_motions(num_samples) if self._amp_obs_demo_buf is None: self._build_amp_obs_demo_buf(num_samples) else: assert self._amp_obs_demo_buf.shape[0] == num_samples # since negative times are added to these values in build_amp_obs_demo, # we shift them into the range [0 + truncate_time, end of clip] truncate_time = self.dt * (self._num_amp_obs_steps - 1) motion_times0 = self._motion_lib.sample_time(motion_ids, truncate_time=truncate_time) motion_times0 += truncate_time amp_obs_demo = self.build_amp_obs_demo(motion_ids, motion_times0) self._amp_obs_demo_buf[:] = amp_obs_demo.view(self._amp_obs_demo_buf.shape) amp_obs_demo_flat = self._amp_obs_demo_buf.view(-1, self.get_num_amp_obs()) return amp_obs_demo_flat
[docs] def build_amp_obs_demo(self, motion_ids, motion_times0): dt = self.dt motion_ids = torch.tile(motion_ids.unsqueeze(-1), [1, self._num_amp_obs_steps]) motion_times = motion_times0.unsqueeze(-1) time_steps = -dt * torch.arange(0, self._num_amp_obs_steps, device=self.device) motion_times = motion_times + time_steps motion_ids = motion_ids.view(-1) motion_times = motion_times.view(-1) root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos \ = self._motion_lib.get_motion_state(motion_ids, motion_times) amp_obs_demo = build_amp_observations(root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_pos, self._local_root_obs, self._root_height_obs, self._dof_obs_size, self._dof_offsets) return amp_obs_demo
def _build_amp_obs_demo_buf(self, num_samples): self._amp_obs_demo_buf = torch.zeros((num_samples, self._num_amp_obs_steps, self._num_amp_obs_per_step), device=self.device, dtype=torch.float32) return def _setup_character_props(self, key_bodies): super()._setup_character_props(key_bodies) asset_file = self.cfg["env"]["asset"]["assetFileName"] num_key_bodies = len(key_bodies) if asset_file == "mjcf/amp_humanoid.xml": self._num_amp_obs_per_step = 13 + self._dof_obs_size + 28 + 3 * num_key_bodies # [root_h, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos] elif asset_file == "mjcf/amp_humanoid_sword_shield.xml": self._num_amp_obs_per_step = 13 + self._dof_obs_size + 31 + 3 * num_key_bodies # [root_h, root_rot, root_vel, root_ang_vel, dof_pos, d elif asset_file in ["mjcf/amp_humanoid_spoon_pan_fixed.xml", "mjcf/hotu_humanoid.xml"]: self._num_amp_obs_per_step = 13 + self._dof_obs_size + 34 + 3 * num_key_bodies elif asset_file == "mjcf/hotu_humanoid_w_qbhand.xml": self._num_amp_obs_per_step = 13 + self._dof_obs_size + 64 + 3 * num_key_bodies else: raise ValueError(f"Unsupported humanoid body num: {asset_file}") def _load_motion(self, motion_file): assert self._dof_offsets[-1] == self.num_dof self._motion_lib = MotionLib( motion_file=motion_file, dof_body_ids=self._dof_body_ids, dof_offsets=self._dof_offsets, key_body_ids=self._key_body_ids.cpu().numpy(), device=self.device, ) return
[docs] def reset_idx(self, env_ids): self._reset_default_env_ids = [] self._reset_ref_env_ids = [] super().reset_idx(env_ids) self._init_amp_obs(env_ids) return
def _reset_actors(self, env_ids): if self._state_init == HumanoidAMP.StateInit.Default: self._reset_default(env_ids) elif ( self._state_init == HumanoidAMP.StateInit.Start or self._state_init == HumanoidAMP.StateInit.Random ): self._reset_ref_state_init(env_ids) elif self._state_init == HumanoidAMP.StateInit.Hybrid: self._reset_hybrid_state_init(env_ids) else: assert False, "Unsupported state initialization strategy: {:s}".format( str(self._state_init) ) return def _reset_default(self, env_ids): self._humanoid_root_states[env_ids] = self._initial_humanoid_root_states[ env_ids ] self._dof_pos[env_ids] = self._initial_dof_pos[env_ids] self._dof_vel[env_ids] = self._initial_dof_vel[env_ids] self._reset_default_env_ids = env_ids return def _reset_ref_state_init(self, env_ids): num_envs = env_ids.shape[0] motion_ids = self._motion_lib.sample_motions(num_envs) if ( self._state_init == HumanoidAMP.StateInit.Random or self._state_init == HumanoidAMP.StateInit.Hybrid ): motion_times = self._motion_lib.sample_time(motion_ids) elif self._state_init == HumanoidAMP.StateInit.Start: motion_times = torch.zeros(num_envs, device=self.device) else: assert ( False ), f"Unsupported state initialization strategy: {self._state_init}" ( root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos, ) = self._motion_lib.get_motion_state(motion_ids, motion_times) self._set_env_state( env_ids=env_ids, root_pos=root_pos, root_rot=root_rot, dof_pos=dof_pos, root_vel=root_vel, root_ang_vel=root_ang_vel, dof_vel=dof_vel, ) self._reset_ref_env_ids = env_ids self._reset_ref_motion_ids = motion_ids self._reset_ref_motion_times = motion_times return def _reset_hybrid_state_init(self, env_ids): num_envs = env_ids.shape[0] ref_probs = to_torch( np.array([self._hybrid_init_prob] * num_envs), device=self.device ) ref_init_mask = torch.bernoulli(ref_probs) == 1.0 ref_reset_ids = env_ids[ref_init_mask] if len(ref_reset_ids) > 0: self._reset_ref_state_init(ref_reset_ids) default_reset_ids = env_ids[torch.logical_not(torch.tensor(ref_init_mask))] if len(default_reset_ids) > 0: self._reset_default(default_reset_ids) return def _init_amp_obs(self, env_ids): self._compute_amp_observations(env_ids) if len(self._reset_default_env_ids) > 0: self._init_amp_obs_default(self._reset_default_env_ids) if len(self._reset_ref_env_ids) > 0: self._init_amp_obs_ref( self._reset_ref_env_ids, self._reset_ref_motion_ids, self._reset_ref_motion_times, ) return def _init_amp_obs_default(self, env_ids): curr_amp_obs = self._curr_amp_obs_buf[env_ids].unsqueeze(-2) self._hist_amp_obs_buf[env_ids] = curr_amp_obs return def _init_amp_obs_ref(self, env_ids, motion_ids, motion_times): dt = self.dt motion_ids = torch.tile( motion_ids.unsqueeze(-1), [1, self._num_amp_obs_steps - 1] ) motion_times = motion_times.unsqueeze(-1) time_steps = -dt * ( torch.arange(0, self._num_amp_obs_steps - 1, device=self.device) + 1 ) motion_times = motion_times + time_steps motion_ids = motion_ids.view(-1) motion_times = motion_times.view(-1) ( root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos, ) = self._motion_lib.get_motion_state(motion_ids, motion_times) amp_obs_demo = build_amp_observations( root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_pos, self._local_root_obs, self._root_height_obs, self._dof_obs_size, self._dof_offsets, ) self._hist_amp_obs_buf[env_ids] = amp_obs_demo.view( self._hist_amp_obs_buf[env_ids].shape ) return def _set_env_state( self, env_ids, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel ): self._humanoid_root_states[env_ids, 0:3] = root_pos self._humanoid_root_states[env_ids, 3:7] = root_rot self._humanoid_root_states[env_ids, 7:10] = root_vel self._humanoid_root_states[env_ids, 10:13] = root_ang_vel self._dof_pos[env_ids] = dof_pos self._dof_vel[env_ids] = dof_vel return def _update_hist_amp_obs(self, env_ids=None): if env_ids is None: for i in reversed(range(self._amp_obs_buf.shape[1] - 1)): self._amp_obs_buf[:, i + 1] = self._amp_obs_buf[:, i] else: for i in reversed(range(self._amp_obs_buf.shape[1] - 1)): self._amp_obs_buf[env_ids, i + 1] = self._amp_obs_buf[env_ids, i] return def _compute_amp_observations(self, env_ids=None): key_body_pos = self._rigid_body_pos[:, self._key_body_ids, :] if env_ids is None: self._curr_amp_obs_buf[:] = build_amp_observations( self._rigid_body_pos[:, 0, :], self._rigid_body_rot[:, 0, :], self._rigid_body_vel[:, 0, :], self._rigid_body_ang_vel[:, 0, :], self._dof_pos, self._dof_vel, key_body_pos, self._local_root_obs, self._root_height_obs, self._dof_obs_size, self._dof_offsets, ) else: self._curr_amp_obs_buf[env_ids] = build_amp_observations( self._rigid_body_pos[env_ids][:, 0, :], self._rigid_body_rot[env_ids][:, 0, :], self._rigid_body_vel[env_ids][:, 0, :], self._rigid_body_ang_vel[env_ids][:, 0, :], self._dof_pos[env_ids], self._dof_vel[env_ids], key_body_pos[env_ids], self._local_root_obs, self._root_height_obs, self._dof_obs_size, self._dof_offsets, ) return
@torch.jit.script def build_amp_observations( root_pos, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos, local_root_obs, root_height_obs, dof_obs_size, dof_offsets, ): # type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool, bool, int, List[int]) -> Tensor root_h = root_pos[:, 2:3] heading_rot = torch_utils.calc_heading_quat_inv(root_rot) if local_root_obs: root_rot_obs = quat_mul(heading_rot, root_rot) else: root_rot_obs = root_rot root_rot_obs = torch_utils.quat_to_tan_norm(root_rot_obs) if not root_height_obs: root_h_obs = torch.zeros_like(root_h) else: root_h_obs = root_h local_root_vel = quat_rotate(heading_rot, root_vel) local_root_ang_vel = quat_rotate(heading_rot, root_ang_vel) root_pos_expand = root_pos.unsqueeze(-2) local_key_body_pos = key_body_pos - root_pos_expand heading_rot_expand = heading_rot.unsqueeze(-2) heading_rot_expand = heading_rot_expand.repeat((1, local_key_body_pos.shape[1], 1)) flat_end_pos = local_key_body_pos.view( local_key_body_pos.shape[0] * local_key_body_pos.shape[1], local_key_body_pos.shape[2], ) flat_heading_rot = heading_rot_expand.view( heading_rot_expand.shape[0] * heading_rot_expand.shape[1], heading_rot_expand.shape[2], ) local_end_pos = quat_rotate(flat_heading_rot, flat_end_pos) flat_local_key_pos = local_end_pos.view( local_key_body_pos.shape[0], local_key_body_pos.shape[1] * local_key_body_pos.shape[2], ) dof_obs = dof_to_obs(dof_pos, dof_obs_size, dof_offsets) obs = torch.cat( ( root_h_obs, root_rot_obs, local_root_vel, local_root_ang_vel, dof_obs, dof_vel, flat_local_key_pos, ), dim=-1, ) return obs