|
|
@@ -30,9 +30,13 @@ from collections import defaultdict |
|
|
|
|
|
|
|
|
|
|
|
class Layer(torch.nn.Module):
|
|
|
|
def __init__(self, output_dim: Union[int, List[int]], **kwargs) -> None:
|
|
|
|
def __init__(self,
|
|
|
|
output_dim: Union[int, List[int]],
|
|
|
|
is_sparse: bool,
|
|
|
|
**kwargs) -> None:
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
self.output_dim = output_dim
|
|
|
|
self.is_sparse = is_sparse
|
|
|
|
|
|
|
|
|
|
|
|
class InputLayer(Layer):
|
|
|
@@ -42,7 +46,7 @@ class InputLayer(Layer): |
|
|
|
if not isinstance(output_dim, list):
|
|
|
|
output_dim = [output_dim,] * len(data.node_types)
|
|
|
|
|
|
|
|
super().__init__(output_dim, **kwargs)
|
|
|
|
super().__init__(output_dim, is_sparse=False, **kwargs)
|
|
|
|
self.data = data
|
|
|
|
self.node_reps = None
|
|
|
|
self.build()
|
|
|
@@ -67,6 +71,34 @@ class InputLayer(Layer): |
|
|
|
return s.strip()
|
|
|
|
|
|
|
|
|
|
|
|
class OneHotInputLayer(Layer):
|
|
|
|
def __init__(self, data: Data, **kwargs) -> None:
|
|
|
|
output_dim = [ a.count for a in data.node_types ]
|
|
|
|
super().__init__(output_dim, is_sparse=True, **kwargs)
|
|
|
|
self.data = data
|
|
|
|
self.node_reps = None
|
|
|
|
self.build()
|
|
|
|
|
|
|
|
def build(self) -> None:
|
|
|
|
self.node_reps = []
|
|
|
|
for i, nt in enumerate(self.data.node_types):
|
|
|
|
reps = torch.eye(nt.count).to_sparse()
|
|
|
|
reps = torch.nn.Parameter(reps)
|
|
|
|
self.register_parameter('node_reps[%d]' % i, reps)
|
|
|
|
self.node_reps.append(reps)
|
|
|
|
|
|
|
|
def forward(self) -> List[torch.nn.Parameter]:
|
|
|
|
return self.node_reps
|
|
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
|
s = ''
|
|
|
|
s += 'One-hot GNN input layer\n'
|
|
|
|
s += ' # of node types: %d\n' % len(self.data.node_types)
|
|
|
|
for nt in self.data.node_types:
|
|
|
|
s += ' - %s (%d)\n' % (nt.name, nt.count)
|
|
|
|
return s.strip()
|
|
|
|
|
|
|
|
|
|
|
|
class DecagonLayer(Layer):
|
|
|
|
def __init__(self,
|
|
|
|
data: Data,
|
|
|
@@ -78,7 +110,7 @@ class DecagonLayer(Layer): |
|
|
|
**kwargs):
|
|
|
|
if not isinstance(output_dim, list):
|
|
|
|
output_dim = [ output_dim ] * len(data.node_types)
|
|
|
|
super().__init__(output_dim, **kwargs)
|
|
|
|
super().__init__(output_dim, is_sparse=False, **kwargs)
|
|
|
|
self.data = data
|
|
|
|
self.previous_layer = previous_layer
|
|
|
|
self.input_dim = previous_layer.output_dim
|
|
|
@@ -98,6 +130,9 @@ class DecagonLayer(Layer): |
|
|
|
self.keep_prob, self.rel_activation)
|
|
|
|
self.next_layer_repr[nt_row].append((conv, nt_col))
|
|
|
|
|
|
|
|
if nt_row == nt_col:
|
|
|
|
continue
|
|
|
|
|
|
|
|
conv = SparseDropoutGraphConvActivation(self.input_dim[nt_row],
|
|
|
|
self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1),
|
|
|
|
self.keep_prob, self.rel_activation)
|
|
|
@@ -105,9 +140,16 @@ class DecagonLayer(Layer): |
|
|
|
|
|
|
|
def __call__(self):
|
|
|
|
prev_layer_repr = self.previous_layer()
|
|
|
|
next_layer_repr = self.next_layer_repr
|
|
|
|
next_layer_repr = [None] * len(self.data.node_types)
|
|
|
|
print('next_layer_repr:', next_layer_repr)
|
|
|
|
for i in range(len(self.data.node_types)):
|
|
|
|
next_layer_repr[i] = map(lambda conv, neighbor_type: \
|
|
|
|
conv(prev_layer_repr[neighbor_type]), next_layer_repr[i])
|
|
|
|
next_layer_repr = list(map(sum, next_layer_repr))
|
|
|
|
next_layer_repr[i] = [
|
|
|
|
conv(prev_layer_repr[neighbor_type]) \
|
|
|
|
for (conv, neighbor_type) in \
|
|
|
|
self.next_layer_repr[i]
|
|
|
|
]
|
|
|
|
next_layer_repr[i] = sum(next_layer_repr[i])
|
|
|
|
|
|
|
|
print('next_layer_repr:', next_layer_repr)
|
|
|
|
# next_layer_repr = list(map(sum, next_layer_repr))
|
|
|
|
return next_layer_repr
|