Source code for rofunc.learning.RofuncRL.processors.schedulers

#  Copyright (C) 2024, Junjia Liu
#
#  This file is part of Rofunc.
#
#  Rofunc is licensed under the GNU General Public License v3.0.
#  You may use, distribute, and modify this code under the terms of the GPL-3.0.
#
#  Additional Terms for Commercial Use:
#  Commercial use requires sharing 50% of net profits with the copyright holder.
#  Financial reports and regular payments must be provided as agreed in writing.
#  Non-compliance results in revocation of commercial rights.
#
#  For more details, see <https://www.gnu.org/licenses/>.
#  Contact: skylark0924@gmail.com

from typing import Union, Optional

import torch
from torch.optim.lr_scheduler import _LRScheduler


[docs]class KLAdaptiveRL(_LRScheduler): def __init__(self, optimizer: torch.optim.Optimizer, kl_threshold: float = 0.008, min_lr: float = 1e-6, max_lr: float = 1e-2, kl_factor: float = 2, lr_factor: float = 1.5, last_epoch: int = -1, verbose: bool = False) -> None: """Adaptive KL scheduler Adjusts the learning rate according to the KL divergence. The implementation is adapted from the rl_games library (https://github.com/Denys88/rl_games/blob/master/rl_games/common/schedulers.py) .. note:: This scheduler is only available for PPO at the moment. Applying it to other agents will not change the learning rate Example:: >>> scheduler = KLAdaptiveRL(optimizer, kl_threshold=0.01) >>> for epoch in range(100): >>> train(...) >>> validate(...) >>> kl_divergence = ... >>> scheduler.step(kl_divergence) :param optimizer: Wrapped optimizer :type optimizer: torch.optim.Optimizer :param kl_threshold: Threshold for KL divergence (default: ``0.008``) :type kl_threshold: float, optional :param min_lr: Lower bound for learning rate (default: ``1e-6``) :type min_lr: float, optional :param max_lr: Upper bound for learning rate (default: ``1e-2``) :type max_lr: float, optional :param kl_factor: The number used to modify the KL divergence threshold (default: ``2``) :type kl_factor: float, optional :param lr_factor: The number used to modify the learning rate (default: ``1.5``) :type lr_factor: float, optional :param last_epoch: The index of last epoch (default: ``-1``) :type last_epoch: int, optional :param verbose: Verbose mode (default: ``False``) :type verbose: bool, optional """ super().__init__(optimizer, last_epoch, verbose) self.kl_threshold = kl_threshold self.min_lr = min_lr self.max_lr = max_lr self._kl_factor = kl_factor self._lr_factor = lr_factor self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
[docs] def step(self, kl: Optional[Union[torch.Tensor, float]] = None, epoch: Optional[int] = None) -> None: """ Step scheduler Example:: >>> kl = torch.distributions.kl_divergence(p, q) >>> kl tensor([0.0332, 0.0500, 0.0383, ..., 0.0076, 0.0240, 0.0164]) >>> scheduler.step(kl.mean()) >>> kl = 0.0046 >>> scheduler.step(kl) :param kl: KL divergence (default: None) If None, no adjustment is made. If tensor, the number of elements must be 1 :type kl: torch.Tensor, float, None, optional :param epoch: Epoch (default: None) :type epoch: int, optional """ if kl is not None: for group in self.optimizer.param_groups: if kl > self.kl_threshold * self._kl_factor: group['lr'] = max(group['lr'] / self._lr_factor, self.min_lr) elif kl < self.kl_threshold / self._kl_factor: group['lr'] = min(group['lr'] * self._lr_factor, self.max_lr) self._last_lr = [group['lr'] for group in self.optimizer.param_groups]