RofuncRL AMP (Adversarial Motion Priors)#

Paper: “AMP: Adversarial Motion Priors for Stylized Physics-Based Character Control”. Peng et al. 2021. https://arxiv.org/abs/2104.02180

1.  Algorithm#

AMP framework

AMP is a mixline (mixed online and offline) method that combines imitation learning and reinforcement learning, it is achieved by combining two-part reward functions \(r^G\) and \(r^S\):

\[ r\left(\mathrm{~s}_t, \mathrm{a}_t, \mathrm{~s}_{t+1}, \mathrm{~g}\right)=w^G r^G\left(\mathrm{~s}_t, \mathrm{a}_t, \mathrm{~s}_{t+1}, \mathrm{~g}\right)+w^S r^S\left(\mathrm{~s}_t, \mathrm{~s}_{t+1}\right) \]

where \(r^G\left(\mathrm{~s}_t, \mathrm{a}_t, \mathrm{~s}_t, \mathrm{~g}\right)\) is the task-specific reinforcement learning reward which defines high-level objectives \(g\) that a character should satisfy (e.g. moving to a target location), \(r^S\left(\mathrm{~s}_t, \mathrm{~s}_{t+1}\right)\) is the task-agnostic imitation learning reward which specifies low-level details of the behaviors that the character should adopt when performing the task (e.g., walking vs. running to a target). \(w^G\) and \(w^S\) are the weights of the two reward functions.

rewards = self.memory.get_tensor_by_name("rewards")
amp_states = self.memory.get_tensor_by_name("amp_states")

with torch.no_grad():
    amp_logits = self.discriminator(self._amp_state_preprocessor(amp_states))
    if self._least_square_discriminator:
        style_reward = torch.maximum(torch.tensor(1 - 0.25 * torch.square(1 - amp_logits)),
                                        torch.tensor(0.0001, device=self.device))
    else:
        style_reward = -torch.log(torch.maximum(torch.tensor(1 - 1 / (1 + torch.exp(-amp_logits))),
                                                torch.tensor(0.0001, device=self.device)))
    style_reward *= self._discriminator_reward_scale

    combined_rewards = self._task_reward_weight * rewards + self._style_reward_weight * style_reward

2.  Demos#

2.1.  Humanoid Run#

HumanoidAMPRun Inference

python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidAMP_RofuncRL.py --task HumanoidAMP_run --inference

2.2.  Humanoid BackFlip#

HumanoidAMPFlip Inference

python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidAMP_RofuncRL.py --task HumanoidAMP_backflip --inference

2.3.  Humanoid Dance#

HumanoidAMPDance Inference

python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidAMP_RofuncRL.py --task HumanoidAMP_dance --inference

2.4.  Humanoid Hop#

HumanoidAMPHop Inference

python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidAMP_RofuncRL.py --task HumanoidAMP_hop --inference

3.  Baseline comparison#

We compare the performance of the AMP algorithm with different tricks and an open source baseline (SKRL). These experiments were conducted on the Humanoid environment. The results are shown below:

3.1.  Humanoid Run#

HumanoidAMPRun

  • Pink: SKRL AMP

  • Green: Rofunc AMP

3.2.  Humanoid BackFlip#

HumanoidAMPFlip

  • Pink: Rofunc AMP

4.  Tricks#

4.1.  Least-squares discriminator#

The standard GAN objective typically uses a sigmoid cross-entropy loss function.

\[ \underset{D}{\arg \min }-\mathbb{E}_{d^{\mathcal{M}}\left(\mathrm{s}, \mathrm{s}^{\prime}\right)}\left[\log \left(D\left(\mathrm{~s}, \mathrm{~s}^{\prime}\right)\right)\right]-\mathbb{E}_{d^\pi\left(\mathrm{s}, \mathrm{s}^{\prime}\right)}\left[\log \left(1-D\left(\mathrm{~s}, \mathrm{~s}^{\prime}\right)\right)\right] \]

