# Copyright 2023, Junjia LIU, jjliu@mae.cuhk.edu.hk
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import datetime
import multiprocessing
import os
import random
from typing import Union, Optional
import gym
import gymnasium
import numpy as np
import torch
import tqdm
from omegaconf import DictConfig, OmegaConf
from tensorboard import program
from torch.utils.tensorboard import SummaryWriter
import rofunc as rf
from rofunc.config.utils import omegaconf_to_dict
from rofunc.learning.RofuncRL.processors.normalizers import Normalization
from rofunc.learning.utils.env_wrappers import wrap_env
from rofunc.utils.logger.beauty_logger import BeautyLogger
from rofunc.utils.oslab.internet import reserve_sock_addr
[docs]class BaseTrainer:
def __init__(self,
cfg: DictConfig,
env: Union[gym.Env, gymnasium.Env],
device: Optional[Union[str, torch.device]] = None,
env_name: Optional[str] = None,
inference: bool = False):
self.cfg = cfg
self.cfg_trainer = cfg.train.Trainer
self.agent = None
self.device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
self.env_name = env_name
self.inference_flag = inference
'''Experiment log directory'''
directory = self.cfg.train.Trainer.experiment_directory
exp_name = self.cfg.train.Trainer.experiment_name
directory = os.path.join(os.getcwd(), "runs") if not directory else directory
exp_name = datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S-%f") if not exp_name else exp_name
if not inference:
exp_name = "RofuncRL_{}_{}_{}".format(self.__class__.__name__, env_name, exp_name)
else:
exp_name = "RofuncRL_{}_{}_{}_inference".format(self.__class__.__name__, env_name, exp_name)
self.exp_dir = os.path.join(directory, exp_name)
rf.oslab.create_dir(self.exp_dir, local_verbose=True)
'''Rofunc logger'''
self.rofunc_logger = BeautyLogger(self.exp_dir, verbose=self.cfg.train.Trainer.rofunc_logger_kwargs.verbose)
self.rofunc_logger.info(f"Trainer configurations:\n{OmegaConf.to_yaml(self.cfg.train)}")
'''Setup Weights & Biases'''
self.setup_wandb()
'''TensorBoard'''
# main entry to log data for consumption and visualization by TensorBoard
self.write_interval = self.cfg.train.Trainer.write_interval
self.writer = SummaryWriter(log_dir=self.exp_dir)
tb = program.TensorBoard()
# Find a free port
with reserve_sock_addr() as (h, p):
argv = ['tensorboard', f"--logdir={self.exp_dir}", f"--port={p}"]
tb_extra_args = os.getenv('TB_EXTRA_ARGS', "")
if tb_extra_args:
argv += tb_extra_args.split(' ')
tb.configure(argv)
# Launch TensorBoard
url = tb.launch()
self.rofunc_logger.info(f"Tensorboard listening on {url}")
'''Misc variables'''
self.maximum_steps = self.cfg.train.Trainer.get("maximum_steps", int(1e6))
self.start_learning_steps = self.cfg.train.Trainer.get("start_learning_steps", 0)
self.random_steps = self.cfg.train.Trainer.get("random_steps", 0)
self.rollouts = self.cfg.train.Trainer.get("rollouts", 16)
self.max_episode_steps = self.cfg.train.Trainer.get("max_episode_steps", 250)
self._step = 0
self._rollout = 0
self._update_times = 0
self.start_time = None
'''Evaluation and inference configurations'''
self.eval_flag = self.cfg.train.Trainer.get("eval_flag", False)
self.eval_freq = self.cfg.train.Trainer.get("eval_freq", 5 * self.max_episode_steps)
self.eval_steps = self.cfg.train.Trainer.get("eval_steps", 1000)
self.eval_env_seed = self.cfg.train.Trainer.get("eval_env_seed", random.randint(0, 10000))
self.use_eval_thread = self.cfg.train.Trainer.get("use_eval_thread", False)
if self.eval_flag:
assert self.eval_steps % self.max_episode_steps == 0, \
f"eval_steps ({self.eval_steps}) must be a multiple of max_episode_steps ({self.max_episode_steps})."
self.inference_steps = self.cfg.train.Trainer.get("inference_steps", 1000)
self.total_rew_mean = -1e4
self.eval_rew_mean = 0
'''Environment'''
# env.device = self.device # TODO: check whether this is necessary
self.env = wrap_env(env, logger=self.rofunc_logger, seed=self.cfg.train.Trainer.seed)
self.eval_env = wrap_env(env, logger=self.rofunc_logger, seed=self.eval_env_seed) if self.eval_flag else None
self.rofunc_logger.info(f"Environment:\n "
f" action_space: {self.env.action_space.shape}\n "
f" observation_space: {self.env.observation_space.shape}\n "
f" num_envs: {self.env.num_envs}")
if hasattr(self.env._env, "cfg"):
self.rofunc_logger.info(f"Task configurations:\n{self.env._env.cfg}")
'''Normalization'''
self.state_norm = Normalization(shape=self.env.observation_space, device=device)
[docs] def setup_wandb(self):
# setup Weights & Biases
if self.cfg.train.get("Trainer", {}).get("wandb", False) and self.inference_flag is False:
import wandb
# set default values
wandb_kwargs = copy.deepcopy(self.cfg.train.get("Trainer", {}).get("wandb_kwargs", {}))
wandb_kwargs.name = self.exp_dir.split("/")[-1] if wandb_kwargs.get("name",
None) is None else wandb_kwargs.name
cfg_copy = self.cfg.copy()
del cfg_copy.hydra
config = OmegaConf.to_container(cfg_copy, resolve=True, throw_on_missing=True)
wandb.tensorboard.patch(root_logdir=self.exp_dir, pytorch=True)
# init Weights & Biases
wandb.init(config=config, project=wandb_kwargs.get("project", "RofuncRL"), name=wandb_kwargs.name,
sync_tensorboard=True)
[docs] def get_action(self, states):
if self._step < self.random_steps:
actions = torch.tensor([self.env.action_space.sample() for _ in range(self.env.num_envs)]).to(self.device)
else:
actions, _ = self.agent.act(states)
return actions
[docs] def train(self):
"""
Main training loop. \n
- Reset the environment
- For each step:
- Pre-interaction
- Obtain action from agent
- Interact with environment
- Store transition
- Reset the environment
- Post-interaction
- Close the environment
"""
# reset env
states, infos = self.env.reset()
with tqdm.trange(self.maximum_steps, ncols=80, colour='green') as self.t_bar:
for _ in self.t_bar:
self.pre_interaction()
# Obtain action from agent
with torch.no_grad():
actions = self.get_action(states)
# Interact with environment
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
next_states, rewards, terminated, truncated, infos = self.agent.multi_gpu_transfer(next_states, rewards,
terminated,
truncated, infos)
with torch.no_grad():
# Store transition
self.agent.store_transition(states=states, actions=actions, next_states=next_states,
rewards=rewards, terminated=terminated, truncated=truncated,
infos=infos)
self.post_interaction()
self._step += 1
with torch.no_grad():
# Reset the environment
if terminated.any() or truncated.any():
states, infos = self.env.reset()
else:
states = next_states.clone()
# close the environment
self.env.close()
# close the logger
self.writer.close()
self.rofunc_logger.info('Training complete.')
[docs] def pre_interaction(self):
pass
[docs] def post_interaction(self):
"""
Base post-interaction function
- Write to tensorboard
- Save checkpoints
"""
# Update best models and tensorboard
if not self._step % self.write_interval and self.write_interval > 0:
# update best models
self.total_rew_mean = np.mean(self.agent.tracking_data.get("Reward / Total reward (mean)", -1e4))
if self.total_rew_mean > self.agent.checkpoint_best_modules["reward"]:
self.agent.checkpoint_best_modules["timestep"] = self._step
self.agent.checkpoint_best_modules["reward"] = self.total_rew_mean
self.agent.checkpoint_best_modules["saved"] = False
self.agent.checkpoint_best_modules["modules"] = {k: copy.deepcopy(self.agent._get_internal_value(v)) for
k, v in self.agent.checkpoint_modules.items()}
self.agent.save_ckpt(os.path.join(self.agent.checkpoint_dir, "best_ckpt.pth"))
# Update tensorboard
self.write_tensorboard()
# Update tqdm bar message
if self.eval_flag:
post_str = f"Rew/Best/Eval: {self.total_rew_mean:.2f}/{self.agent.checkpoint_best_modules['reward']:.2f}/{self.eval_rew_mean:.2f}"
else:
post_str = f"Rew/Best: {self.total_rew_mean:.2f}/{self.agent.checkpoint_best_modules['reward']:.2f}"
self.t_bar.set_postfix_str(post_str)
self.rofunc_logger.info(f"Step: {self._step}, {post_str}", local_verbose=False)
# Save checkpoints
if self.agent.checkpoint_interval is not None:
if not (self._step + 1) % self.agent.checkpoint_interval and \
self.agent.checkpoint_interval > 0 and self._step > 1:
self.agent.save_ckpt(os.path.join(self.agent.checkpoint_dir, f"ckpt_{self._step + 1}.pth"))
# Evaluate per self.eval_freq steps
if self.eval_flag:
if not (self._step + 1) % self.eval_freq and (self._step + 1) > self.start_learning_steps:
self.rofunc_logger.info(f'Evaluate at step {self._step + 1}.', local_verbose=False)
if self.use_eval_thread: # Use a separate thread to run evaluation
self.rofunc_logger.info('Start evaluation thread.', local_verbose=False)
eval_thread = multiprocessing.Process(target=self.eval)
eval_thread.start()
eval_thread.join()
else:
self.eval()
[docs] def write_tensorboard(self):
for k, v in self.agent.tracking_data.items():
if k.endswith("(min)"):
self.writer.add_scalar(k, np.min(v), self._step)
elif k.endswith("(max)"):
self.writer.add_scalar(k, np.max(v), self._step)
else:
self.writer.add_scalar(k, np.mean(v), self._step)
# reset data containers for next iteration
self.agent.track_rewards.clear()
self.agent.track_timesteps.clear()
self.agent.tracking_data.clear()
[docs] def eval(self):
# reset env
states, infos = self.eval_env.reset()
for _ in tqdm.trange(self.eval_steps):
with torch.no_grad():
# Obtain action from agent
actions, _ = self.agent.act(states, deterministic=True) # TODO: check
# Interact with environment
next_states, rewards, terminated, truncated, infos = self.eval_env.step(actions)
# Reset the environment
if terminated.any() or truncated.any():
states, infos = self.eval_env.reset()
else:
states = next_states.clone()
# close the environment
self.eval_env.close()
self.rofunc_logger.info('Evaluation complete.')
[docs] def inference(self):
# reset env
states, infos = self.env.reset()
for _ in tqdm.trange(self.inference_steps):
self.pre_interaction()
with torch.no_grad():
# Obtain action from agent
actions, _ = self.agent.act(states, deterministic=True) # TODO: check
# Interact with environment
next_states, rewards, terminated, truncated, infos = self.env.step(actions)
# Reset the environment
if terminated.any() or truncated.any():
states, infos = self.env.reset()
else:
states = next_states.clone()
# close the environment
self.env.close()
self.rofunc_logger.info('Inference complete.')