Source code for rofunc.learning.RofuncRL.state_encoders.graph_encoders

#  Copyright (C) 2024, Junjia Liu
# 
#  This file is part of Rofunc.
# 
#  Rofunc is licensed under the GNU General Public License v3.0.
#  You may use, distribute, and modify this code under the terms of the GPL-3.0.
# 
#  Additional Terms for Commercial Use:
#  Commercial use requires sharing 50% of net profits with the copyright holder.
#  Financial reports and regular payments must be provided as agreed in writing.
#  Non-compliance results in revocation of commercial rights.
# 
#  For more details, see <https://www.gnu.org/licenses/>.
#  Contact: skylark0924@gmail.com

import torch.nn as nn

from .base_encoders import BaseEncoder


[docs]class HomoGraphEncoder(BaseEncoder): def __init__(self, in_dim, hidden_dim): import dgl.nn.pytorch as dglnn super(HomoGraphEncoder, self).__init__(hidden_dim) # init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init. # constant_(x, 0), nn.init.calculate_gain('relu')) # self.conv1 = dglnn.GraphConv(in_dim, hidden_dim, activation=F.relu) # self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim, activation=F.relu) # self.conv3 = dglnn.GraphConv(hidden_dim, hidden_dim, activation=F.relu) num_heads = 3 self.conv1 = dglnn.GATConv(in_dim, hidden_dim, num_heads=num_heads) self.conv2 = dglnn.GATConv(hidden_dim * num_heads, hidden_dim, 1) # self.conv3 = dglnn.GraphConv(hidden_dim, hidden_dim, activation=F.relu) self.linear = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU())
[docs] def forward(self, g, inputs): import torch.nn.functional as F import dgl # 应用图卷积和激活函数 h = self.conv1(g, inputs) h = h.view(-1, h.size(1) * h.size(2)) h = F.elu(h) h = self.conv2(g, h) h = h.squeeze() # h = self.conv3(g, h) with g.local_scope(): g.ndata['h'] = h # 使用平均读出计算图表示 hg = dgl.mean_nodes(g, 'h') hg = self.linear(hg) return hg
[docs]class HeteroGraphEncoder(BaseEncoder): def __init__(self): super(HeteroGraphEncoder, self).__init__()