RofuncRL AMP (Adversarial Motion Priors)#
Paper: “AMP: Adversarial Motion Priors for Stylized Physics-Based Character Control”. Peng et al. 2021. https://arxiv.org/abs/2104.02180
1. Algorithm#
AMP is a mixline (mixed online and offline) method that combines imitation learning and reinforcement learning, it is achieved by combining two-part reward functions \(r^G\) and \(r^S\):
where \(r^G\left(\mathrm{~s}_t, \mathrm{a}_t, \mathrm{~s}_t, \mathrm{~g}\right)\) is the task-specific
reinforcement learning reward which defines high-level objectives \(g\) that a character should satisfy (e.g. moving to a target location), \(r^S\left(\mathrm{~s}_t, \mathrm{~s}_{t+1}\right)\) is the task-agnostic
imitation learning reward which specifies low-level details of the behaviors that the character should adopt when performing the task (e.g., walking vs. running to a target). \(w^G\) and \(w^S\) are the weights of the two reward functions.
rewards = self.memory.get_tensor_by_name("rewards")
amp_states = self.memory.get_tensor_by_name("amp_states")
with torch.no_grad():
amp_logits = self.discriminator(self._amp_state_preprocessor(amp_states))
if self._least_square_discriminator:
style_reward = torch.maximum(torch.tensor(1 - 0.25 * torch.square(1 - amp_logits)),
torch.tensor(0.0001, device=self.device))
else:
style_reward = -torch.log(torch.maximum(torch.tensor(1 - 1 / (1 + torch.exp(-amp_logits))),
torch.tensor(0.0001, device=self.device)))
style_reward *= self._discriminator_reward_scale
combined_rewards = self._task_reward_weight * rewards + self._style_reward_weight * style_reward
2. Demos#
2.1. Humanoid Run#
python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidAMP_RofuncRL.py --task HumanoidAMP_run --inference
2.2. Humanoid BackFlip#
python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidAMP_RofuncRL.py --task HumanoidAMP_backflip --inference
2.3. Humanoid Dance#
python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidAMP_RofuncRL.py --task HumanoidAMP_dance --inference
2.4. Humanoid Hop#
python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidAMP_RofuncRL.py --task HumanoidAMP_hop --inference
3. Baseline comparison#
We compare the performance of the AMP algorithm with different tricks and an open source baseline
(SKRL). These experiments were conducted on the Humanoid
environment.
The results are shown below:
3.1. Humanoid Run#
Pink
: SKRL AMPGreen
: Rofunc AMP
3.2. Humanoid BackFlip#
Pink
: Rofunc AMP
4. Tricks#
4.1. Least-squares discriminator#
The standard GAN objective typically uses a sigmoid cross-entropy loss function.
However, this loss tends to lead to optimization challenges due to vanishing gradients as the sigmoid function saturates, which can hamper training of the policy. To address this issue, AMP adopts the loss function proposed for least-squares GAN (LSGAN) which is given by:
The discriminator is trained by solving a least-squares regression problem to predict a score of 1 for samples from the dataset and −1 for samples recorded from the policy.
# discriminator prediction loss
if self._least_square_discriminator:
discriminator_loss = 0.5 * (
F.mse_loss(amp_cat_logits, -torch.ones_like(amp_cat_logits), reduction='mean') \
+ F.mse_loss(amp_motion_logits, torch.ones_like(amp_motion_logits), reduction='mean'))
else:
discriminator_loss = 0.5 * (nn.BCEWithLogitsLoss()(amp_cat_logits, torch.zeros_like(amp_cat_logits)) \
+ nn.BCEWithLogitsLoss()(amp_motion_logits,
torch.ones_like(amp_motion_logits)))
Compare the performance of AMP with different discriminator loss functions in HumanoidAMP_Dance
task.
Red
: Rofunc AMP with standard discriminatorGreen
: Rofunc AMP with least-squares discriminator
Compare the performance of AMP with different discriminator loss functions in HumanoidAMP_Walk
task.
Light blue
: Rofunc AMP with standard discriminatorPink
: Rofunc AMP with least-squares discriminator
Attention
Least-squares discriminator is a stable trick used in the original AMP paper, but it seems not necessary.
4.2. Gradient penalty#
amp_motion_gradient = torch.autograd.grad(amp_motion_logits,
sampled_amp_motion_states,
grad_outputs=torch.ones_like(amp_motion_logits),
create_graph=True,
retain_graph=True,
only_inputs=True)
gradient_penalty = torch.sum(torch.square(amp_motion_gradient[0]), dim=-1).mean()
discriminator_loss += self._discriminator_gradient_penalty_scale * gradient_penalty
5. Network update function#
def update_net(self):
"""
Update the network
"""
# update dataset of reference motions
self.motion_dataset.add_samples(states=self.collect_reference_motions(self._amp_batch_size))
'''Compute combined rewards'''
rewards = self.memory.get_tensor_by_name("rewards")
amp_states = self.memory.get_tensor_by_name("amp_states")
with torch.no_grad():
amp_logits = self.discriminator(self._amp_state_preprocessor(amp_states))
if self._least_square_discriminator:
style_rewards = torch.maximum(torch.tensor(1 - 0.25 * torch.square(1 - amp_logits)),
torch.tensor(0.0001, device=self.device))
else:
style_rewards = -torch.log(torch.maximum(torch.tensor(1 - 1 / (1 + torch.exp(-amp_logits))),
torch.tensor(0.0001, device=self.device)))
style_rewards *= self._discriminator_reward_scale
combined_rewards = self._task_reward_weight * rewards + self._style_reward_weight * style_rewards
'''Compute Generalized Advantage Estimator (GAE)'''
values = self.memory.get_tensor_by_name("values")
next_values = self.memory.get_tensor_by_name("next_values")
advantage = 0
advantages = torch.zeros_like(combined_rewards)
not_dones = self.memory.get_tensor_by_name("terminated").logical_not()
memory_size = combined_rewards.shape[0]
# advantages computation
for i in reversed(range(memory_size)):
advantage = combined_rewards[i] - values[i] + self._discount * (
next_values[i] + self._td_lambda * not_dones[i] * advantage)
advantages[i] = advantage
# returns computation
values_target = advantages + values
# advantage normalization
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True))
self.memory.set_tensor_by_name("returns", self._value_preprocessor(values_target, train=True))
self.memory.set_tensor_by_name("advantages", advantages)
'''Sample mini-batches from memory and update the network'''
sampled_batches = self.memory.sample_all(names=self._tensors_names, mini_batches=self._mini_batch_size)
sampled_motion_batches = self.motion_dataset.sample(names=["states"],
batch_size=self.memory.memory_size * self.memory.num_envs,
mini_batches=self._mini_batch_size)
if len(self.replay_buffer):
sampled_replay_batches = self.replay_buffer.sample(names=["states"],
batch_size=self.memory.memory_size * self.memory.num_envs,
mini_batches=self._mini_batch_size)
else:
sampled_replay_batches = [[batches[self._tensors_names.index("amp_states")]] for batches in sampled_batches]
cumulative_policy_loss = 0
cumulative_entropy_loss = 0
cumulative_value_loss = 0
cumulative_discriminator_loss = 0
# learning epochs
for epoch in range(self._learning_epochs):
# mini-batches loop
for i, (sampled_states, sampled_actions, sampled_rewards, samples_next_states, samples_terminated,
sampled_log_prob, sampled_values, sampled_returns, sampled_advantages, sampled_amp_states,
_) in enumerate(sampled_batches):
sampled_states = self._state_preprocessor(sampled_states, train=True)
_, log_prob_now = self.policy(sampled_states, sampled_actions)
# compute entropy loss
entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy().mean()
# compute policy loss
ratio = torch.exp(log_prob_now - sampled_log_prob)
surrogate = sampled_advantages * ratio
surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip,
1.0 + self._ratio_clip)
policy_loss = -torch.min(surrogate, surrogate_clipped).mean()
# compute value loss
predicted_values = self.value(sampled_states)
if self._clip_predicted_values:
predicted_values = sampled_values + torch.clip(predicted_values - sampled_values,
min=-self._value_clip,
max=self._value_clip)
value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values)
# compute discriminator loss
if self._discriminator_batch_size:
sampled_amp_states_batch = self._amp_state_preprocessor(
sampled_amp_states[0:self._discriminator_batch_size], train=True)
sampled_amp_replay_states = self._amp_state_preprocessor(
sampled_replay_batches[i][0][0:self._discriminator_batch_size], train=True)
sampled_amp_motion_states = self._amp_state_preprocessor(
sampled_motion_batches[i][0][0:self._discriminator_batch_size], train=True)
else:
sampled_amp_states_batch = self._amp_state_preprocessor(sampled_amp_states, train=True)
sampled_amp_replay_states = self._amp_state_preprocessor(sampled_replay_batches[i][0], train=True)
sampled_amp_motion_states = self._amp_state_preprocessor(sampled_motion_batches[i][0], train=True)
sampled_amp_motion_states.requires_grad_(True)
amp_logits = self.discriminator(sampled_amp_states_batch)
amp_replay_logits = self.discriminator(sampled_amp_replay_states)
amp_motion_logits = self.discriminator(sampled_amp_motion_states)
amp_cat_logits = torch.cat([amp_logits, amp_replay_logits], dim=0)
# discriminator prediction loss
if self._least_square_discriminator:
discriminator_loss = 0.5 * (
F.mse_loss(amp_cat_logits, -torch.ones_like(amp_cat_logits), reduction='mean')
+ F.mse_loss(amp_motion_logits, torch.ones_like(amp_motion_logits), reduction='mean'))
else:
discriminator_loss = 0.5 * (nn.BCEWithLogitsLoss()(amp_cat_logits, torch.zeros_like(amp_cat_logits))
+ nn.BCEWithLogitsLoss()(amp_motion_logits,
torch.ones_like(amp_motion_logits)))
# discriminator logit regularization
if self._discriminator_logit_regularization_scale:
logit_weights = torch.flatten(list(self.discriminator.modules())[-1].weight)
discriminator_loss += self._discriminator_logit_regularization_scale * torch.sum(
torch.square(logit_weights))
# discriminator gradient penalty
if self._discriminator_gradient_penalty_scale:
amp_motion_gradient = torch.autograd.grad(amp_motion_logits,
sampled_amp_motion_states,
grad_outputs=torch.ones_like(amp_motion_logits),
create_graph=True,
retain_graph=True,
only_inputs=True)
gradient_penalty = torch.sum(torch.square(amp_motion_gradient[0]), dim=-1).mean()
discriminator_loss += self._discriminator_gradient_penalty_scale * gradient_penalty
# discriminator weight decay
if self._discriminator_weight_decay_scale:
weights = [torch.flatten(module.weight) for module in self.discriminator.modules()
if isinstance(module, torch.nn.Linear)]
weight_decay = torch.sum(torch.square(torch.cat(weights, dim=-1)))
discriminator_loss += self._discriminator_weight_decay_scale * weight_decay
discriminator_loss *= self._discriminator_loss_scale
'''Update networks'''
# Update policy network
self.optimizer_policy.zero_grad()
(policy_loss + entropy_loss).backward()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.optimizer_policy.step()
# Update value network
self.optimizer_value.zero_grad()
value_loss.backward()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.value.parameters(), self._grad_norm_clip)
self.optimizer_value.step()
# Update discriminator network
self.optimizer_disc.zero_grad()
discriminator_loss.backward()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.discriminator.parameters(), self._grad_norm_clip)
self.optimizer_disc.step()
# update cumulative losses
cumulative_policy_loss += policy_loss.item()
cumulative_value_loss += value_loss.item()
if self._entropy_loss_scale:
cumulative_entropy_loss += entropy_loss.item()
cumulative_discriminator_loss += discriminator_loss.item()
# update learning rate
if self._lr_scheduler:
self.scheduler_policy.step()
self.scheduler_value.step()
self.scheduler_disc.step()
# update AMP replay buffer
self.replay_buffer.add_samples(states=amp_states.view(-1, amp_states.shape[-1]))
# record data
self.track_data("Info / Combined rewards", combined_rewards.mean().cpu())
self.track_data("Info / Style rewards", style_rewards.mean().cpu())
self.track_data("Info / Task rewards", rewards.mean().cpu())
self.track_data("Loss / Policy loss", cumulative_policy_loss / (self._learning_epochs * self._mini_batch_size))
self.track_data("Loss / Value loss", cumulative_value_loss / (self._learning_epochs * self._mini_batch_size))
self.track_data("Loss / Discriminator loss",
cumulative_discriminator_loss / (self._learning_epochs * self._mini_batch_size))
if self._entropy_loss_scale:
self.track_data("Loss / Entropy loss",
cumulative_entropy_loss / (self._learning_epochs * self._mini_batch_size))
if self._lr_scheduler:
self.track_data("Learning / Learning rate (policy)", self.scheduler_policy.get_last_lr()[0])
self.track_data("Learning / Learning rate (value)", self.scheduler_value.get_last_lr()[0])
self.track_data("Learning / Learning rate (discriminator)", self.scheduler_disc.get_last_lr()[0])