Source code for rofunc.learning.RofuncRL.processors.noises
# Copyright (C) 2024, Junjia Liu
#
# This file is part of Rofunc.
#
# Rofunc is licensed under the GNU General Public License v3.0.
# You may use, distribute, and modify this code under the terms of the GPL-3.0.
#
# Additional Terms for Commercial Use:
# Commercial use requires sharing 50% of net profits with the copyright holder.
# Financial reports and regular payments must be provided as agreed in writing.
# Non-compliance results in revocation of commercial rights.
#
# For more details, see <https://www.gnu.org/licenses/>.
# Contact: skylark0924@gmail.com
from typing import Optional, Union, Tuple
import torch
from torch.distributions import Normal
[docs]class Noise:
def __init__(self, device: Optional[Union[str, torch.device]] = None) -> None:
"""
Base class representing a noise
"""
self.device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu") if device is None else torch.device(device)
[docs] def sample_like(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Sample a noise with the same size (shape) as the input tensor
:param tensor: Input tensor used to determine output tensor size (shape)
:return: Sampled noise
"""
return self.sample(tensor.size())
[docs] def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor:
"""
Noise sampling method to be implemented by the inheriting classes
:param size: Shape of the sampled tensor
"""
raise NotImplementedError("The sampling method (.sample()) is not implemented")
[docs]class GaussianNoise(Noise):
def __init__(self, mean: float, std: float, device: Optional[Union[str, torch.device]] = None) -> None:
"""
Class representing a Gaussian noise
:param mean: Mean of the normal distribution
:param std: Standard deviation of the normal distribution
:param device: Device on which a torch tensor is or will be allocated (default: ``None``).
If None, the device will be either ``"cuda:0"`` if available or ``"cpu"``
"""
super().__init__(device)
self.distribution = Normal(loc=torch.tensor(mean, device=self.device, dtype=torch.float32),
scale=torch.tensor(std, device=self.device, dtype=torch.float32))
[docs] def sample(self, size: Union[Tuple[int], torch.Size]) -> torch.Tensor:
"""
Sample a Gaussian noise
:param size: Shape of the sampled tensor
"""
return self.distribution.sample(size)