rofunc.learning.RofuncRL.processors.schedulers#

1.  Module Contents#

1.1.  Classes#

KLAdaptiveRL

1.2.  API#

class rofunc.learning.RofuncRL.processors.schedulers.KLAdaptiveRL(optimizer: torch.optim.Optimizer, kl_threshold: float = 0.008, min_lr: float = 1e-06, max_lr: float = 0.01, kl_factor: float = 2, lr_factor: float = 1.5, last_epoch: int = -1, verbose: bool = False)[source]#

Bases: torch.optim.lr_scheduler._LRScheduler

Initialization

Adaptive KL scheduler

Adjusts the learning rate according to the KL divergence. The implementation is adapted from the rl_games library (Denys88/rl_games)

Note

This scheduler is only available for PPO at the moment. Applying it to other agents will not change the learning rate

Example:

>>> scheduler = KLAdaptiveRL(optimizer, kl_threshold=0.01)
>>> for epoch in range(100):
>>>     train(...)
>>>     validate(...)
>>>     kl_divergence = ...
>>>     scheduler.step(kl_divergence)
Parameters:
  • optimizer (torch.optim.Optimizer) – Wrapped optimizer

  • kl_threshold (float, optional) – Threshold for KL divergence (default: 0.008)

  • min_lr (float, optional) – Lower bound for learning rate (default: 1e-6)

  • max_lr (float, optional) – Upper bound for learning rate (default: 1e-2)

  • kl_factor (float, optional) – The number used to modify the KL divergence threshold (default: 2)

  • lr_factor (float, optional) – The number used to modify the learning rate (default: 1.5)

  • last_epoch (int, optional) – The index of last epoch (default: -1)

  • verbose (bool, optional) – Verbose mode (default: False)

step(kl: Optional[Union[torch.Tensor, float]] = None, epoch: Optional[int] = None) None[source]#

Step scheduler

Example:

>>> kl = torch.distributions.kl_divergence(p, q)
>>> kl
tensor([0.0332, 0.0500, 0.0383,  ..., 0.0076, 0.0240, 0.0164])
>>> scheduler.step(kl.mean())

>>> kl = 0.0046
>>> scheduler.step(kl)
Parameters:
  • kl (torch.Tensor, float, None, optional) – KL divergence (default: None) If None, no adjustment is made. If tensor, the number of elements must be 1

  • epoch (int, optional) – Epoch (default: None)