RofuncRL ASE (Adversarial Skill Embeddings)#
Paper: “ASE: Large-Scale Reusable Adversarial Skill Embeddings for Physically Simulated Characters”. Peng et al. 2022. https://arxiv.org/abs/2205.01906
1. Algorithm#
2. Demos#
2.1. Pre-trained latent space model#
python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidASEGetupSwordShield --motion_file reallusion_sword_shield/dataset_reallusion_sword_shield.yaml --inference
2.2. Pre-trained latent space model with perturbation#
You can test the robustness of the latent space model by changing to HumanoidASEPerturbSwordShield
task (throwing
boxes to the humanoid robot). It will use the same pre-trained latent space model as previous demo, but set
the reset
function to reset by the maximum length of the episode, rather than resetting immediately when robots fall
on the ground.
python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidASEPerturbSwordShield --motion_file reallusion_sword_shield/dataset_reallusion_sword_shield.yaml --inference
Note By using the pre-trained latent space model, we can train some high-level policies for complex tasks with simple task-specific reward functions.
2.3. High-level policy learning with pre-trained latent space model (Heading)#
HumanoidASEHeadingSwordShield
task: the humanoid robot should face to the blue line and walk towards the red line.
python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidASEHeadingSwordShield --motion_file reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy --inference
2.4. High-level policy learning with pre-trained latent space model (Location)#
HumanoidASELocationSwordShield
task: the humanoid robot should walk to the red location.
python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidASELocationSwordShield --motion_file reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy --inference
2.5. High-level policy learning with pre-trained latent space model (Reach)#
HumanoidASEReachSwordShield
task: the humanoid robot should let the sword reach the red point.
python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidASEReachSwordShield --motion_file reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy --inference
2.6. High-level policy learning with pre-trained latent space model (Strike)#
HumanoidASEStrikeSwordShield
task: the humanoid robot should strike the sword to the block.
python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidASEStrikeSwordShield --motion_file reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy --inference
2.7. Motion visualization#
If you want to visualize the motion, you can use HumanoidViewMotion
task. For example, you can use the following
command to visualize the motion reallusion_sword_shield/RL_Avatar_Atk_2xCombo01_Motion.npy
by
using HumanoidViewMotion
.
python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidViewMotion --motion_file reallusion_sword_shield/RL_Avatar_Atk_2xCombo01_Motion.npy --inference --headless=False
You can also use the absolute path of the motion file.
python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidViewMotion --motion_file /home/ubuntu/Github/Rofunc/examples/data/amp/reallusion_sword_shield/RL_Avatar_Atk_Jump_Motion.npy --inference --headless=False
3. Baseline comparison#
4. Tricks#
5. Network update function#
def update_net(self):
"""
Update the network
"""
# update dataset of reference motions
self.motion_dataset.add_samples(states=self.collect_reference_motions(self._amp_batch_size))
'''Compute combined rewards'''
rewards = self.memory.get_tensor_by_name("rewards")
amp_states = self.memory.get_tensor_by_name("amp_states")
ase_latents = self.memory.get_tensor_by_name("ase_latents")
with torch.no_grad():
# Compute style reward from discriminator
amp_logits = self.discriminator(self._amp_state_preprocessor(amp_states))
if self._least_square_discriminator:
style_rewards = torch.maximum(torch.tensor(1 - 0.25 * torch.square(1 - amp_logits)),
torch.tensor(0.0001, device=self.device))
else:
style_rewards = -torch.log(torch.maximum(torch.tensor(1 - 1 / (1 + torch.exp(-amp_logits))),
torch.tensor(0.0001, device=self.device)))
style_rewards *= self._discriminator_reward_scale
# Compute encoder reward
if self.encoder is self.discriminator:
enc_output = self.encoder.get_enc(self._amp_state_preprocessor(amp_states))
else:
enc_output = self.encoder(self._amp_state_preprocessor(amp_states))
enc_output = torch.nn.functional.normalize(enc_output, dim=-1)
enc_reward = torch.clamp_min(torch.sum(enc_output * ase_latents, dim=-1, keepdim=True), 0.0)
enc_reward *= self._enc_reward_scale
combined_rewards = (self._task_reward_weight * rewards
+ self._style_reward_weight * style_rewards
+ self._enc_reward_weight * enc_reward)
'''Compute Generalized Advantage Estimator (GAE)'''
values = self.memory.get_tensor_by_name("values")
next_values = self.memory.get_tensor_by_name("next_values")
advantage = 0
advantages = torch.zeros_like(combined_rewards)
not_dones = self.memory.get_tensor_by_name("terminated").logical_not()
memory_size = combined_rewards.shape[0]
# advantages computation
for i in reversed(range(memory_size)):
advantage = combined_rewards[i] - values[i] + self._discount * (
next_values[i] + self._td_lambda * not_dones[i] * advantage)
advantages[i] = advantage
# returns computation
values_target = advantages + values
# advantage normalization
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True))
self.memory.set_tensor_by_name("returns", self._value_preprocessor(values_target, train=True))
self.memory.set_tensor_by_name("advantages", advantages)
'''Sample mini-batches from memory and update the network'''
sampled_batches = self.memory.sample_all(names=self._tensors_names, mini_batches=self._mini_batch_size)
sampled_motion_batches = self.motion_dataset.sample(names=["states"],
batch_size=self.memory.memory_size * self.memory.num_envs,
mini_batches=self._mini_batch_size)
if len(self.replay_buffer):
sampled_replay_batches = self.replay_buffer.sample(names=["states"],
batch_size=self.memory.memory_size * self.memory.num_envs,
mini_batches=self._mini_batch_size)
else:
sampled_replay_batches = [[batches[self._tensors_names.index("amp_states")]] for batches in sampled_batches]
cumulative_policy_loss = 0
cumulative_entropy_loss = 0
cumulative_value_loss = 0
cumulative_discriminator_loss = 0
cumulative_encoder_loss = 0
# learning epochs
for epoch in range(self._learning_epochs):
# mini-batches loop
for i, (sampled_states, sampled_actions, sampled_rewards, samples_next_states, samples_terminated,
sampled_log_prob, sampled_values, sampled_returns, sampled_advantages, sampled_amp_states,
_, sampled_ase_latents) in enumerate(sampled_batches):
sampled_states = self._state_preprocessor(torch.hstack((sampled_states, sampled_ase_latents)),
train=True)
# sampled_states = self._state_preprocessor(sampled_states, train=True)
_, log_prob_now = self.policy(sampled_states, sampled_actions)
# compute entropy loss
entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy().mean()
# compute policy loss
ratio = torch.exp(log_prob_now - sampled_log_prob)
surrogate = sampled_advantages * ratio
surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip,
1.0 + self._ratio_clip)
policy_loss = -torch.min(surrogate, surrogate_clipped).mean()
# compute value loss
predicted_values = self.value(sampled_states)
if self._clip_predicted_values:
predicted_values = sampled_values + torch.clip(predicted_values - sampled_values,
min=-self._value_clip,
max=self._value_clip)
value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values)
# compute discriminator loss
if self._discriminator_batch_size:
sampled_amp_states_batch = self._amp_state_preprocessor(
sampled_amp_states[0:self._discriminator_batch_size], train=True)
sampled_amp_replay_states = self._amp_state_preprocessor(
sampled_replay_batches[i][0][0:self._discriminator_batch_size], train=True)
sampled_amp_motion_states = self._amp_state_preprocessor(
sampled_motion_batches[i][0][0:self._discriminator_batch_size], train=True)
else:
sampled_amp_states_batch = self._amp_state_preprocessor(sampled_amp_states, train=True)
sampled_amp_replay_states = self._amp_state_preprocessor(sampled_replay_batches[i][0], train=True)
sampled_amp_motion_states = self._amp_state_preprocessor(sampled_motion_batches[i][0], train=True)
sampled_amp_motion_states.requires_grad_(True)
amp_logits = self.discriminator(sampled_amp_states_batch)
amp_replay_logits = self.discriminator(sampled_amp_replay_states)
amp_motion_logits = self.discriminator(sampled_amp_motion_states)
amp_cat_logits = torch.cat([amp_logits, amp_replay_logits], dim=0)
# discriminator prediction loss
if self._least_square_discriminator:
discriminator_loss = 0.5 * (
F.mse_loss(amp_cat_logits, -torch.ones_like(amp_cat_logits), reduction='mean') \
+ F.mse_loss(amp_motion_logits, torch.ones_like(amp_motion_logits), reduction='mean'))
else:
discriminator_loss = 0.5 * (nn.BCEWithLogitsLoss()(amp_cat_logits, torch.zeros_like(amp_cat_logits)) \
+ nn.BCEWithLogitsLoss()(amp_motion_logits,
torch.ones_like(amp_motion_logits)))
# discriminator logit regularization
if self._discriminator_logit_regularization_scale:
logit_weights = torch.flatten(list(self.discriminator.modules())[-1].weight)
discriminator_loss += self._discriminator_logit_regularization_scale * torch.sum(
torch.square(logit_weights))
# discriminator gradient penalty
if self._discriminator_gradient_penalty_scale:
amp_motion_gradient = torch.autograd.grad(amp_motion_logits,
sampled_amp_motion_states,
grad_outputs=torch.ones_like(amp_motion_logits),
create_graph=True,
retain_graph=True,
only_inputs=True)
gradient_penalty = torch.sum(torch.square(amp_motion_gradient[0]), dim=-1).mean()
discriminator_loss += self._discriminator_gradient_penalty_scale * gradient_penalty
# discriminator weight decay
if self._discriminator_weight_decay_scale:
weights = [torch.flatten(module.weight) for module in self.discriminator.modules() \
if isinstance(module, torch.nn.Linear)]
weight_decay = torch.sum(torch.square(torch.cat(weights, dim=-1)))
discriminator_loss += self._discriminator_weight_decay_scale * weight_decay
discriminator_loss *= self._discriminator_loss_scale
# encoder loss
if self.encoder is self.discriminator:
enc_output = self.encoder.get_enc(self._amp_state_preprocessor(sampled_amp_states))
else:
enc_output = self.encoder(self._amp_state_preprocessor(sampled_amp_states))
enc_output = torch.nn.functional.normalize(enc_output, dim=-1)
enc_err = -torch.sum(enc_output * sampled_ase_latents, dim=-1, keepdim=True)
enc_loss = torch.mean(enc_err)
# encoder gradient penalty
if self._enc_gradient_penalty_scale:
enc_obs_grad = torch.autograd.grad(enc_err,
sampled_ase_latents,
grad_outputs=torch.ones_like(enc_err),
create_graph=True,
retain_graph=True,
only_inputs=True)
gradient_penalty = torch.sum(torch.square(enc_obs_grad[0]), dim=-1).mean()
enc_loss += self._enc_gradient_penalty_scale * gradient_penalty
# encoder weight decay
if self._enc_weight_decay_scale:
weights = [torch.flatten(module.weight) for module in self.encoder.modules() \
if isinstance(module, torch.nn.Linear)]
weight_decay = torch.sum(torch.square(torch.cat(weights, dim=-1)))
enc_loss += self._enc_weight_decay_scale * weight_decay
# if self._enable_amp_diversity_bonus():
# diversity_loss = self._diversity_loss(batch_dict['obs'], mu, batch_dict['ase_latents'])
# diversity_loss = torch.sum(rand_action_mask * diversity_loss) / rand_action_sum
# loss += self._amp_diversity_bonus * diversity_loss
# a_info['amp_diversity_loss'] = diversity_loss
'''Update networks'''
# Update policy network
self.optimizer_policy.zero_grad()
(policy_loss + entropy_loss).backward()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.optimizer_policy.step()
# Update value network
self.optimizer_value.zero_grad()
value_loss.backward()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.value.parameters(), self._grad_norm_clip)
self.optimizer_value.step()
# Update discriminator network
self.optimizer_disc.zero_grad()
if self.encoder is self.discriminator:
(discriminator_loss + enc_loss).backward()
else:
discriminator_loss.backward()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.discriminator.parameters(), self._grad_norm_clip)
self.optimizer_disc.step()
# Update encoder network
if self.encoder is not self.discriminator:
self.optimizer_enc.zero_grad()
enc_loss.backward()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.encoder.parameters(), self._grad_norm_clip)
self.optimizer_enc.step()
# update cumulative losses
cumulative_policy_loss += policy_loss.item()
cumulative_value_loss += value_loss.item()
if self._entropy_loss_scale:
cumulative_entropy_loss += entropy_loss.item()
cumulative_discriminator_loss += discriminator_loss.item()
cumulative_encoder_loss += enc_loss.item()
# update learning rate
if self._lr_scheduler:
self.scheduler_policy.step()
self.scheduler_value.step()
self.scheduler_disc.step()
if self.encoder is not self.discriminator:
self.scheduler_enc.step()
# update AMP replay buffer
self.replay_buffer.add_samples(states=amp_states.view(-1, amp_states.shape[-1]))
# record data
self.track_data("Info / Combined rewards", combined_rewards.mean().cpu())
self.track_data("Info / Style rewards", style_rewards.mean().cpu())
self.track_data("Info / Encoder rewards", enc_reward.mean().cpu())
self.track_data("Info / Task rewards", rewards.mean().cpu())
self.track_data("Loss / Policy loss",
cumulative_policy_loss / (self._learning_epochs * self._mini_batch_size))
self.track_data("Loss / Value loss", cumulative_value_loss / (self._learning_epochs * self._mini_batch_size))
self.track_data("Loss / Discriminator loss",
cumulative_discriminator_loss / (self._learning_epochs * self._mini_batch_size))
self.track_data("Loss / Encoder loss",
cumulative_encoder_loss / (self._learning_epochs * self._mini_batch_size))
if self._entropy_loss_scale:
self.track_data("Loss / Entropy loss",
cumulative_entropy_loss / (self._learning_epochs * self._mini_batch_size))
if self._lr_scheduler:
self.track_data("Learning / Learning rate (policy)", self.scheduler_policy.get_last_lr()[0])
self.track_data("Learning / Learning rate (value)", self.scheduler_value.get_last_lr()[0])
self.track_data("Learning / Learning rate (discriminator)", self.scheduler_disc.get_last_lr()[0])
if self.encoder is not self.discriminator:
self.track_data("Learning / Learning rate (encoder)", self.scheduler_enc.get_last_lr()[0])