RofuncRL ASE (Adversarial Skill Embeddings)#

Paper: “ASE: Large-Scale Reusable Adversarial Skill Embeddings for Physically Simulated Characters”. Peng et al. 2022. https://arxiv.org/abs/2205.01906

1.  Algorithm#

2.  Demos#

2.1.  Pre-trained latent space model#

python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidASEGetupSwordShield --motion_file reallusion_sword_shield/dataset_reallusion_sword_shield.yaml --inference

2.2.  Pre-trained latent space model with perturbation#

You can test the robustness of the latent space model by changing to HumanoidASEPerturbSwordShield task (throwing boxes to the humanoid robot). It will use the same pre-trained latent space model as previous demo, but set the reset function to reset by the maximum length of the episode, rather than resetting immediately when robots fall on the ground.

python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidASEPerturbSwordShield --motion_file reallusion_sword_shield/dataset_reallusion_sword_shield.yaml --inference

Note By using the pre-trained latent space model, we can train some high-level policies for complex tasks with simple task-specific reward functions.

2.3.  High-level policy learning with pre-trained latent space model (Heading)#

HumanoidASEHeadingSwordShield task: the humanoid robot should face to the blue line and walk towards the red line.

python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidASEHeadingSwordShield --motion_file reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy --inference

2.4.  High-level policy learning with pre-trained latent space model (Location)#

HumanoidASELocationSwordShield task: the humanoid robot should walk to the red location.

python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidASELocationSwordShield --motion_file reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy --inference

2.5.  High-level policy learning with pre-trained latent space model (Reach)#

HumanoidASEReachSwordShield task: the humanoid robot should let the sword reach the red point.

python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidASEReachSwordShield --motion_file reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy --inference

2.6.  High-level policy learning with pre-trained latent space model (Strike)#

HumanoidASEStrikeSwordShield task: the humanoid robot should strike the sword to the block.

python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidASEStrikeSwordShield --motion_file reallusion_sword_shield/RL_Avatar_Idle_Ready_Motion.npy --inference

2.7.  Motion visualization#

If you want to visualize the motion, you can use HumanoidViewMotion task. For example, you can use the following command to visualize the motion reallusion_sword_shield/RL_Avatar_Atk_2xCombo01_Motion.npy by using HumanoidViewMotion.

python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidViewMotion --motion_file reallusion_sword_shield/RL_Avatar_Atk_2xCombo01_Motion.npy --inference --headless=False

You can also use the absolute path of the motion file.

python examples/learning_rl/IsaacGym_RofuncRL/example_HumanoidASE_RofuncRL.py --task HumanoidViewMotion --motion_file /home/ubuntu/Github/Rofunc/examples/data/amp/reallusion_sword_shield/RL_Avatar_Atk_Jump_Motion.npy --inference --headless=False

3.  Baseline comparison#

4.  Tricks#

