Source code for rofunc.learning.RofuncRL.state_encoders.base_encoders

#  Copyright (C) 2024, Junjia Liu
# 
#  This file is part of Rofunc.
# 
#  Rofunc is licensed under the GNU General Public License v3.0.
#  You may use, distribute, and modify this code under the terms of the GPL-3.0.
# 
#  Additional Terms for Commercial Use:
#  Commercial use requires sharing 50% of net profits with the copyright holder.
#  Financial reports and regular payments must be provided as agreed in writing.
#  Non-compliance results in revocation of commercial rights.
# 
#  For more details, see <https://www.gnu.org/licenses/>.
#  Contact: skylark0924@gmail.com
import torch
import torch.nn as nn
from omegaconf import DictConfig

import rofunc as rf
from rofunc.config.utils import omegaconf_to_dict
from rofunc.learning.RofuncRL.models.base_models import BaseMLP


[docs]class EmptyEncoder(nn.Module): def __init__(self): super(EmptyEncoder, self).__init__()
[docs] def forward(self, x): return x
[docs]class BaseEncoder(nn.Module): def __init__(self, cfg: DictConfig, cfg_name: str = 'state_encoder'): super(BaseEncoder, self).__init__() self.cfg = cfg self.cfg_dict = omegaconf_to_dict(self.cfg) self.cfg_name = cfg_name self.input_dim = self.cfg_dict[cfg_name]['inp_channels'] self.output_dim = self.cfg_dict[cfg_name]['out_channels'] self.use_pretrained = self.cfg_dict[cfg_name]['use_pretrained'] self.freeze = self.cfg_dict[cfg_name]['freeze'] self.model_ckpt = self.cfg_dict[cfg_name]['model_ckpt'] self.model_module_name = self.cfg_dict[cfg_name]['model_module_name'] if 'model_module_name' in self.cfg_dict[ cfg_name] else None
[docs] def set_up(self): if self.freeze: self.freeze_network() if self.use_pretrained: self.pre_trained_mode()
[docs] def freeze_network(self): for net in self.freeze_net_list: for param in net.parameters(): param.requires_grad = False rf.logger.beauty_print(f"Freeze state encoder", type="info")
[docs] def pre_trained_mode(self): if self.use_pretrained is True and self.model_ckpt is None: raise ValueError("Cannot freeze the encoder without a checkpoint") if self.use_pretrained: self.load_ckpt(self.model_ckpt)
def _get_internal_value(self, module): return module.state_dict() if hasattr(module, "state_dict") else module
[docs] def save_ckpt(self, path: str): modules = {} for name, module in self.checkpoint_modules.items(): modules[name] = self._get_internal_value(module) torch.save(modules, path)
[docs] def load_ckpt(self, path: str): modules = torch.load(path) if self.model_module_name is not None: modules = modules[self.model_module_name] if type(modules) is dict: for name, data in modules.items(): module = self.checkpoint_modules.get(name, None) if module is not None: if hasattr(module, "load_state_dict"): module.load_state_dict(data) if hasattr(module, "eval"): module.eval() else: raise NotImplementedError rf.logger.beauty_print(f"Loaded pretrained state encoder model from {self.model_ckpt}", type="info")
[docs]class MLPEncoder(BaseMLP): def __init__(self, cfg, cfg_name): super().__init__(cfg, cfg_name)