RofuncRL TD3 (Twin Delayed Deep Deterministic Policy Gradient)#

Paper: “Addressing Function Approximation Error in Actor-Critic Methods”. Fujimoto. et al. 2018. https://arxiv.org/abs/1802.09477

1.  Algorithm#

    def update_net(self):
        """
        Update the network
        :return:
        """
        # sample a batch from memory
        sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
            self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

        # learning epochs
        for gradient_step in range(self._gradient_steps):
            sampled_states = self._state_preprocessor(sampled_states, train=not gradient_step)
            sampled_next_states = self._state_preprocessor(sampled_next_states)

            with torch.no_grad():
                # target policy smoothing
                next_actions, _ = self.target_actor(sampled_next_states)
                noises = torch.clamp(self._smooth_regularization_noise.sample(next_actions.shape),
                                     min=-self._smooth_regularization_clip,
                                     max=self._smooth_regularization_clip)
                next_actions.add_(noises)
                next_actions.clamp_(min=self.clip_actions_min, max=self.clip_actions_max)

                # compute target values
                target_q1_values = self.target_critic_1(sampled_next_states, next_actions)
                target_q2_values = self.target_critic_2(sampled_next_states, next_actions)
                target_q_values = torch.min(target_q1_values, target_q2_values)
                target_values = sampled_rewards + self._discount * sampled_dones.logical_not() * target_q_values

            # compute critic loss
            critic_1_values = self.critic_1(sampled_states, sampled_actions)
            critic_2_values = self.critic_2(sampled_states, sampled_actions)

            critic_loss = (F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values)) / 2

            # optimization step (critic)
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            if self._grad_norm_clip > 0:
                nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()),
                                         self._grad_norm_clip)
            self.critic_optimizer.step()

            self._critic_update_times += 1
            if self._critic_update_times % self._policy_update_delay == 0:
                # compute actor (actor) loss
                actions, log_prob = self.actor(sampled_states)
                critic_1_values = self.critic_1(sampled_states, actions)
                critic_2_values = self.critic_2(sampled_states, actions)

                policy_loss = - torch.min(critic_1_values, critic_2_values).mean()

                # optimization step (actor)
                self.actor_optimizer.zero_grad()
                policy_loss.backward()
                if self._grad_norm_clip > 0:
                    nn.utils.clip_grad_norm_(self.actor.parameters(), self._grad_norm_clip)
                self.actor_optimizer.step()

                # update target networks
                self.target_actor.update_parameters(self.actor, polyak=self._polyak)
                self.target_critic_1.update_parameters(self.critic_1, polyak=self._polyak)
                self.target_critic_2.update_parameters(self.critic_2, polyak=self._polyak)

                self.track_data("Loss / Actor loss", policy_loss.item())

            # update learning rate
            if self._lr_scheduler:
                self.actor_scheduler.step()
                self.critic_scheduler.step()

            # record data
            self.track_data("Loss / Critic loss", critic_loss.item())

            self.track_data("Q-network / Q1 (max)", torch.max(critic_1_values).item())
            self.track_data("Q-network / Q1 (min)", torch.min(critic_1_values).item())
            self.track_data("Q-network / Q1 (mean)", torch.mean(critic_1_values).item())

            self.track_data("Q-network / Q2 (max)", torch.max(critic_2_values).item())
            self.track_data("Q-network / Q2 (min)", torch.min(critic_2_values).item())
            self.track_data("Q-network / Q2 (mean)", torch.mean(critic_2_values).item())

            self.track_data("Target / Target (max)", torch.max(target_values).item())
            self.track_data("Target / Target (min)", torch.min(target_values).item())
            self.track_data("Target / Target (mean)", torch.mean(target_values).item())

            if self._lr_scheduler:
                self.track_data("Learning / Actor learning rate", self.actor_scheduler.get_last_lr()[0])
                self.track_data("Learning / Critic learning rate", self.critic_scheduler.get_last_lr()[0])

2.  Performance comparison#

We compare the performance of the TD3 algorithm with different tricks and an open source baseline (SKRL). These experiments were conducted on the Pendulum environment. The results are shown below:

2.1.  Pendulum#

Pendulum

  • Gray: SKRL TD3

  • Green: Rofunc TD3