5.  Network update function#

    def update_net(self):
        """
        Update the network
        """
        # update dataset of reference motions
        self.motion_dataset.add_samples(states=self.collect_reference_motions(self._amp_batch_size))

        '''Compute combined rewards'''
        rewards = self.memory.get_tensor_by_name("rewards")
        amp_states = self.memory.get_tensor_by_name("amp_states")
        ase_latents = self.memory.get_tensor_by_name("ase_latents")

        with torch.no_grad():
            # Compute style reward from discriminator
            amp_logits = self.discriminator(self._amp_state_preprocessor(amp_states))
            if self._least_square_discriminator:
                style_rewards = torch.maximum(torch.tensor(1 - 0.25 * torch.square(1 - amp_logits)),
                                              torch.tensor(0.0001, device=self.device))
            else:
                style_rewards = -torch.log(torch.maximum(torch.tensor(1 - 1 / (1 + torch.exp(-amp_logits))),
                                                         torch.tensor(0.0001, device=self.device)))
            style_rewards *= self._discriminator_reward_scale

            # Compute encoder reward
            if self.encoder is self.discriminator:
                enc_output = self.encoder.get_enc(self._amp_state_preprocessor(amp_states))
            else:
                enc_output = self.encoder(self._amp_state_preprocessor(amp_states))
            enc_output = torch.nn.functional.normalize(enc_output, dim=-1)
            enc_reward = torch.clamp_min(torch.sum(enc_output * ase_latents, dim=-1, keepdim=True), 0.0)
            enc_reward *= self._enc_reward_scale

        combined_rewards = (self._task_reward_weight * rewards
                            + self._style_reward_weight * style_rewards
                            + self._enc_reward_weight * enc_reward)

        '''Compute Generalized Advantage Estimator (GAE)'''
        values = self.memory.get_tensor_by_name("values")
        next_values = self.memory.get_tensor_by_name("next_values")

        advantage = 0
        advantages = torch.zeros_like(combined_rewards)
        not_dones = self.memory.get_tensor_by_name("terminated").logical_not()
        memory_size = combined_rewards.shape[0]

        # advantages computation
        for i in reversed(range(memory_size)):
            advantage = combined_rewards[i] - values[i] + self._discount * (
                    next_values[i] + self._td_lambda * not_dones[i] * advantage)
            advantages[i] = advantage
        # returns computation
        values_target = advantages + values
        # advantage normalization
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

        self.memory.set_tensor_by_name("values", self._value_preprocessor(values, train=True))
        self.memory.set_tensor_by_name("returns", self._value_preprocessor(values_target, train=True))
        self.memory.set_tensor_by_name("advantages", advantages)

        '''Sample mini-batches from memory and update the network'''
        sampled_batches = self.memory.sample_all(names=self._tensors_names, mini_batches=self._mini_batch_size)
        sampled_motion_batches = self.motion_dataset.sample(names=["states"],
                                                            batch_size=self.memory.memory_size * self.memory.num_envs,
                                                            mini_batches=self._mini_batch_size)

        if len(self.replay_buffer):
            sampled_replay_batches = self.replay_buffer.sample(names=["states"],
                                                               batch_size=self.memory.memory_size * self.memory.num_envs,
                                                               mini_batches=self._mini_batch_size)
        else:
            sampled_replay_batches = [[batches[self._tensors_names.index("amp_states")]] for batches in sampled_batches]

        cumulative_policy_loss = 0
        cumulative_entropy_loss = 0
        cumulative_value_loss = 0
        cumulative_discriminator_loss = 0
        cumulative_encoder_loss = 0

        # learning epochs
        for epoch in range(self._learning_epochs):
            # mini-batches loop
            for i, (sampled_states, sampled_actions, sampled_rewards, samples_next_states, samples_terminated,
                    sampled_log_prob, sampled_values, sampled_returns, sampled_advantages, sampled_amp_states,
                    _, sampled_ase_latents) in enumerate(sampled_batches):
                sampled_states = self._state_preprocessor(torch.hstack((sampled_states, sampled_ase_latents)),
                                                          train=True)
                # sampled_states = self._state_preprocessor(sampled_states, train=True)
                _, log_prob_now = self.policy(sampled_states, sampled_actions)

                # compute entropy loss
                entropy_loss = -self._entropy_loss_scale * self.policy.get_entropy().mean()

                # compute policy loss
                ratio = torch.exp(log_prob_now - sampled_log_prob)
                surrogate = sampled_advantages * ratio
                surrogate_clipped = sampled_advantages * torch.clip(ratio, 1.0 - self._ratio_clip,
                                                                    1.0 + self._ratio_clip)

                policy_loss = -torch.min(surrogate, surrogate_clipped).mean()

                # compute value loss
                predicted_values = self.value(sampled_states)

                if self._clip_predicted_values:
                    predicted_values = sampled_values + torch.clip(predicted_values - sampled_values,
                                                                   min=-self._value_clip,
                                                                   max=self._value_clip)
                value_loss = self._value_loss_scale * F.mse_loss(sampled_returns, predicted_values)

                # compute discriminator loss
                if self._discriminator_batch_size:
                    sampled_amp_states_batch = self._amp_state_preprocessor(
                        sampled_amp_states[0:self._discriminator_batch_size], train=True)
                    sampled_amp_replay_states = self._amp_state_preprocessor(
                        sampled_replay_batches[i][0][0:self._discriminator_batch_size], train=True)
                    sampled_amp_motion_states = self._amp_state_preprocessor(
                        sampled_motion_batches[i][0][0:self._discriminator_batch_size], train=True)
                else:
                    sampled_amp_states_batch = self._amp_state_preprocessor(sampled_amp_states, train=True)
                    sampled_amp_replay_states = self._amp_state_preprocessor(sampled_replay_batches[i][0], train=True)
                    sampled_amp_motion_states = self._amp_state_preprocessor(sampled_motion_batches[i][0], train=True)

                sampled_amp_motion_states.requires_grad_(True)
                amp_logits = self.discriminator(sampled_amp_states_batch)
                amp_replay_logits = self.discriminator(sampled_amp_replay_states)
                amp_motion_logits = self.discriminator(sampled_amp_motion_states)
                amp_cat_logits = torch.cat([amp_logits, amp_replay_logits], dim=0)

                # discriminator prediction loss
                if self._least_square_discriminator:
                    discriminator_loss = 0.5 * (
                            F.mse_loss(amp_cat_logits, -torch.ones_like(amp_cat_logits), reduction='mean') \
                            + F.mse_loss(amp_motion_logits, torch.ones_like(amp_motion_logits), reduction='mean'))
                else:
                    discriminator_loss = 0.5 * (nn.BCEWithLogitsLoss()(amp_cat_logits, torch.zeros_like(amp_cat_logits)) \
                                                + nn.BCEWithLogitsLoss()(amp_motion_logits,
                                                                         torch.ones_like(amp_motion_logits)))

                # discriminator logit regularization
                if self._discriminator_logit_regularization_scale:
                    logit_weights = torch.flatten(list(self.discriminator.modules())[-1].weight)
                    discriminator_loss += self._discriminator_logit_regularization_scale * torch.sum(
                        torch.square(logit_weights))

                # discriminator gradient penalty
                if self._discriminator_gradient_penalty_scale:
                    amp_motion_gradient = torch.autograd.grad(amp_motion_logits,
                                                              sampled_amp_motion_states,
                                                              grad_outputs=torch.ones_like(amp_motion_logits),
                                                              create_graph=True,
                                                              retain_graph=True,
                                                              only_inputs=True)
                    gradient_penalty = torch.sum(torch.square(amp_motion_gradient[0]), dim=-1).mean()
                    discriminator_loss += self._discriminator_gradient_penalty_scale * gradient_penalty

                # discriminator weight decay
                if self._discriminator_weight_decay_scale:
                    weights = [torch.flatten(module.weight) for module in self.discriminator.modules() \
                               if isinstance(module, torch.nn.Linear)]
                    weight_decay = torch.sum(torch.square(torch.cat(weights, dim=-1)))
                    discriminator_loss += self._discriminator_weight_decay_scale * weight_decay

                discriminator_loss *= self._discriminator_loss_scale

                # encoder loss
                if self.encoder is self.discriminator:
                    enc_output = self.encoder.get_enc(self._amp_state_preprocessor(sampled_amp_states))
                else:
                    enc_output = self.encoder(self._amp_state_preprocessor(sampled_amp_states))
                enc_output = torch.nn.functional.normalize(enc_output, dim=-1)
                enc_err = -torch.sum(enc_output * sampled_ase_latents, dim=-1, keepdim=True)
                enc_loss = torch.mean(enc_err)

                # encoder gradient penalty
                if self._enc_gradient_penalty_scale:
                    enc_obs_grad = torch.autograd.grad(enc_err,
                                                       sampled_ase_latents,
                                                       grad_outputs=torch.ones_like(enc_err),
                                                       create_graph=True,
                                                       retain_graph=True,
                                                       only_inputs=True)
                    gradient_penalty = torch.sum(torch.square(enc_obs_grad[0]), dim=-1).mean()
                    enc_loss += self._enc_gradient_penalty_scale * gradient_penalty

                # encoder weight decay
                if self._enc_weight_decay_scale:
                    weights = [torch.flatten(module.weight) for module in self.encoder.modules() \
                               if isinstance(module, torch.nn.Linear)]
                    weight_decay = torch.sum(torch.square(torch.cat(weights, dim=-1)))
                    enc_loss += self._enc_weight_decay_scale * weight_decay

                # if self._enable_amp_diversity_bonus():
                #     diversity_loss = self._diversity_loss(batch_dict['obs'], mu, batch_dict['ase_latents'])
                #     diversity_loss = torch.sum(rand_action_mask * diversity_loss) / rand_action_sum
                #     loss += self._amp_diversity_bonus * diversity_loss
                #     a_info['amp_diversity_loss'] = diversity_loss

                '''Update networks'''
                # Update policy network
                self.optimizer_policy.zero_grad()
                (policy_loss + entropy_loss).backward()
                if self._grad_norm_clip > 0:
                    nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
                self.optimizer_policy.step()

                # Update value network
                self.optimizer_value.zero_grad()
                value_loss.backward()
                if self._grad_norm_clip > 0:
                    nn.utils.clip_grad_norm_(self.value.parameters(), self._grad_norm_clip)
                self.optimizer_value.step()

                # Update discriminator network
                self.optimizer_disc.zero_grad()
                if self.encoder is self.discriminator:
                    (discriminator_loss + enc_loss).backward()
                else:
                    discriminator_loss.backward()
                if self._grad_norm_clip > 0:
                    nn.utils.clip_grad_norm_(self.discriminator.parameters(), self._grad_norm_clip)
                self.optimizer_disc.step()

                # Update encoder network
                if self.encoder is not self.discriminator:
                    self.optimizer_enc.zero_grad()
                    enc_loss.backward()
                    if self._grad_norm_clip > 0:
                        nn.utils.clip_grad_norm_(self.encoder.parameters(), self._grad_norm_clip)
                    self.optimizer_enc.step()

                # update cumulative losses
                cumulative_policy_loss += policy_loss.item()
                cumulative_value_loss += value_loss.item()
                if self._entropy_loss_scale:
                    cumulative_entropy_loss += entropy_loss.item()
                cumulative_discriminator_loss += discriminator_loss.item()
                cumulative_encoder_loss += enc_loss.item()

            # update learning rate
            if self._lr_scheduler:
                self.scheduler_policy.step()
                self.scheduler_value.step()
                self.scheduler_disc.step()
                if self.encoder is not self.discriminator:
                    self.scheduler_enc.step()

        # update AMP replay buffer
        self.replay_buffer.add_samples(states=amp_states.view(-1, amp_states.shape[-1]))

        # record data
        self.track_data("Info / Combined rewards", combined_rewards.mean().cpu())
        self.track_data("Info / Style rewards", style_rewards.mean().cpu())
        self.track_data("Info / Encoder rewards", enc_reward.mean().cpu())
        self.track_data("Info / Task rewards", rewards.mean().cpu())

        self.track_data("Loss / Policy loss",
                        cumulative_policy_loss / (self._learning_epochs * self._mini_batch_size))
        self.track_data("Loss / Value loss", cumulative_value_loss / (self._learning_epochs * self._mini_batch_size))
        self.track_data("Loss / Discriminator loss",
                        cumulative_discriminator_loss / (self._learning_epochs * self._mini_batch_size))
        self.track_data("Loss / Encoder loss",
                        cumulative_encoder_loss / (self._learning_epochs * self._mini_batch_size))
        if self._entropy_loss_scale:
            self.track_data("Loss / Entropy loss",
                            cumulative_entropy_loss / (self._learning_epochs * self._mini_batch_size))
        if self._lr_scheduler:
            self.track_data("Learning / Learning rate (policy)", self.scheduler_policy.get_last_lr()[0])
        self.track_data("Learning / Learning rate (value)", self.scheduler_value.get_last_lr()[0])
        self.track_data("Learning / Learning rate (discriminator)", self.scheduler_disc.get_last_lr()[0])
        if self.encoder is not self.discriminator:
            self.track_data("Learning / Learning rate (encoder)", self.scheduler_enc.get_last_lr()[0])