Source code for rofunc.learning.RofuncRL.state_encoders.base_encoders
# Copyright 2023, Junjia LIU, jjliu@mae.cuhk.edu.hk
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from omegaconf import DictConfig
import rofunc as rf
from rofunc.config.utils import omegaconf_to_dict
from rofunc.learning.RofuncRL.models.base_models import BaseMLP
[docs]class EmptyEncoder(nn.Module):
def __init__(self):
super(EmptyEncoder, self).__init__()
[docs] def forward(self, x):
return x
[docs]class BaseEncoder(nn.Module):
def __init__(self, cfg: DictConfig, cfg_name: str = 'state_encoder'):
super(BaseEncoder, self).__init__()
self.cfg = cfg
self.cfg_dict = omegaconf_to_dict(self.cfg)
self.cfg_name = cfg_name
self.input_dim = self.cfg_dict[cfg_name]['inp_channels']
self.output_dim = self.cfg_dict[cfg_name]['out_channels']
self.use_pretrained = self.cfg_dict[cfg_name]['use_pretrained']
self.freeze = self.cfg_dict[cfg_name]['freeze']
self.model_ckpt = self.cfg_dict[cfg_name]['model_ckpt']
self.model_module_name = self.cfg_dict[cfg_name]['model_module_name'] if 'model_module_name' in self.cfg_dict[
cfg_name] else None
[docs] def set_up(self):
if self.freeze:
self.freeze_network()
if self.use_pretrained:
self.pre_trained_mode()
[docs] def freeze_network(self):
for net in self.freeze_net_list:
for param in net.parameters():
param.requires_grad = False
rf.logger.beauty_print(f"Freeze state encoder", type="info")
[docs] def pre_trained_mode(self):
if self.use_pretrained is True and self.model_ckpt is None:
raise ValueError("Cannot freeze the encoder without a checkpoint")
if self.use_pretrained:
self.load_ckpt(self.model_ckpt)
def _get_internal_value(self, module):
return module.state_dict() if hasattr(module, "state_dict") else module
[docs] def save_ckpt(self, path: str):
modules = {}
for name, module in self.checkpoint_modules.items():
modules[name] = self._get_internal_value(module)
torch.save(modules, path)
[docs] def load_ckpt(self, path: str):
modules = torch.load(path)
if self.model_module_name is not None:
modules = modules[self.model_module_name]
if type(modules) is dict:
for name, data in modules.items():
module = self.checkpoint_modules.get(name, None)
if module is not None:
if hasattr(module, "load_state_dict"):
module.load_state_dict(data)
if hasattr(module, "eval"):
module.eval()
else:
raise NotImplementedError
rf.logger.beauty_print(f"Loaded pretrained state encoder model from {self.model_ckpt}", type="info")
[docs]class MLPEncoder(BaseMLP):
def __init__(self, cfg, cfg_name):
super().__init__(cfg, cfg_name)