|
@@ -133,9 +133,10 @@ class Model(torch.nn.Module): |
|
|
List[torch.Tensor]:
|
|
|
List[torch.Tensor]:
|
|
|
|
|
|
|
|
|
cur_layer_repr = in_layer_repr
|
|
|
cur_layer_repr = in_layer_repr
|
|
|
next_layer_repr = [ None ] * len(self.data.vertex_types)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(self.layer_dimensions) - 1):
|
|
|
for i in range(len(self.layer_dimensions) - 1):
|
|
|
|
|
|
next_layer_repr = [ [] for _ in range(len(self.data.vertex_types)) ]
|
|
|
|
|
|
|
|
|
for _, et in self.data.edge_types.items():
|
|
|
for _, et in self.data.edge_types.items():
|
|
|
vt_row, vt_col = et.vertex_type_row, et.vertex_type_column
|
|
|
vt_row, vt_col = et.vertex_type_row, et.vertex_type_column
|
|
|
adj_matrices = self.adj_matrices['%d-%d' % (vt_row, vt_col)]
|
|
|
adj_matrices = self.adj_matrices['%d-%d' % (vt_row, vt_col)]
|
|
@@ -158,13 +159,18 @@ class Model(torch.nn.Module): |
|
|
self.data.vertex_types[vt_row].count,
|
|
|
self.data.vertex_types[vt_row].count,
|
|
|
self.layer_dimensions[i + 1])
|
|
|
self.layer_dimensions[i + 1])
|
|
|
|
|
|
|
|
|
print('b, Layer:', i, 'x.shape:', x.shape)
|
|
|
|
|
|
|
|
|
print('b, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape)
|
|
|
|
|
|
|
|
|
x = x.sum(dim=0)
|
|
|
x = x.sum(dim=0)
|
|
|
x = torch.nn.functional.normalize(x, p=2, dim=1)
|
|
|
x = torch.nn.functional.normalize(x, p=2, dim=1)
|
|
|
x = self.conv_activation(x)
|
|
|
|
|
|
|
|
|
# x = self.rel_activation(x)
|
|
|
|
|
|
print('c, Layer:', i, 'vt_row:', vt_row, 'x.shape:', x.shape)
|
|
|
|
|
|
|
|
|
|
|
|
next_layer_repr[vt_row].append(x)
|
|
|
|
|
|
|
|
|
|
|
|
next_layer_repr = [ self.conv_activation(sum(x)) \
|
|
|
|
|
|
for x in next_layer_repr ]
|
|
|
|
|
|
|
|
|
next_layer_repr[vt_row] = x
|
|
|
|
|
|
cur_layer_repr = next_layer_repr
|
|
|
cur_layer_repr = next_layer_repr
|
|
|
return next_layer_repr
|
|
|
return next_layer_repr
|
|
|
|
|
|
|
|
|