IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Selaa lähdekoodia

Adjust DecagonLayer to normalize per edge type.

master
Stanislaw Adaszewski 4 vuotta sitten
vanhempi
commit
52be4061b3
2 muutettua tiedostoa jossa 37 lisäystä ja 23 poistoa
  1. +20
    -11
      src/decagon_pytorch/layer.py
  2. +17
    -12
      tests/decagon_pytorch/test_layer.py

+ 20
- 11
src/decagon_pytorch/layer.py Näytä tiedosto

@@ -124,11 +124,14 @@ class DecagonLayer(Layer):
self.next_layer_repr = defaultdict(list)
for (nt_row, nt_col), relation_types in self.data.relation_types.items():
row_convs = []
col_convs = []
for rel in relation_types:
conv = DropoutGraphConvActivation(self.input_dim[nt_col],
self.output_dim[nt_row], rel.adjacency_matrix,
self.keep_prob, self.rel_activation)
self.next_layer_repr[nt_row].append((conv, nt_col))
row_convs.append(conv)
if nt_row == nt_col:
continue
@@ -136,21 +139,27 @@ class DecagonLayer(Layer):
conv = DropoutGraphConvActivation(self.input_dim[nt_row],
self.output_dim[nt_col], rel.adjacency_matrix.transpose(0, 1),
self.keep_prob, self.rel_activation)
self.next_layer_repr[nt_col].append((conv, nt_row))
col_convs.append(conv)
self.next_layer_repr[nt_row].append((row_convs, nt_col))
if nt_row == nt_col:
continue
self.next_layer_repr[nt_col].append((col_convs, nt_row))
def __call__(self):
prev_layer_repr = self.previous_layer()
next_layer_repr = [None] * len(self.data.node_types)
next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ]
print('next_layer_repr:', next_layer_repr)
for i in range(len(self.data.node_types)):
next_layer_repr[i] = [
conv(prev_layer_repr[neighbor_type]) \
for (conv, neighbor_type) in \
self.next_layer_repr[i]
]
for convs, neighbor_type in self.next_layer_repr[i]:
convs = [ conv(prev_layer_repr[neighbor_type]) \
for conv in convs ]
convs = sum(convs)
convs = torch.nn.functional.normalize(convs, p=2, dim=1)
next_layer_repr[i].append(convs)
next_layer_repr[i] = sum(next_layer_repr[i])
next_layer_repr[i] = torch.nn.functional.normalize(next_layer_repr[i], p=2, dim=1)
next_layer_repr[i] = self.layer_activation(next_layer_repr[i])
print('next_layer_repr:', next_layer_repr)
# next_layer_repr = list(map(sum, next_layer_repr))
return next_layer_repr

+ 17
- 12
tests/decagon_pytorch/test_layer.py Näytä tiedosto

@@ -136,16 +136,21 @@ def test_decagon_layer_03():
x = torch.tensor([-1, 0, 0.5, 1])
assert (d_layer.layer_activation(x) == torch.nn.functional.relu(x)).all()
assert len(d_layer.next_layer_repr) == 2
assert len(d_layer.next_layer_repr[0]) == 2
assert len(d_layer.next_layer_repr[1]) == 4
assert all(map(lambda a: isinstance(a[0], DropoutGraphConvActivation),
d_layer.next_layer_repr[0]))
assert all(map(lambda a: isinstance(a[0], DropoutGraphConvActivation),
d_layer.next_layer_repr[1]))
assert all(map(lambda a: a[0].output_dim == 32,
d_layer.next_layer_repr[0]))
assert all(map(lambda a: a[0].output_dim == 32,
d_layer.next_layer_repr[1]))
for i in range(2):
assert len(d_layer.next_layer_repr[i]) == 2
assert isinstance(d_layer.next_layer_repr[i], list)
assert isinstance(d_layer.next_layer_repr[i][0], tuple)
assert isinstance(d_layer.next_layer_repr[i][0][0], list)
assert isinstance(d_layer.next_layer_repr[i][0][1], int)
assert all([
isinstance(dgca, DropoutGraphConvActivation) \
for dgca in d_layer.next_layer_repr[i][0][0]
])
assert all([
dgca.output_dim == 32 \
for dgca in d_layer.next_layer_repr[i][0][0]
])
def test_decagon_layer_04():
@@ -166,10 +171,10 @@ def test_decagon_layer_04():
keep_prob=1., rel_activation=lambda x: x,
layer_activation=lambda x: x)
assert isinstance(d_layer.next_layer_repr[0][0][0],
assert isinstance(d_layer.next_layer_repr[0][0][0][0],
DropoutGraphConvActivation)
weight = d_layer.next_layer_repr[0][0][0].graph_conv.weight
weight = d_layer.next_layer_repr[0][0][0][0].graph_conv.weight
assert isinstance(weight, torch.Tensor)
assert len(multi_dgca.sparse_dgca) == 1


Loading…
Peruuta
Tallenna