Source code for rofunc.learning.RofuncRL.models.critic_models

# Copyright 2023, Junjia LIU, jjliu@mae.cuhk.edu.hk
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gym
import gymnasium
import torch
import torch.nn as nn
from omegaconf import DictConfig
from torch import Tensor
from typing import Union, Tuple, Optional, List

from rofunc.config.utils import omegaconf_to_dict
from rofunc.learning.RofuncRL.models.utils import build_mlp, init_layers, activation_func, get_space_dim
from rofunc.learning.RofuncRL.state_encoders.base_encoders import EmptyEncoder


[docs]class BaseCritic(nn.Module): def __init__(self, cfg: DictConfig, observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space, List]], action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]], state_encoder: Optional[nn.Module] = EmptyEncoder(), cfg_name: str = 'critic'): super().__init__() self.cfg = cfg self.action_dim = get_space_dim(action_space) cfg_dict = omegaconf_to_dict(cfg) self.mlp_hidden_dims = cfg_dict[cfg_name]['mlp_hidden_dims'] self.mlp_activation = activation_func(cfg_dict[cfg_name]['mlp_activation']) # state encoder self.state_encoder = state_encoder if isinstance(self.state_encoder, EmptyEncoder): self.state_dim = get_space_dim(observation_space) else: self.state_dim = self.state_encoder.output_dim if isinstance(observation_space, tuple) or isinstance(observation_space, list): self.state_dim += get_space_dim(observation_space[1:]) self.backbone_net = None # build_mlp(dims=[state_dim + action_dim, *dims, 1]) self.state_avg = nn.Parameter(torch.zeros((self.state_dim,)), requires_grad=False) self.state_std = nn.Parameter(torch.ones((self.state_dim,)), requires_grad=False) self.value_avg = nn.Parameter(torch.zeros((1,)), requires_grad=False) self.value_std = nn.Parameter(torch.ones((1,)), requires_grad=False)
[docs] def state_norm(self, state: Tensor) -> Tensor: return (state - self.state_avg) / self.state_std # todo state_norm
[docs] def value_re_norm(self, value: Tensor) -> Tensor: return value * self.value_std + self.value_avg # todo value_norm
[docs] def freeze_parameters(self, freeze: bool = True) -> None: """ Freeze or unfreeze internal parameters :param freeze: freeze (True) or unfreeze (False) """ for parameters in self.parameters(): parameters.requires_grad = not freeze
[docs] def update_parameters(self, model: torch.nn.Module, polyak: float = 1) -> None: """ Update internal parameters by hard or soft (polyak averaging) update - Hard update: :math:`\\theta = \\theta_{net}` - Soft (polyak averaging) update: :math:`\\theta = (1 - \\rho) \\theta + \\rho \\theta_{net}` :param model: Model used to update the internal parameters :param polyak: Polyak hyperparameter between 0 and 1 (default: ``1``). A hard update is performed when its value is 1 """ with torch.no_grad(): # hard update if polyak == 1: for parameters, model_parameters in zip(self.parameters(), model.parameters()): parameters.data.copy_(model_parameters.data) # soft update (use in-place operations to avoid creating new parameters) else: for parameters, model_parameters in zip(self.parameters(), model.parameters()): parameters.data.mul_(1 - polyak) parameters.data.add_(polyak * model_parameters.data)
[docs]class Critic(BaseCritic): def __init__(self, cfg: DictConfig, observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space, List]], action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]], state_encoder: Optional[nn.Module] = EmptyEncoder(), cfg_name: str = 'critic'): super().__init__(cfg, observation_space, action_space, state_encoder, cfg_name) self.backbone_net = build_mlp(dims=[self.state_dim, *self.mlp_hidden_dims], hidden_activation=self.mlp_activation) self.value_net = nn.Linear(self.mlp_hidden_dims[-1], 1) if self.cfg.use_init: init_layers(self.backbone_net, gain=1.0) init_layers(self.value_net, gain=1.0)
[docs] def forward(self, state: Tensor, action: Tensor = None) -> Tensor: state = self.state_encoder(state) if action is not None: state = torch.cat((state, action), dim=1) value = self.backbone_net(state) value = self.value_net(value) return value