Source code for rofunc.learning.RofuncRL.processors.schedulers
# Copyright 2023, Junjia LIU, jjliu@mae.cuhk.edu.hk
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Optional
import torch
from torch.optim.lr_scheduler import _LRScheduler
[docs]class KLAdaptiveRL(_LRScheduler):
def __init__(self,
optimizer: torch.optim.Optimizer,
kl_threshold: float = 0.008,
min_lr: float = 1e-6,
max_lr: float = 1e-2,
kl_factor: float = 2,
lr_factor: float = 1.5,
last_epoch: int = -1,
verbose: bool = False) -> None:
"""Adaptive KL scheduler
Adjusts the learning rate according to the KL divergence.
The implementation is adapted from the rl_games library
(https://github.com/Denys88/rl_games/blob/master/rl_games/common/schedulers.py)
.. note::
This scheduler is only available for PPO at the moment.
Applying it to other agents will not change the learning rate
Example::
>>> scheduler = KLAdaptiveRL(optimizer, kl_threshold=0.01)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> kl_divergence = ...
>>> scheduler.step(kl_divergence)
:param optimizer: Wrapped optimizer
:type optimizer: torch.optim.Optimizer
:param kl_threshold: Threshold for KL divergence (default: ``0.008``)
:type kl_threshold: float, optional
:param min_lr: Lower bound for learning rate (default: ``1e-6``)
:type min_lr: float, optional
:param max_lr: Upper bound for learning rate (default: ``1e-2``)
:type max_lr: float, optional
:param kl_factor: The number used to modify the KL divergence threshold (default: ``2``)
:type kl_factor: float, optional
:param lr_factor: The number used to modify the learning rate (default: ``1.5``)
:type lr_factor: float, optional
:param last_epoch: The index of last epoch (default: ``-1``)
:type last_epoch: int, optional
:param verbose: Verbose mode (default: ``False``)
:type verbose: bool, optional
"""
super().__init__(optimizer, last_epoch, verbose)
self.kl_threshold = kl_threshold
self.min_lr = min_lr
self.max_lr = max_lr
self._kl_factor = kl_factor
self._lr_factor = lr_factor
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
[docs] def step(self, kl: Optional[Union[torch.Tensor, float]] = None, epoch: Optional[int] = None) -> None:
"""
Step scheduler
Example::
>>> kl = torch.distributions.kl_divergence(p, q)
>>> kl
tensor([0.0332, 0.0500, 0.0383, ..., 0.0076, 0.0240, 0.0164])
>>> scheduler.step(kl.mean())
>>> kl = 0.0046
>>> scheduler.step(kl)
:param kl: KL divergence (default: None)
If None, no adjustment is made.
If tensor, the number of elements must be 1
:type kl: torch.Tensor, float, None, optional
:param epoch: Epoch (default: None)
:type epoch: int, optional
"""
if kl is not None:
for group in self.optimizer.param_groups:
if kl > self.kl_threshold * self._kl_factor:
group['lr'] = max(group['lr'] / self._lr_factor, self.min_lr)
elif kl < self.kl_threshold / self._kl_factor:
group['lr'] = min(group['lr'] * self._lr_factor, self.max_lr)
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]