rofunc.learning.RofuncRL.models.misc_models#

1.  Module Contents#

1.1.  Classes#

ASEDiscEnc

DTrans

1.2.  API#

class rofunc.learning.RofuncRL.models.misc_models.ASEDiscEnc(cfg: omegaconf.DictConfig, input_dim: int, enc_output_dim: int, disc_output_dim: int, cfg_name: str)[source]#

Bases: rofunc.learning.RofuncRL.models.base_models.BaseMLP

forward(x: torch.Tensor) torch.Tensor[source]#
get_enc(x: torch.Tensor) torch.Tensor[source]#
get_disc(x: torch.Tensor) torch.Tensor[source]#
class rofunc.learning.RofuncRL.models.misc_models.DTrans(cfg: omegaconf.DictConfig, observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]], action_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]], state_encoder: Optional[torch.nn.Module] = EmptyEncoder(), cfg_name='actor')[source]#

Bases: torch.nn.Module

forward(states, actions, rewards, returns_to_go, timesteps, attention_mask=None)[source]#
Parameters:
  • states – [B, T, C, H, W] or [B, T, Value]

  • actions – [B, T, Action]

  • rewards – [B, T, 1]

  • returns_to_go – [B, T, 1]

  • timesteps – [B, T]

  • attention_mask – [B, T]

Returns: