RofuncRL SAC (Soft Actor-Critic)#

Paper: “Soft Actor-Critic Algorithms and Applications”. Haarnoja et al. 2018. https://arxiv.org/abs/1812.05905

1.  Algorithm#

    def update_net(self):
        """
        Update the network
        :return:
        """
        # sample a batch from memory
        sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
            self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]

        # learning epochs
        for gradient_step in range(self._gradient_steps):
            sampled_states = self._state_preprocessor(sampled_states, train=not gradient_step)
            sampled_next_states = self._state_preprocessor(sampled_next_states)

            # compute target values
            with torch.no_grad():
                next_actions, next_log_prob = self.actor(sampled_next_states)

                target_q1_values = self.target_critic_1(sampled_next_states, next_actions)
                target_q2_values = self.target_critic_2(sampled_next_states, next_actions)
                target_q_values = torch.min(target_q1_values,
                                            target_q2_values) - self._entropy_coefficient * next_log_prob
                target_values = sampled_rewards + self._discount * sampled_dones.logical_not() * target_q_values

            # compute critic loss
            critic_1_values = self.critic_1(sampled_states, sampled_actions)
            critic_2_values = self.critic_2(sampled_states, sampled_actions)

            critic_loss = (F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values)) / 2

            # optimization step (critic)
            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            if self._grad_norm_clip > 0:
                nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()),
                                         self._grad_norm_clip)
            self.critic_optimizer.step()

            # compute actor (actor) loss
            actions, log_prob = self.actor(sampled_states)
            critic_1_values = self.critic_1(sampled_states, actions)
            critic_2_values = self.critic_2(sampled_states, actions)

            actor_loss = (self._entropy_coefficient * log_prob - torch.min(critic_1_values, critic_2_values)).mean()

            # optimization step (actor)
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            if self._grad_norm_clip > 0:
                nn.utils.clip_grad_norm_(self.actor.parameters(), self._grad_norm_clip)
            self.actor_optimizer.step()

            # entropy learning
            if self._learn_entropy:
                # compute entropy loss
                entropy_loss = -(self.log_entropy_coefficient * (log_prob + self._target_entropy).detach()).mean()

                # optimization step (entropy)
                self.entropy_optimizer.zero_grad()
                entropy_loss.backward()
                self.entropy_optimizer.step()

                # compute entropy coefficient
                self._entropy_coefficient = torch.exp(self.log_entropy_coefficient.detach())

            # update target networks
            self.target_critic_1.update_parameters(self.critic_1, polyak=self._polyak)
            self.target_critic_2.update_parameters(self.critic_2, polyak=self._polyak)

            # update learning rate
            if self._lr_scheduler:
                self.actor_scheduler.step()
                self.critic_scheduler.step()

            # record data
            self.track_data("Loss / Actor loss", actor_loss.item())
            self.track_data("Loss / Critic loss", critic_loss.item())

            self.track_data("Q-network / Q1 (max)", torch.max(critic_1_values).item())
            self.track_data("Q-network / Q1 (min)", torch.min(critic_1_values).item())
            self.track_data("Q-network / Q1 (mean)", torch.mean(critic_1_values).item())

            self.track_data("Q-network / Q2 (max)", torch.max(critic_2_values).item())
            self.track_data("Q-network / Q2 (min)", torch.min(critic_2_values).item())
            self.track_data("Q-network / Q2 (mean)", torch.mean(critic_2_values).item())

            self.track_data("Target / Target (max)", torch.max(target_values).item())
            self.track_data("Target / Target (min)", torch.min(target_values).item())
            self.track_data("Target / Target (mean)", torch.mean(target_values).item())

            if self._learn_entropy:
                self.track_data("Loss / Entropy loss", entropy_loss.item())
                self.track_data("Coefficient / Entropy coefficient", self._entropy_coefficient.item())

            if self._lr_scheduler:
                self.track_data("Learning / Actor learning rate", self.actor_scheduler.get_last_lr()[0])
                self.track_data("Learning / Critic learning rate", self.critic_scheduler.get_last_lr()[0])

2.  Performance comparison#

We compare the performance of the SAC algorithm with different tricks and an open source baseline (SKRL). These experiments were conducted on the Pendulum environment. The results are shown below:

2.1.  Pendulum#

Pendulum

  • Pink: SKRL SAC

  • Blue: Rofunc SAC with ReLU activation function and batch size of 64

  • Orange: Rofunc SAC with Tanh activation function and batch size of 64

  • Gray: Rofunc SAC with ELU activation function and batch size of 512

  • Red: Rofunc SAC with ELU activation function and batch size of 64