However, this loss tends to lead to optimization challenges due to vanishing gradients as the sigmoid function saturates, which can hamper training of the policy. To address this issue, AMP adopts the loss function proposed for least-squares GAN (LSGAN) which is given by:

\[ \underset{D}{\arg \min } \mathbb{E}_{d^{\mathcal{M}}\left(\mathrm{s}, \mathrm{s}^{\prime}\right)}\left[\left(D\left(\mathrm{~s}, \mathrm{~s}^{\prime}\right)-1\right)^2\right]+\mathbb{E}_{d^\pi\left(\mathrm{s}, \mathrm{s}^{\prime}\right)}\left[\left(D\left(\mathrm{~s}, \mathrm{~s}^{\prime}\right)+1\right)^2\right] \]

The discriminator is trained by solving a least-squares regression problem to predict a score of 1 for samples from the dataset and −1 for samples recorded from the policy.

# discriminator prediction loss
if self._least_square_discriminator:
    discriminator_loss = 0.5 * (
                F.mse_loss(amp_cat_logits, -torch.ones_like(amp_cat_logits), reduction='mean') \
                + F.mse_loss(amp_motion_logits, torch.ones_like(amp_motion_logits), reduction='mean'))
else:
    discriminator_loss = 0.5 * (nn.BCEWithLogitsLoss()(amp_cat_logits, torch.zeros_like(amp_cat_logits)) \
                                + nn.BCEWithLogitsLoss()(amp_motion_logits,
                                                            torch.ones_like(amp_motion_logits)))

Compare the performance of AMP with different discriminator loss functions in HumanoidAMP_Dance task.

  • Red: Rofunc AMP with standard discriminator

  • Green: Rofunc AMP with least-squares discriminator

Compare the performance of AMP with different discriminator loss functions in HumanoidAMP_Walk task.

  • Light blue: Rofunc AMP with standard discriminator

  • Pink: Rofunc AMP with least-squares discriminator

Attention

Least-squares discriminator is a stable trick used in the original AMP paper, but it seems not necessary.

4.2.  Gradient penalty#

amp_motion_gradient = torch.autograd.grad(amp_motion_logits,
                                          sampled_amp_motion_states,
                                          grad_outputs=torch.ones_like(amp_motion_logits),
                                          create_graph=True,
                                          retain_graph=True,
                                          only_inputs=True)
gradient_penalty = torch.sum(torch.square(amp_motion_gradient[0]), dim=-1).mean()
discriminator_loss += self._discriminator_gradient_penalty_scale * gradient_penalty

