Source code for rofunc.learning.RofuncRL.models.misc_models

# Copyright 2023, Junjia LIU, jjliu@mae.cuhk.edu.hk
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union, Tuple, Optional

import gym
import gymnasium
import torch
import torch.nn as nn
import transformers
from omegaconf import DictConfig
from torch import Tensor

from rofunc.config.utils import omegaconf_to_dict
from rofunc.learning.RofuncRL.models.base_models import BaseMLP
from rofunc.learning.RofuncRL.models.utils import init_layers, get_space_dim
from rofunc.learning.RofuncRL.state_encoders.base_encoders import EmptyEncoder


[docs]class ASEDiscEnc(BaseMLP): def __init__(self, cfg: DictConfig, input_dim: int, enc_output_dim: int, disc_output_dim: int, cfg_name: str): super().__init__(cfg, input_dim, disc_output_dim, cfg_name) self.encoder_layer = nn.Linear(self.mlp_hidden_dims[-1], enc_output_dim) self.disc_layer = self.output_net if self.cfg.use_init: init_layers(self.encoder_layer, gain=1.0)
[docs] def forward(self, x: Tensor) -> Tensor: return super().forward(x) # Same as self.get_disc(x)
[docs] def get_enc(self, x: Tensor) -> Tensor: x = self.backbone_net(x) x = self.encoder_layer(x) return x
[docs] def get_disc(self, x: Tensor) -> Tensor: x = self.backbone_net(x) x = self.disc_layer(x) return x
[docs]class DTrans(nn.Module): def __init__(self, cfg: DictConfig, observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]], action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]], state_encoder: Optional[nn.Module] = EmptyEncoder(), cfg_name='actor'): super().__init__() self.cfg = cfg self.cfg_dict = omegaconf_to_dict(self.cfg) self.action_dim = get_space_dim(action_space) self.n_embd = self.cfg_dict[cfg_name]["n_embd"] self.max_ep_len = self.cfg_dict[cfg_name]["max_episode_steps"] # state encoder self.state_encoder = state_encoder if isinstance(self.state_encoder, EmptyEncoder): self.state_dim = get_space_dim(observation_space) else: self.state_dim = self.state_encoder.output_dim gpt_config = transformers.GPT2Config( vocab_size=1, # doesn't matter -- we don't use the vocab n_embd=self.n_embd, n_layer=self.cfg_dict[cfg_name]["n_layer"], n_head=self.cfg_dict[cfg_name]["n_head"], n_inner=self.n_embd * 4, activation_function=self.cfg_dict[cfg_name]["activation_function"], resid_pdrop=self.cfg_dict[cfg_name]["dropout"], attn_pdrop=self.cfg_dict[cfg_name]["dropout"], n_positions=1024 ) self.embed_timestep = nn.Embedding(self.max_ep_len, self.n_embd) self.embed_return = torch.nn.Linear(1, self.n_embd) self.embed_state = torch.nn.Linear(self.state_dim, self.n_embd) self.embed_action = torch.nn.Linear(self.action_dim, self.n_embd) self.embed_ln = nn.LayerNorm(self.n_embd) self.backbone_net = transformers.GPT2Model(gpt_config) # note: we don't predict states or returns for the paper self.predict_state = torch.nn.Linear(self.n_embd, self.state_dim) self.predict_action = nn.Sequential(*([nn.Linear(self.n_embd, self.action_dim)] + ([nn.Tanh()] if self.cfg.use_action_out_tanh else []))) self.predict_return = torch.nn.Linear(self.n_embd, 1)
[docs] def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None): """ :param states: [B, T, C, H, W] or [B, T, Value] :param actions: [B, T, Action] :param rewards: [B, T, 1] :param returns_to_go: [B, T, 1] :param timesteps: [B, T] :param attention_mask: [B, T] :return: """ batch_size, seq_length = states.shape[0], states.shape[1] if attention_mask is None: # attention mask for GPT: 1 if can be attended to, 0 if not attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) # state encoder states = self.state_encoder(states) # embed each modality with a different head state_embeddings = self.embed_state(states) # [B, T, n_embd] action_embeddings = self.embed_action(actions) # [B, T, n_embd] returns_embeddings = self.embed_return(returns_to_go) # [B, T, n_embd] time_embeddings = self.embed_timestep(timesteps) # [B, T, n_embd] # time embeddings are treated similar to positional embeddings state_embeddings = state_embeddings + time_embeddings action_embeddings = action_embeddings + time_embeddings returns_embeddings = returns_embeddings + time_embeddings # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...) # which works nice in an autoregressive sense since states predict actions # [B, 3, T, n_embd] -> [B, T, 3, n_embd] -> [B, T*3, n_embd] stacked_inputs = torch.stack((returns_embeddings, state_embeddings, action_embeddings), dim=1 ).permute(0, 2, 1, 3).reshape(batch_size, 3 * seq_length, self.n_embd) stacked_inputs = self.embed_ln(stacked_inputs) # to make the attention mask fit the stacked inputs, have to stack it as well stacked_attention_mask = torch.stack((attention_mask, attention_mask, attention_mask), dim=1 ).permute(0, 2, 1).reshape(batch_size, 3 * seq_length) # we feed in the input embeddings (not word indices as in NLP) to the model transformer_outputs = self.backbone_net(inputs_embeds=stacked_inputs, attention_mask=stacked_attention_mask) x = transformer_outputs['last_hidden_state'] # reshape x so that the second dimension corresponds to the original # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t x = x.reshape(batch_size, seq_length, 3, self.n_embd).permute(0, 2, 1, 3) # get predictions return_preds = self.predict_return(x[:, 2]) # predict next return given state and action state_preds = self.predict_state(x[:, 2]) # predict next state given state and action action_preds = self.predict_action(x[:, 1]) # predict next action given state return state_preds, action_preds, return_preds