RofuncRL TD3 (Twin Delayed Deep Deterministic Policy Gradient)#
Paper: “Addressing Function Approximation Error in Actor-Critic Methods”. Fujimoto. et al. 2018. https://arxiv.org/abs/1802.09477
1. Algorithm#
def update_net(self):
"""
Update the network
:return:
"""
# sample a batch from memory
sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \
self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0]
# learning epochs
for gradient_step in range(self._gradient_steps):
sampled_states = self._state_preprocessor(sampled_states, train=not gradient_step)
sampled_next_states = self._state_preprocessor(sampled_next_states)
with torch.no_grad():
# target policy smoothing
next_actions, _ = self.target_actor(sampled_next_states)
noises = torch.clamp(self._smooth_regularization_noise.sample(next_actions.shape),
min=-self._smooth_regularization_clip,
max=self._smooth_regularization_clip)
next_actions.add_(noises)
next_actions.clamp_(min=self.clip_actions_min, max=self.clip_actions_max)
# compute target values
target_q1_values = self.target_critic_1(sampled_next_states, next_actions)
target_q2_values = self.target_critic_2(sampled_next_states, next_actions)
target_q_values = torch.min(target_q1_values, target_q2_values)
target_values = sampled_rewards + self._discount * sampled_dones.logical_not() * target_q_values
# compute critic loss
critic_1_values = self.critic_1(sampled_states, sampled_actions)
critic_2_values = self.critic_2(sampled_states, sampled_actions)
critic_loss = (F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values)) / 2
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()),
self._grad_norm_clip)
self.critic_optimizer.step()
self._critic_update_times += 1
if self._critic_update_times % self._policy_update_delay == 0:
# compute actor (actor) loss
actions, log_prob = self.actor(sampled_states)
critic_1_values = self.critic_1(sampled_states, actions)
critic_2_values = self.critic_2(sampled_states, actions)
policy_loss = - torch.min(critic_1_values, critic_2_values).mean()
# optimization step (actor)
self.actor_optimizer.zero_grad()
policy_loss.backward()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.actor.parameters(), self._grad_norm_clip)
self.actor_optimizer.step()
# update target networks
self.target_actor.update_parameters(self.actor, polyak=self._polyak)
self.target_critic_1.update_parameters(self.critic_1, polyak=self._polyak)
self.target_critic_2.update_parameters(self.critic_2, polyak=self._polyak)
self.track_data("Loss / Actor loss", policy_loss.item())
# update learning rate
if self._lr_scheduler:
self.actor_scheduler.step()
self.critic_scheduler.step()
# record data
self.track_data("Loss / Critic loss", critic_loss.item())
self.track_data("Q-network / Q1 (max)", torch.max(critic_1_values).item())
self.track_data("Q-network / Q1 (min)", torch.min(critic_1_values).item())
self.track_data("Q-network / Q1 (mean)", torch.mean(critic_1_values).item())
self.track_data("Q-network / Q2 (max)", torch.max(critic_2_values).item())
self.track_data("Q-network / Q2 (min)", torch.min(critic_2_values).item())
self.track_data("Q-network / Q2 (mean)", torch.mean(critic_2_values).item())
self.track_data("Target / Target (max)", torch.max(target_values).item())
self.track_data("Target / Target (min)", torch.min(target_values).item())
self.track_data("Target / Target (mean)", torch.mean(target_values).item())
if self._lr_scheduler:
self.track_data("Learning / Actor learning rate", self.actor_scheduler.get_last_lr()[0])
self.track_data("Learning / Critic learning rate", self.critic_scheduler.get_last_lr()[0])
2. Performance comparison#
We compare the performance of the TD3 algorithm with different tricks and an open source baseline
(SKRL). These experiments were conducted on the Pendulum
environment.
The results are shown below:
2.1. Pendulum#
Gray
: SKRL TD3Green
: Rofunc TD3