5.  Network update function#

    def update_net(self):
        """
        Update the network
        """
        # update dataset of reference motions
        self.motion_dataset.add_samples(states=self.collect_reference_motions(self._amp_batch_size))

        '''Compute combined rewards'''
        rewards = self.memory.get_tensor_by_name("rewards")
        amp_states = self.memory.get_tensor_by_name("amp_states")

        with torch.no_grad():
            amp_logits = self.discriminator(self._amp_state_preprocessor(amp_states))
            if self._least_square_discriminator:
                style_rewards = torch.maximum(torch.tensor(1 - 0.25 * torch.square(1 - amp_logits)),
                                              torch.tensor(0.0001, device=self.device))
            else:
                style_rewards = -torch.log(torch.maximum(torch.tensor(1 - 1 / (1 + torch.exp(-amp_logits))),
                                                         torch.tensor(0.0001, device=self.device)))
            style_rewards *= self._discriminator_reward_scale

        combined_rewards = self._task_reward_weight * rewards + self._style_reward_weight * style_rewards

        '''Compute Generalized Advantage Estimator (GAE)'''
        values = self.memory.get_tensor_by_name("values")
        next_values = self.memory.get_tensor_by_name("next_values")

        advantage = 0
        advantages = torch.zeros_like(combined_rewards)
        not_dones = self.memory.get_tensor_by_name("terminated").logical_not()
        memory_size = combined_rewards.shape[0]

        # advantages computation
        for i in reversed(range(memory_size)):
            advantage = combined_rewards[i] - values[i] + self._discount * (
                    next_values[i] + self._td_lambda * not_dones[i] * advantage)
            advantages[i] = advantage
        # returns computation
        values_target = advantages + values
        # advantage normalization
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True))
        self.memory.set_tensor_by_name("returns", self._value_preprocessor(values_target, train=True))
        self.memory.set_tensor_by_name("advantages", advantages)

        '''Sample mini-batches from memory and update the network'''
        sampled_batches = self.memory.sample_all(names=self._tensors_names, mini_batches=self._mini_batch_size)
        sampled_motion_batches = self.motion_dataset.sample(names=["states"],
                                                            batch_size=self.memory.memory_size * self.memory.num_envs,
                                                            mini_batches=self._mini_batch_size)

        if len(self.replay_buffer):
            sampled_replay_batches = self.replay_buffer.sample(names=["states"],
                                                               batch_size=self.memory.memory_size * self.memory.num_envs,
                                                               mini_batches=self._mini_batch_size)
        else:
            sampled_replay_batches = [[batches[self._tensors_names.index("amp_states")]] for batches in sampled_batches]

        cumulative_policy_loss = 0
        cumulative_entropy_loss = 0
        cumulative_value_loss = 0
        cumulative_discriminator_loss = 0

        # learning epochs
        for epoch in range(self._learning_epochs):
            # mini-batches loop
            for i, (sampled_states, sampled_actions, sampled_rewards, samples_next_states, samples_terminated,
                    sampled_log_prob, sampled_values, sampled_returns, sampled_advantages, sampled_amp_states,
                    _) in enumerate(sampled_batches):
                sampled_states = self._state_preprocessor(sampled_states, train=True)
                _, log_prob_now = self.policy(sampled_states, sampled_actions)

                # compute entropy loss
                entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy().mean()

                # compute policy loss
                ratio = torch.exp(log_prob_now - sampled_log_prob)
                surrogate = sampled_advantages * ratio
                surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip,
                                                                    1.0 + self._ratio_clip)

                policy_loss = -torch.min(surrogate, surrogate_clipped).mean()

                # compute value loss
                predicted_values = self.value(sampled_states)

                if self._clip_predicted_values:
                    predicted_values = sampled_values + torch.clip(predicted_values - sampled_values,
                                                                   min=-self._value_clip,
                                                                   max=self._value_clip)
                value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values)

                # compute discriminator loss
                if self._discriminator_batch_size:
                    sampled_amp_states_batch = self._amp_state_preprocessor(
                        sampled_amp_states[0:self._discriminator_batch_size], train=True)
                    sampled_amp_replay_states = self._amp_state_preprocessor(
                        sampled_replay_batches[i][0][0:self._discriminator_batch_size], train=True)
                    sampled_amp_motion_states = self._amp_state_preprocessor(
                        sampled_motion_batches[i][0][0:self._discriminator_batch_size], train=True)
                else:
                    sampled_amp_states_batch = self._amp_state_preprocessor(sampled_amp_states, train=True)
                    sampled_amp_replay_states = self._amp_state_preprocessor(sampled_replay_batches[i][0], train=True)
                    sampled_amp_motion_states = self._amp_state_preprocessor(sampled_motion_batches[i][0], train=True)

                sampled_amp_motion_states.requires_grad_(True)
                amp_logits = self.discriminator(sampled_amp_states_batch)
                amp_replay_logits = self.discriminator(sampled_amp_replay_states)
                amp_motion_logits = self.discriminator(sampled_amp_motion_states)
                amp_cat_logits = torch.cat([amp_logits, amp_replay_logits], dim=0)

                # discriminator prediction loss
                if self._least_square_discriminator:
                    discriminator_loss = 0.5 * (
                            F.mse_loss(amp_cat_logits, -torch.ones_like(amp_cat_logits), reduction='mean')
                            + F.mse_loss(amp_motion_logits, torch.ones_like(amp_motion_logits), reduction='mean'))
                else:
                    discriminator_loss = 0.5 * (nn.BCEWithLogitsLoss()(amp_cat_logits, torch.zeros_like(amp_cat_logits))
                                                + nn.BCEWithLogitsLoss()(amp_motion_logits,
                                                                         torch.ones_like(amp_motion_logits)))

                # discriminator logit regularization
                if self._discriminator_logit_regularization_scale:
                    logit_weights = torch.flatten(list(self.discriminator.modules())[-1].weight)
                    discriminator_loss += self._discriminator_logit_regularization_scale * torch.sum(
                        torch.square(logit_weights))

                # discriminator gradient penalty
                if self._discriminator_gradient_penalty_scale:
                    amp_motion_gradient = torch.autograd.grad(amp_motion_logits,
                                                              sampled_amp_motion_states,
                                                              grad_outputs=torch.ones_like(amp_motion_logits),
                                                              create_graph=True,
                                                              retain_graph=True,
                                                              only_inputs=True)
                    gradient_penalty = torch.sum(torch.square(amp_motion_gradient[0]), dim=-1).mean()
                    discriminator_loss += self._discriminator_gradient_penalty_scale * gradient_penalty

                # discriminator weight decay
                if self._discriminator_weight_decay_scale:
                    weights = [torch.flatten(module.weight) for module in self.discriminator.modules()
                               if isinstance(module, torch.nn.Linear)]
                    weight_decay = torch.sum(torch.square(torch.cat(weights, dim=-1)))
                    discriminator_loss += self._discriminator_weight_decay_scale * weight_decay

                discriminator_loss *= self._discriminator_loss_scale

                '''Update networks'''
                # Update policy network
                self.optimizer_policy.zero_grad()
                (policy_loss + entropy_loss).backward()
                if self._grad_norm_clip > 0:
                    nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
                self.optimizer_policy.step()

                # Update value network
                self.optimizer_value.zero_grad()
                value_loss.backward()
                if self._grad_norm_clip > 0:
                    nn.utils.clip_grad_norm_(self.value.parameters(), self._grad_norm_clip)
                self.optimizer_value.step()

                # Update discriminator network
                self.optimizer_disc.zero_grad()
                discriminator_loss.backward()
                if self._grad_norm_clip > 0:
                    nn.utils.clip_grad_norm_(self.discriminator.parameters(), self._grad_norm_clip)
                self.optimizer_disc.step()

                # update cumulative losses
                cumulative_policy_loss += policy_loss.item()
                cumulative_value_loss += value_loss.item()
                if self._entropy_loss_scale:
                    cumulative_entropy_loss += entropy_loss.item()
                cumulative_discriminator_loss += discriminator_loss.item()

            # update learning rate
            if self._lr_scheduler:
                self.scheduler_policy.step()
                self.scheduler_value.step()
                self.scheduler_disc.step()

        # update AMP replay buffer
        self.replay_buffer.add_samples(states=amp_states.view(-1, amp_states.shape[-1]))

        # record data
        self.track_data("Info / Combined rewards", combined_rewards.mean().cpu())
        self.track_data("Info / Style rewards", style_rewards.mean().cpu())
        self.track_data("Info / Task rewards", rewards.mean().cpu())

        self.track_data("Loss / Policy loss", cumulative_policy_loss / (self._learning_epochs * self._mini_batch_size))
        self.track_data("Loss / Value loss", cumulative_value_loss / (self._learning_epochs * self._mini_batch_size))
        self.track_data("Loss / Discriminator loss",
                        cumulative_discriminator_loss / (self._learning_epochs * self._mini_batch_size))
        if self._entropy_loss_scale:
            self.track_data("Loss / Entropy loss",
                            cumulative_entropy_loss / (self._learning_epochs * self._mini_batch_size))
        if self._lr_scheduler:
            self.track_data("Learning / Learning rate (policy)", self.scheduler_policy.get_last_lr()[0])
            self.track_data("Learning / Learning rate (value)", self.scheduler_value.get_last_lr()[0])
            self.track_data("Learning / Learning rate (discriminator)", self.scheduler_disc.get_last_lr()[0])