Source code for rofunc.learning.RofuncRL.agents.online.td3_agent

#  Copyright (C) 2024, Junjia Liu
# 
#  This file is part of Rofunc.
# 
#  Rofunc is licensed under the GNU General Public License v3.0.
#  You may use, distribute, and modify this code under the terms of the GPL-3.0.
# 
#  Additional Terms for Commercial Use:
#  Commercial use requires sharing 50% of net profits with the copyright holder.
#  Financial reports and regular payments must be provided as agreed in writing.
#  Non-compliance results in revocation of commercial rights.
# 
#  For more details, see <https://www.gnu.org/licenses/>.
#  Contact: skylark0924@gmail.com

import gym
import gymnasium
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import DictConfig
from typing import Union, Tuple, Optional

import rofunc as rf
from rofunc.learning.RofuncRL.agents.base_agent import BaseAgent
from rofunc.learning.RofuncRL.models.actor_models import ActorTD3
from rofunc.learning.RofuncRL.models.critic_models import Critic
from rofunc.learning.RofuncRL.processors.noises import GaussianNoise
from rofunc.learning.RofuncRL.processors.schedulers import KLAdaptiveRL
from rofunc.learning.RofuncRL.processors.standard_scaler import RunningStandardScaler
from rofunc.learning.RofuncRL.processors.standard_scaler import empty_preprocessor
from rofunc.learning.RofuncRL.utils.memory import Memory


