Source code for rofunc.learning.RofuncRL.state_encoders.visual_encoders

# Copyright 2023, Junjia LIU, jjliu@mae.cuhk.edu.hk
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn

from rofunc.learning.RofuncRL.models.utils import activation_func
from rofunc.learning.RofuncRL.models.utils import build_mlp, build_cnn, init_layers
from rofunc.learning.RofuncRL.state_encoders.base_encoders import BaseEncoder


[docs]class CNNEncoder(BaseEncoder): def __init__(self, cfg): super().__init__(cfg) self.cnn_kernel_size = self.cfg_dict[self.cfg_name]['cnn_args']['cnn_kernel_size'] self.cnn_stride = self.cfg_dict[self.cfg_name]['cnn_args']['cnn_stride'] self.cnn_padding = self.cfg_dict[self.cfg_name]['cnn_args']['cnn_padding'] self.cnn_dilation = self.cfg_dict[self.cfg_name]['cnn_args']['cnn_dilation'] self.cnn_hidden_dims = self.cfg_dict[self.cfg_name]['cnn_args']['cnn_hidden_dims'] self.cnn_activation = activation_func(self.cfg_dict[self.cfg_name]['cnn_args']['cnn_activation']) self.cnn_pooling = self.cfg_dict[self.cfg_name]['cnn_args']['cnn_pooling'] self.cnn_pooling_args = self.cfg_dict[self.cfg_name]['cnn_args']['cnn_pooling_args'] self.mlp_inp_dims = self.cfg_dict[self.cfg_name]['cnn_args']['mlp_inp_dims'] self.mlp_hidden_dims = self.cfg_dict[self.cfg_name]['cnn_args']['mlp_hidden_dims'] self.mlp_activation = activation_func(self.cfg_dict[self.cfg_name]['cnn_args']['mlp_activation']) self.backbone_net = build_cnn(dims=[self.input_dim, *self.cnn_hidden_dims], kernel_size=self.cnn_kernel_size, stride=self.cnn_stride, padding=self.cnn_padding, dilation=self.cnn_dilation, hidden_activation=self.cnn_activation, pooling=self.cnn_pooling, pooling_args=self.cnn_pooling_args) self.flatten = nn.Flatten(start_dim=1) self.output_net = build_mlp(dims=[self.mlp_inp_dims, *self.mlp_hidden_dims, self.output_dim], hidden_activation=self.mlp_activation) self.checkpoint_modules = {'backbone_net': self.backbone_net, 'output_net': self.output_net} if self.cfg.use_init: init_layers(self.backbone_net, gain=1.0, init_type='kaiming_uniform') init_layers(self.output_net, gain=1.0) self.freeze_net_list = [self.backbone_net, self.output_net] self.set_up()
[docs] def forward(self, inputs): x = self.backbone_net(inputs) x = self.flatten(x) x = self.output_net(x) return x
[docs]class ResnetEncoder(BaseEncoder): def __init__(self, cfg): super().__init__(cfg) self.sub_type = self.cfg_dict[self.cfg_name]['resnet_args']['sub_type'] assert self.sub_type in ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] self.use_pretrained = self.cfg_dict[self.cfg_name]['use_pretrained'] self.backbone_net = torch.hub.load('pytorch/vision', self.sub_type, num_classes=self.output_dim, pretrained=self.use_pretrained) self.freeze_net_list = [self.backbone_net] self.set_up()
[docs] def forward(self, inputs): x = self.backbone_net(inputs) return x
[docs]class ViTEncoder(BaseEncoder): def __init__(self, cfg): super().__init__(cfg) self.sub_type = self.cfg_dict[self.cfg_name]['vit_args']['sub_type'] assert self.sub_type in ['vit_b_16', 'vit_b_32', 'vit_h_14', 'vit_l_16', 'vit_l_32'] self.use_pretrained = self.cfg_dict[self.cfg_name]['use_pretrained'] self.backbone_net = torch.hub.load('pytorch/vision', self.sub_type, num_classes=self.output_dim, pretrained=self.use_pretrained) self.freeze_net_list = [self.backbone_net] self.set_up()
[docs] def forward(self, inputs): x = self.backbone_net(inputs) return x
if __name__ == '__main__': from omegaconf import DictConfig import rofunc as rf cfg = DictConfig({'use_init': True, 'state_encoder': {'inp_channels': 4, 'out_channels': 512, 'use_pretrained': True, 'freeze': False, 'model_ckpt': 'test.ckpt', 'cnn_args': { 'cnn_structure': ['conv', 'relu', 'conv', 'relu', 'pool'], 'cnn_kernel_size': [8, 4], 'cnn_stride': 1, 'cnn_padding': 1, 'cnn_dilation': 1, 'cnn_hidden_dims': [32, 64], 'cnn_activation': 'relu', 'cnn_pooling': None, # ['max', 'avg'] 'cnn_pooling_args': { 'cnn_pooling_kernel_size': 2, 'cnn_pooling_stride': 2, 'cnn_pooling_padding': 0, 'cnn_pooling_dilation': 1}, 'mlp_inp_dims': 215296, 'mlp_hidden_dims': [512], 'mlp_activation': 'relu', }}}) model = CNNEncoder(cfg=cfg).to('cuda:0') print(model) # cfg = DictConfig( # {'use_init': True, # 'state_encoder': {'inp_channels': 4, 'out_channels': 512, 'resnet_args': {'sub_type': 'resnet18'}, # 'use_pretrained': False}}) # model = ResnetEncoder(cfg=cfg) # print(model) # cfg = DictConfig({'use_init': True, # 'state_encoder': {'inp_channels': 4, 'out_channels': 512, 'vit_args': {'sub_type': 'vit_b_16'}, # 'use_pretrained': False}}) # model = ViTEncoder(cfg=cfg) # print(model) inp_latent_vector = torch.randn(32, 4, 64, 64).to('cuda:0') # [B, C, H, W] gt_latent_vector = torch.randn(32, 512).to('cuda:0') optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) for i in range(10000): out_latent_vector = model(inp_latent_vector) # predicted_latent_vector = out_latent_vector.last_hidden_state[:, -1:, :] # print(out_latent_vector) loss = nn.MSELoss()(out_latent_vector, gt_latent_vector) optimizer.zero_grad() loss.backward() optimizer.step() if i % 100 == 0: model.save_ckpt('test.ckpt') rf.logger.beauty_print('Save ckpt') print(loss)