RofuncRL SAC (Soft Actor-Critic)#
Paper: “Soft Actor-Critic Algorithms and Applications”. Haarnoja et al. 2018. https://arxiv.org/abs/1812.05905
1. Algorithm#
def update_net(self):
"""
Update the network
:return:
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]
# learning epochs
for gradient_step in range(self._gradient_steps):
sampled_states = self._state_preprocessor(sampled_states, train=not gradient_step)
sampled_next_states = self._state_preprocessor(sampled_next_states)
# compute target values
with torch.no_grad():
next_actions, next_log_prob = self.actor(sampled_next_states)
target_q1_values = self.target_critic_1(sampled_next_states, next_actions)
target_q2_values = self.target_critic_2(sampled_next_states, next_actions)
target_q_values = torch.min(target_q1_values,
target_q2_values) - self._entropy_coefficient * next_log_prob
target_values = sampled_rewards + self._discount * sampled_dones.logical_not() * target_q_values
# compute critic loss
critic_1_values = self.critic_1(sampled_states, sampled_actions)
critic_2_values = self.critic_2(sampled_states, sampled_actions)
critic_loss = (F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values)) / 2
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()),
self._grad_norm_clip)
self.critic_optimizer.step()
# compute actor (actor) loss
actions, log_prob = self.actor(sampled_states)
critic_1_values = self.critic_1(sampled_states, actions)
critic_2_values = self.critic_2(sampled_states, actions)
actor_loss = (self._entropy_coefficient * log_prob - torch.min(critic_1_values, critic_2_values)).mean()
# optimization step (actor)
self.actor_optimizer.zero_grad()
actor_loss.backward()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.actor.parameters(), self._grad_norm_clip)
self.actor_optimizer.step()
# entropy learning
if self._learn_entropy:
# compute entropy loss
entropy_loss = -(self.log_entropy_coefficient * (log_prob + self._target_entropy).detach()).mean()
# optimization step (entropy)
self.entropy_optimizer.zero_grad()
entropy_loss.backward()
self.entropy_optimizer.step()
# compute entropy coefficient
self._entropy_coefficient = torch.exp(self.log_entropy_coefficient.detach())
# update target networks
self.target_critic_1.update_parameters(self.critic_1, polyak=self._polyak)
self.target_critic_2.update_parameters(self.critic_2, polyak=self._polyak)
# update learning rate
if self._lr_scheduler:
self.actor_scheduler.step()
self.critic_scheduler.step()
# record data
self.track_data("Loss / Actor loss", actor_loss.item())
self.track_data("Loss / Critic loss", critic_loss.item())
self.track_data("Q-network / Q1 (max)", torch.max(critic_1_values).item())
self.track_data("Q-network / Q1 (min)", torch.min(critic_1_values).item())
self.track_data("Q-network / Q1 (mean)", torch.mean(critic_1_values).item())
self.track_data("Q-network / Q2 (max)", torch.max(critic_2_values).item())
self.track_data("Q-network / Q2 (min)", torch.min(critic_2_values).item())
self.track_data("Q-network / Q2 (mean)", torch.mean(critic_2_values).item())
self.track_data("Target / Target (max)", torch.max(target_values).item())
self.track_data("Target / Target (min)", torch.min(target_values).item())
self.track_data("Target / Target (mean)", torch.mean(target_values).item())
if self._learn_entropy:
self.track_data("Loss / Entropy loss", entropy_loss.item())
self.track_data("Coefficient / Entropy coefficient", self._entropy_coefficient.item())
if self._lr_scheduler:
self.track_data("Learning / Actor learning rate", self.actor_scheduler.get_last_lr()[0])
self.track_data("Learning / Critic learning rate", self.critic_scheduler.get_last_lr()[0])
2. Performance comparison#
We compare the performance of the SAC algorithm with different tricks and an open source baseline
(SKRL). These experiments were conducted on the Pendulum
environment.
The results are shown below:
2.1. Pendulum#
Pink
: SKRL SACBlue
: Rofunc SAC with ReLU activation function and batch size of 64Orange
: Rofunc SAC with Tanh activation function and batch size of 64Gray
: Rofunc SAC with ELU activation function and batch size of 512Red
: Rofunc SAC with ELU activation function and batch size of 64