[docs]class TD3Agent(BaseAgent): """ Twin Delayed Deep Deterministic Policy Gradient (TD3) agent \n “Addressing Function Approximation Error in Actor-Critic Methods”. Fujimoto. et al. 2018. https://arxiv.org/abs/1802.09477 \n Rofunc documentation: https://rofunc.readthedocs.io/en/latest/lfd/RofuncRL/TD3.html """ def __init__(self, cfg: DictConfig, observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]], action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]], memory: Optional[Union[Memory, Tuple[Memory]]] = None, device: Optional[Union[str, torch.device]] = None, experiment_dir: Optional[str] = None, rofunc_logger: Optional[rf.logger.BeautyLogger] = None): """ :param cfg: Configurations :param observation_space: Observation space :param action_space: Action space :param memory: Memory for storing transitions :param device: Device on which the torch tensor is allocated :param experiment_dir: Directory for storing experiment data :param rofunc_logger: Rofunc logger """ super().__init__(cfg, observation_space, action_space, memory, device, experiment_dir, rofunc_logger) '''Define models for TD3''' concat_space = [observation_space, action_space] self.actor = ActorTD3(cfg.Model, observation_space, action_space, self.se).to(self.device) self.target_actor = ActorTD3(cfg.Model, observation_space, action_space, self.se).to(self.device) self.critic_1 = Critic(cfg.Model, concat_space, action_space, self.se).to(self.device) self.critic_2 = Critic(cfg.Model, concat_space, action_space, self.se).to(self.device) self.target_critic_1 = Critic(cfg.Model, concat_space, action_space, self.se).to(self.device) self.target_critic_2 = Critic(cfg.Model, concat_space, action_space, self.se).to(self.device) self.models = {"actor": self.actor, "target_actor": self.target_actor, "critic_1": self.critic_1, "critic_2": self.critic_2, "target_critic_1": self.target_critic_1, "target_critic_2": self.target_critic_2} # checkpoint models self.checkpoint_modules["actor"] = self.actor self.checkpoint_modules["target_actor"] = self.target_actor self.checkpoint_modules["critic_1"] = self.critic_1 self.checkpoint_modules["critic_2"] = self.critic_2 self.checkpoint_modules["target_critic_1"] = self.target_critic_1 self.checkpoint_modules["target_critic_2"] = self.target_critic_2 self.rofunc_logger.module(f"Actor model: {self.actor}") self.rofunc_logger.module(f"Critic model x4: {self.critic_1}") '''Create tensors in memory''' if hasattr(cfg.Model, "state_encoder"): img_channel = int(self.cfg.Model.state_encoder.inp_channels) img_size = int(self.cfg.Model.state_encoder.image_size) state_tensor_size = (img_channel, img_size, img_size) kd = True else: state_tensor_size = self.observation_space kd = False self.memory.create_tensor(name="states", size=state_tensor_size, dtype=torch.float32, keep_dimensions=kd) self.memory.create_tensor(name="next_states", size=state_tensor_size, dtype=torch.float32, keep_dimensions=kd) self.memory.create_tensor(name="actions", size=self.action_space, dtype=torch.float32) self.memory.create_tensor(name="rewards", size=1, dtype=torch.float32) self.memory.create_tensor(name="terminated", size=1, dtype=torch.bool) self._tensors_names = ["states", "actions", "rewards", "next_states", "terminated"] '''Get hyper-parameters from config''' self._discount = self.cfg.Agent.discount self._gradient_steps = self.cfg.Agent.gradient_steps self._polyak = self.cfg.Agent.polyak self._batch_size = self.cfg.Agent.batch_size self._lr_a = self.cfg.Agent.lr_a self._lr_c = self.cfg.Agent.lr_c self._lr_scheduler = self.cfg.get("Agent", {}).get("lr_scheduler", KLAdaptiveRL) self._lr_scheduler_kwargs = self.cfg.get("Agent", {}).get("lr_scheduler_kwargs", {'kl_threshold': 0.008}) self._adam_eps = self.cfg.Agent.adam_eps self._grad_norm_clip = self.cfg.Agent.grad_norm_clip self._kl_threshold = self.cfg.Agent.kl_threshold self._policy_update_delay = self.cfg.Agent.policy_update_delay self._smooth_regularization_noise = GaussianNoise(0, 0.1, device=device) self._smooth_regularization_clip = self.cfg.Agent.smooth_regularization_clip self._rewards_shaper = self.cfg.get("Agent", {}).get("rewards_shaper", lambda rewards: rewards * 0.01) self._state_preprocessor = RunningStandardScaler self._state_preprocessor_kwargs = self.cfg.get("Agent", {}).get("state_preprocessor_kwargs", {"size": observation_space, "device": device}) '''Misc variables''' self._current_log_prob = None self._current_next_states = None self._critic_update_times = 0 # clip noise bounds if action_space is not None: self.clip_actions_min = torch.tensor(action_space.low, device=self.device) self.clip_actions_max = torch.tensor(action_space.high, device=self.device) self._set_up() def _set_up(self): """ Set up optimizer, learning rate scheduler and state/value preprocessors """ # Set up optimizer and learning rate scheduler self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self._lr_a, eps=self._adam_eps) self.critic_optimizer = torch.optim.Adam( itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), lr=self._lr_c, eps=self._adam_eps) if self._lr_scheduler is not None: self.actor_scheduler = self._lr_scheduler(self.actor_optimizer, **self._lr_scheduler_kwargs) self.critic_scheduler = self._lr_scheduler(self.critic_optimizer, **self._lr_scheduler_kwargs) self.checkpoint_modules["actor_optimizer"] = self.actor_optimizer self.checkpoint_modules["critic_optimizer"] = self.critic_optimizer # freeze target networks with respect to optimizers (update via .update_parameters()) self.target_actor.freeze_parameters(True) self.target_critic_1.freeze_parameters(True) self.target_critic_2.freeze_parameters(True) # update target networks (hard update) self.target_actor.update_parameters(self.actor, polyak=1) self.target_critic_1.update_parameters(self.critic_1, polyak=1) self.target_critic_2.update_parameters(self.critic_2, polyak=1) # set up preprocessors super()._set_up()
[docs] def act(self, states: torch.Tensor, deterministic: bool = False): if not deterministic: # sample stochastic actions actions, _ = self.actor(self._state_preprocessor(states)) else: # choose deterministic actions for evaluation TODO: check if this is correct actions, _ = self.actor(self._state_preprocessor(states)).detach() return actions, None
[docs] def store_transition(self, states: torch.Tensor, actions: torch.Tensor, next_states: torch.Tensor, rewards: torch.Tensor, terminated: torch.Tensor, truncated: torch.Tensor, infos: torch.Tensor): super().store_transition(states=states, actions=actions, next_states=next_states, rewards=rewards, terminated=terminated, truncated=truncated, infos=infos) # reward shaping if self._rewards_shaper is not None: rewards = self._rewards_shaper(rewards) # storage transition in memory self.memory.add_samples(states=states, actions=actions, rewards=rewards, next_states=next_states, terminated=terminated, truncated=truncated)
[docs] def update_net(self): """ Update the network :return: """ # sample a batch from memory sampled_states, sampled_actions, sampled_rewards, sampled_next_states, sampled_dones = \ self.memory.sample(names=self._tensors_names, batch_size=self._batch_size)[0] # learning epochs for gradient_step in range(self._gradient_steps): sampled_states = self._state_preprocessor(sampled_states, train=not gradient_step) sampled_next_states = self._state_preprocessor(sampled_next_states) with torch.no_grad(): # target policy smoothing next_actions, _ = self.target_actor(sampled_next_states) noises = torch.clamp(self._smooth_regularization_noise.sample(next_actions.shape), min=-self._smooth_regularization_clip, max=self._smooth_regularization_clip) next_actions.add_(noises) next_actions.clamp_(min=self.clip_actions_min, max=self.clip_actions_max) # compute target values target_q1_values = self.target_critic_1(sampled_next_states, next_actions) target_q2_values = self.target_critic_2(sampled_next_states, next_actions) target_q_values = torch.min(target_q1_values, target_q2_values) target_values = sampled_rewards + self._discount * sampled_dones.logical_not() * target_q_values # compute critic loss critic_1_values = self.critic_1(sampled_states, sampled_actions) critic_2_values = self.critic_2(sampled_states, sampled_actions) critic_loss = (F.mse_loss(critic_1_values, target_values) + F.mse_loss(critic_2_values, target_values)) / 2 # optimization step (critic) self.critic_optimizer.zero_grad() critic_loss.backward() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip) self.critic_optimizer.step() self._critic_update_times += 1 if self._critic_update_times % self._policy_update_delay == 0: # compute actor (actor) loss actions, log_prob = self.actor(sampled_states) critic_1_values = self.critic_1(sampled_states, actions) critic_2_values = self.critic_2(sampled_states, actions) policy_loss = - torch.min(critic_1_values, critic_2_values).mean() # optimization step (actor) self.actor_optimizer.zero_grad() policy_loss.backward() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(self.actor.parameters(), self._grad_norm_clip) self.actor_optimizer.step() # update target networks self.target_actor.update_parameters(self.actor, polyak=self._polyak) self.target_critic_1.update_parameters(self.critic_1, polyak=self._polyak) self.target_critic_2.update_parameters(self.critic_2, polyak=self._polyak) self.track_data("Loss / Actor loss", policy_loss.item()) # update learning rate if self._lr_scheduler: self.actor_scheduler.step() self.critic_scheduler.step() # record data self.track_data("Loss / Critic loss", critic_loss.item()) self.track_data("Q-network / Q1 (max)", torch.max(critic_1_values).item()) self.track_data("Q-network / Q1 (min)", torch.min(critic_1_values).item()) self.track_data("Q-network / Q1 (mean)", torch.mean(critic_1_values).item()) self.track_data("Q-network / Q2 (max)", torch.max(critic_2_values).item()) self.track_data("Q-network / Q2 (min)", torch.min(critic_2_values).item()) self.track_data("Q-network / Q2 (mean)", torch.mean(critic_2_values).item()) self.track_data("Target / Target (max)", torch.max(target_values).item()) self.track_data("Target / Target (min)", torch.min(target_values).item()) self.track_data("Target / Target (mean)", torch.mean(target_values).item()) if self._lr_scheduler: self.track_data("Learning / Actor learning rate", self.actor_scheduler.get_last_lr()[0]) self.track_data("Learning / Critic learning rate", self.critic_scheduler.get_last_lr()[0])