# Copyright 2023, Junjia LIU, jjliu@mae.cuhk.edu.hk
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from enum import Enum
from gym import spaces
from isaacgym.torch_utils import *
import rofunc as rf
from rofunc.learning.RofuncRL.tasks.isaacgymenv.hotu.humanoid import Humanoid, dof_to_obs
from rofunc.learning.RofuncRL.tasks.isaacgymenv.hotu.motion_lib import MotionLib, ObjectMotionLib
from rofunc.learning.RofuncRL.tasks.utils import torch_jit_utils as torch_utils
[docs]class HumanoidHOTUTask(Humanoid):
[docs] class StateInit(Enum):
Default = 0
Start = 1
Random = 2
Hybrid = 3
def __init__(self, cfg, rl_device, sim_device, graphics_device_id, headless, virtual_screen_capture, force_render):
self.cfg = cfg
state_init = cfg["env"]["stateInit"]
self._state_init = HumanoidHOTUTask.StateInit[state_init]
self._hybrid_init_prob = cfg["env"]["hybridInitProb"]
self._num_amp_obs_steps = cfg["env"]["numAMPObsSteps"]
assert self._num_amp_obs_steps >= 2
self._reset_default_env_ids = []
self._reset_ref_env_ids = []
super().__init__(config=self.cfg, rl_device=rl_device, sim_device=sim_device,
graphics_device_id=graphics_device_id, headless=headless,
virtual_screen_capture=virtual_screen_capture, force_render=force_render)
# Load motion file
motion_file = cfg["env"].get("motion_file", None)
if rf.oslab.is_absl_path(motion_file):
motion_file_path = motion_file
elif motion_file.split("/")[0] == "examples":
motion_file_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../../../../../../" + motion_file,
)
else:
raise ValueError("Unsupported motion file path")
self._load_motion(motion_file_path)
# Load object motion file
object_motion_file = cfg["env"].get("object_motion_file", None)
if object_motion_file is not None:
if rf.oslab.is_absl_path(object_motion_file):
object_motion_file_path = object_motion_file
elif object_motion_file.split("/")[0] == "examples":
object_motion_file_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"../../../../../../" + object_motion_file,
)
else:
raise ValueError("Unsupported object motion file path")
self._load_object_motion(object_motion_file_path)
# Set up the observation space for AMP
self._amp_obs_space = spaces.Box(np.ones(self.get_num_amp_obs()) * -np.Inf,
np.ones(self.get_num_amp_obs()) * np.Inf)
self._amp_obs_buf = torch.zeros((self.num_envs, self._num_amp_obs_steps, self._num_amp_obs_per_step),
device=self.device, dtype=torch.float)
self._curr_amp_obs_buf = self._amp_obs_buf[:, 0]
self._hist_amp_obs_buf = self._amp_obs_buf[:, 1:]
self._amp_obs_demo_buf = None
[docs] def post_physics_step(self):
super().post_physics_step()
self._update_hist_amp_obs()
self._compute_amp_observations()
amp_obs_flat = self._amp_obs_buf.view(-1, self.get_num_amp_obs())
self.extras["amp_obs"] = amp_obs_flat
[docs] def get_num_amp_obs(self):
return self._num_amp_obs_steps * self._num_amp_obs_per_step
@property
def amp_observation_space(self):
return self._amp_obs_space
def fetch_amp_obs_demo(self, num_samples):
return self.task.fetch_amp_obs_demo(num_samples)
[docs] def fetch_amp_obs_demo(self, num_samples):
dt = self.dt
motion_ids = self._motion_lib.sample_motions(num_samples)
if self._amp_obs_demo_buf is None:
self._build_amp_obs_demo_buf(num_samples)
else:
assert self._amp_obs_demo_buf.shape[0] == num_samples
# since negative times are added to these values in build_amp_obs_demo,
# we shift them into the range [0 + truncate_time, end of clip]
truncate_time = self.dt * (self._num_amp_obs_steps - 1)
motion_times0 = self._motion_lib.sample_time(motion_ids, truncate_time=truncate_time)
motion_times0 += truncate_time
amp_obs_demo = self.build_amp_obs_demo(motion_ids, motion_times0)
self._amp_obs_demo_buf[:] = amp_obs_demo.view(self._amp_obs_demo_buf.shape)
amp_obs_demo_flat = self._amp_obs_demo_buf.view(-1, self.get_num_amp_obs())
return amp_obs_demo_flat
[docs] def build_amp_obs_demo(self, motion_ids, motion_times0):
dt = self.dt
motion_ids = torch.tile(motion_ids.unsqueeze(-1), [1, self._num_amp_obs_steps])
motion_times = motion_times0.unsqueeze(-1)
time_steps = -dt * torch.arange(0, self._num_amp_obs_steps, device=self.device)
motion_times = motion_times + time_steps
motion_ids = motion_ids.view(-1)
motion_times = motion_times.view(-1)
root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel, key_pos, _, _ \
= self._motion_lib.get_motion_state(motion_ids, motion_times)
amp_obs_demo = build_amp_observations(root_pos, root_rot, root_vel, root_ang_vel,
dof_pos, dof_vel, key_pos,
self._local_root_obs, self._root_height_obs,
self._dof_obs_size, self._dof_offsets)
return amp_obs_demo
def _build_amp_obs_demo_buf(self, num_samples):
self._amp_obs_demo_buf = torch.zeros((num_samples, self._num_amp_obs_steps, self._num_amp_obs_per_step),
device=self.device, dtype=torch.float32)
def _setup_character_props(self, key_bodies):
super()._setup_character_props(key_bodies)
asset_file = self.cfg["env"]["asset"]["assetFileName"]
num_key_bodies = len(key_bodies)
if asset_file == "mjcf/amp_humanoid.xml":
self._num_amp_obs_per_step = 13 + self._dof_obs_size + 28 + 3 * num_key_bodies # [root_h, root_rot, root_vel, root_ang_vel, dof_pos, dof_vel, key_body_pos]
elif asset_file == "mjcf/amp_humanoid_sword_shield.xml":
self._num_amp_obs_per_step = 13 + self._dof_obs_size + 31 + 3 * num_key_bodies # [root_h, root_rot, root_vel, root_ang_vel, dof_pos, d
elif asset_file in ["mjcf/amp_humanoid_spoon_pan_fixed.xml", "mjcf/hotu_humanoid.xml"]:
self._num_amp_obs_per_step = 13 + self._dof_obs_size + 34 + 3 * num_key_bodies
elif asset_file == "mjcf/hotu_humanoid_w_qbhand.xml":
self._num_amp_obs_per_step = 13 + self._dof_obs_size + 64 + 3 * num_key_bodies
elif asset_file in ["mjcf/hotu_humanoid_w_qbhand_no_virtual.xml",
"mjcf/hotu_humanoid_w_qbhand_no_virtual_no_quat.xml"]:
self._num_amp_obs_per_step = 13 + self._dof_obs_size + 64 + 3 * num_key_bodies
else:
print(f"Unsupported humanoid body num: {asset_file}")
assert False
def _load_motion(self, motion_file):
assert self._dof_offsets[-1] == self.num_dof
self._motion_lib = MotionLib(
motion_file=motion_file,
dof_body_ids=self._dof_body_ids,
dof_offsets=self._dof_offsets,
key_body_ids=self._key_body_ids.cpu().numpy(),
device=self.device,
)
def _load_object_motion(self, object_motion_file):
self._object_motion_lib = ObjectMotionLib(
object_motion_file=object_motion_file,
object_names=self.cfg["env"]["object_asset"]["assetName"],
device=self.device,
height_offset=self._motion_lib.humanoid_height_offsets[0] # TODO: make it for multiple motions
)
[docs] def reset_idx(self, env_ids):
self._reset_default_env_ids = []
self._reset_ref_env_ids = []
super().reset_idx(env_ids)
self._init_amp_obs(env_ids)
def _reset_actors(self, env_ids):
if self._state_init == HumanoidHOTUTask.StateInit.Default:
self._reset_default(env_ids)
elif (
self._state_init == HumanoidHOTUTask.StateInit.Start
or self._state_init == HumanoidHOTUTask.StateInit.Random
):
self._reset_ref_state_init(env_ids)
elif self._state_init == HumanoidHOTUTask.StateInit.Hybrid:
self._reset_hybrid_state_init(env_ids)
else:
assert False, "Unsupported state initialization strategy: {:s}".format(
str(self._state_init)
)
def _reset_default(self, env_ids):
self._humanoid_root_states[env_ids] = self._initial_humanoid_root_states[
env_ids
]
self._dof_pos[env_ids] = self._initial_dof_pos[env_ids]
self._dof_vel[env_ids] = self._initial_dof_vel[env_ids]
self._reset_default_env_ids = env_ids
def _reset_ref_state_init(self, env_ids):
num_envs = env_ids.shape[0]
motion_ids = self._motion_lib.sample_motions(num_envs)
if (
self._state_init == HumanoidHOTUTask.StateInit.Random
or self._state_init == HumanoidHOTUTask.StateInit.Hybrid
):
motion_times = self._motion_lib.sample_time(motion_ids)
elif self._state_init == HumanoidHOTUTask.StateInit.Start:
motion_times = torch.zeros(num_envs, device=self.device)
else:
assert (
False
), f"Unsupported state initialization strategy: {self._state_init}"
(
root_pos,
root_rot,
dof_pos,
root_vel,
root_ang_vel,
dof_vel,
key_pos,
_, _
) = self._motion_lib.get_motion_state(motion_ids, motion_times)
self._set_env_state(
env_ids=env_ids,
root_pos=root_pos,
root_rot=root_rot,
dof_pos=dof_pos,
root_vel=root_vel,
root_ang_vel=root_ang_vel,
dof_vel=dof_vel,
)
self._reset_ref_env_ids = env_ids
self._reset_ref_motion_ids = motion_ids
self._reset_ref_motion_times = motion_times
def _reset_hybrid_state_init(self, env_ids):
num_envs = env_ids.shape[0]
ref_probs = to_torch(
np.array([self._hybrid_init_prob] * num_envs), device=self.device
)
ref_init_mask = torch.bernoulli(ref_probs) == 1.0
ref_reset_ids = env_ids[ref_init_mask]
if len(ref_reset_ids) > 0:
self._reset_ref_state_init(ref_reset_ids)
default_reset_ids = env_ids[torch.logical_not(torch.tensor(ref_init_mask))]
if len(default_reset_ids) > 0:
self._reset_default(default_reset_ids)
def _init_amp_obs(self, env_ids):
self._compute_amp_observations(env_ids)
if len(self._reset_default_env_ids) > 0:
self._init_amp_obs_default(self._reset_default_env_ids)
if len(self._reset_ref_env_ids) > 0:
self._init_amp_obs_ref(
self._reset_ref_env_ids,
self._reset_ref_motion_ids,
self._reset_ref_motion_times,
)
def _init_amp_obs_default(self, env_ids):
curr_amp_obs = self._curr_amp_obs_buf[env_ids].unsqueeze(-2)
self._hist_amp_obs_buf[env_ids] = curr_amp_obs
def _init_amp_obs_ref(self, env_ids, motion_ids, motion_times):
dt = self.dt
motion_ids = torch.tile(
motion_ids.unsqueeze(-1), [1, self._num_amp_obs_steps - 1]
)
motion_times = motion_times.unsqueeze(-1)
time_steps = -dt * (
torch.arange(0, self._num_amp_obs_steps - 1, device=self.device) + 1
)
motion_times = motion_times + time_steps
motion_ids = motion_ids.view(-1)
motion_times = motion_times.view(-1)
(
root_pos,
root_rot,
dof_pos,
root_vel,
root_ang_vel,
dof_vel,
key_pos,
_, _
) = self._motion_lib.get_motion_state(motion_ids, motion_times)
amp_obs_demo = build_amp_observations(
root_pos,
root_rot,
root_vel,
root_ang_vel,
dof_pos,
dof_vel,
key_pos,
self._local_root_obs,
self._root_height_obs,
self._dof_obs_size,
self._dof_offsets,
)
self._hist_amp_obs_buf[env_ids] = amp_obs_demo.view(
self._hist_amp_obs_buf[env_ids].shape
)
def _set_env_state(
self, env_ids, root_pos, root_rot, dof_pos, root_vel, root_ang_vel, dof_vel
):
self._humanoid_root_states[env_ids, 0:3] = root_pos
self._humanoid_root_states[env_ids, 3:7] = root_rot
self._humanoid_root_states[env_ids, 7:10] = root_vel
self._humanoid_root_states[env_ids, 10:13] = root_ang_vel
# self._dof_pos[env_ids] = dof_pos + self.init_dof_pose.to('cuda:0')
# self._dof_pos[env_ids] = torch.zeros_like(dof_pos).to('cuda:0')
self._dof_pos[env_ids] = dof_pos
# self._dof_pos[env_ids, 6] = -1
# self._dof_pos[env_ids, 28] = 1
self._dof_vel[env_ids] = dof_vel
def _update_hist_amp_obs(self, env_ids=None):
if env_ids is None:
for i in reversed(range(self._amp_obs_buf.shape[1] - 1)):
self._amp_obs_buf[:, i + 1] = self._amp_obs_buf[:, i]
else:
for i in reversed(range(self._amp_obs_buf.shape[1] - 1)):
self._amp_obs_buf[env_ids, i + 1] = self._amp_obs_buf[env_ids, i]
def _compute_amp_observations(self, env_ids=None):
key_body_pos = self._rigid_body_pos[:, self._key_body_ids, :]
if env_ids is None:
self._curr_amp_obs_buf[:] = build_amp_observations(
self._rigid_body_pos[:, 0, :],
self._rigid_body_rot[:, 0, :],
self._rigid_body_vel[:, 0, :],
self._rigid_body_ang_vel[:, 0, :],
self._dof_pos,
self._dof_vel,
key_body_pos,
self._local_root_obs,
self._root_height_obs,
self._dof_obs_size,
self._dof_offsets,
)
else:
self._curr_amp_obs_buf[env_ids] = build_amp_observations(
self._rigid_body_pos[env_ids][:, 0, :],
self._rigid_body_rot[env_ids][:, 0, :],
self._rigid_body_vel[env_ids][:, 0, :],
self._rigid_body_ang_vel[env_ids][:, 0, :],
self._dof_pos[env_ids],
self._dof_vel[env_ids],
key_body_pos[env_ids],
self._local_root_obs,
self._root_height_obs,
self._dof_obs_size,
self._dof_offsets,
)
@torch.jit.script
def build_amp_observations(
root_pos,
root_rot,
root_vel,
root_ang_vel,
dof_pos,
dof_vel,
key_body_pos,
local_root_obs,
root_height_obs,
dof_obs_size,
dof_offsets,
):
# type: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, bool, bool, int, List[int]) -> Tensor
root_h = root_pos[:, 2:3]
heading_rot = torch_utils.calc_heading_quat_inv(root_rot)
if local_root_obs:
root_rot_obs = quat_mul(heading_rot, root_rot)
else:
root_rot_obs = root_rot
root_rot_obs = torch_utils.quat_to_tan_norm(root_rot_obs)
if not root_height_obs:
root_h_obs = torch.zeros_like(root_h)
else:
root_h_obs = root_h
local_root_vel = quat_rotate(heading_rot, root_vel)
local_root_ang_vel = quat_rotate(heading_rot, root_ang_vel)
root_pos_expand = root_pos.unsqueeze(-2)
local_key_body_pos = key_body_pos - root_pos_expand
heading_rot_expand = heading_rot.unsqueeze(-2)
heading_rot_expand = heading_rot_expand.repeat((1, local_key_body_pos.shape[1], 1))
flat_end_pos = local_key_body_pos.view(
local_key_body_pos.shape[0] * local_key_body_pos.shape[1],
local_key_body_pos.shape[2],
)
flat_heading_rot = heading_rot_expand.view(
heading_rot_expand.shape[0] * heading_rot_expand.shape[1],
heading_rot_expand.shape[2],
)
local_end_pos = quat_rotate(flat_heading_rot, flat_end_pos)
flat_local_key_pos = local_end_pos.view(
local_key_body_pos.shape[0],
local_key_body_pos.shape[1] * local_key_body_pos.shape[2],
)
dof_obs = dof_to_obs(dof_pos, dof_obs_size, dof_offsets)
obs = torch.cat(
(
root_h_obs,
root_rot_obs,
local_root_vel,
local_root_ang_vel,
dof_obs,
dof_vel,
flat_local_key_pos,
),
dim=-1,
)
return obs