IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
浏览代码

Get test_decode_layer_01 to pass.

master
Stanislaw Adaszewski 4 年前
父节点
当前提交
e8122c3321
共有 3 个文件被更改,包括 99 次插入53 次删除
  1. +54
    -36
      src/icosagon/declayer.py
  2. +20
    -10
      src/icosagon/trainprep.py
  3. +25
    -7
      tests/icosagon/test_declayer.py

+ 54
- 36
src/icosagon/declayer.py 查看文件

@@ -15,6 +15,25 @@ from typing import Type, \
Dict, \
Tuple
from .decode import DEDICOMDecoder
from dataclasses import dataclass
@dataclass
class RelationPredictions(object):
edges_pos: TrainValTest
edges_neg: TrainValTest
edges_back_pos: TrainValTest
edges_back_neg: TrainValTest
@dataclass
class RelationFamilyPredictions(object):
relation_types: List[RelationPredictions]
@dataclass
class Predictions(object):
relation_families: List[RelationFamilyPredictions]
class DecodeLayer(torch.nn.Module):
@@ -30,13 +49,16 @@ class DecodeLayer(torch.nn.Module):
if not isinstance(input_dim, list):
raise TypeError('input_dim must be a List')
if len(input_dim) != len(data.node_types):
raise ValueError('input_dim must have length equal to num_node_types')
if not all([ a == input_dim[0] for a in input_dim ]):
raise ValueError('All elements of input_dim must have the same value')
if not isinstance(data, PreparedData):
raise TypeError('data must be an instance of PreparedData')
self.input_dim = input_dim
self.input_dim = input_dim[0]
self.output_dim = 1
self.data = data
self.keep_prob = keep_prob
@@ -47,42 +69,38 @@ class DecodeLayer(torch.nn.Module):
def build(self) -> None:
self.decoders = []
for fam in self.data.relation_families:
for (node_type_row, node_type_column), rels in fam.relation_types.items():
for r in rels:
pass
dec = fam.decoder_class()
dec = fam.decoder_class(self.input_dim, len(fam.relation_types),
self.keep_prob, self.activation)
self.decoders.append(dec)
for (node_type_row, node_type_column), rels in self.data.relation_types.items():
if len(rels) == 0:
continue
if isinstance(self.decoder_class, dict):
if (node_type_row, node_type_column) in self.decoder_class:
decoder_class = self.decoder_class[node_type_row, node_type_column]
elif (node_type_column, node_type_row) in self.decoder_class:
decoder_class = self.decoder_class[node_type_column, node_type_row]
else:
raise KeyError('Decoder not specified for edge type: %s -- %s' % (
self.data.node_types[node_type_row].name,
self.data.node_types[node_type_column].name))
else:
decoder_class = self.decoder_class
self.decoders[node_type_row, node_type_column] = \
decoder_class(self.input_dim[node_type_row],
num_relation_types = len(rels),
keep_prob = self.keep_prob,
activation = self.activation)
def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]:
res = {}
for (node_type_row, node_type_column), dec in self.decoders.items():
inputs_row = last_layer_repr[node_type_row]
inputs_column = last_layer_repr[node_type_column]
pred_adj_matrices = [ dec(inputs_row, inputs_column, k) for k in range(dec.num_relation_types) ]
res[node_type_row, node_type_column] = pred_adj_matrices
def _get_tvt(self, r, edge_list_attr_names, row, column, k, last_layer_repr, dec):
pred = []
for p in edge_list_attr_names:
tvt = []
for t in ['train', 'val', 'test']:
# print('r:', r)
edges = getattr(getattr(r, p), t)
inputs_row = last_layer_repr[row][edges[:, 0]]
inputs_column = last_layer_repr[column][edges[:, 1]]
tvt.append(dec(inputs_row, inputs_column, k))
tvt = TrainValTest(*tvt)
pred.append(tvt)
return pred
def forward(self, last_layer_repr: List[torch.Tensor]) -> List[List[torch.Tensor]]:
res = []
for i, fam in enumerate(self.data.relation_families):
fam_pred = []
for k, r in enumerate(fam.relation_types):
pred = []
pred += self._get_tvt(r, ['edges_pos', 'edges_neg'],
r.node_type_row, r.node_type_column, k, last_layer_repr, self.decoders[i])
pred += self._get_tvt(r, ['edges_back_pos', 'edges_back_neg'],
r.node_type_column, r.node_type_row, k, last_layer_repr, self.decoders[i])
pred = RelationPredictions(*pred)
fam_pred.append(pred)
fam_pred = RelationFamilyPredictions(fam_pred)
res.append(fam_pred)
res = Predictions(res)
return res

+ 20
- 10
src/icosagon/trainprep.py 查看文件

@@ -35,6 +35,8 @@ class TrainValTest(object):
class PreparedRelationType(RelationTypeBase):
edges_pos: TrainValTest
edges_neg: TrainValTest
edges_back_pos: TrainValTest
edges_back_neg: TrainValTest
@dataclass
@@ -48,6 +50,10 @@ class PreparedData(object):
relation_families: List[PreparedRelationFamily]
def _empty_edge_list_tvt() -> TrainValTest:
return TrainValTest(*[ torch.zeros((0, 2), dtype=torch.long) for _ in range(3) ])
def train_val_test_split_edges(edges: torch.Tensor,
ratios: TrainValTest) -> TrainValTest:
@@ -115,12 +121,15 @@ def prep_rel_one_node_type(r: RelationType,
adj_mat = r.adjacency_matrix
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
adj_mat_back_train, edges_back_pos, edges_back_neg = \
None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
print('adj_mat_train:', adj_mat_train)
adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
return PreparedRelationType(r.name, r.node_type_row, r.node_type_column,
adj_mat_train, None, edges_pos, edges_neg)
adj_mat_train, adj_mat_back_train, edges_pos, edges_neg,
edges_back_pos, edges_back_neg)
def prep_rel_two_node_types_sym(r: RelationType,
@@ -128,12 +137,14 @@ def prep_rel_two_node_types_sym(r: RelationType,
adj_mat = r.adjacency_matrix
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
edges_back_pos, edges_back_neg = \
_empty_edge_list_tvt(), _empty_edge_list_tvt()
return PreparedRelationType(r.name, r.node_type_row,
r.node_type_column,
norm_adj_mat_two_node_types(adj_mat_train),
norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)),
edges_pos, edges_neg)
edges_pos, edges_neg, edges_back_pos, edges_back_neg)
def prep_rel_two_node_types_asym(r: RelationType,
@@ -144,23 +155,20 @@ def prep_rel_two_node_types_asym(r: RelationType,
prepare_adj_mat(r.adjacency_matrix, ratios)
else:
adj_mat_train, edges_pos, edges_neg = \
None, torch.zeros((0, 2)), torch.zeros((0, 2))
None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
if r.adjacency_matrix_backward is not None:
adj_mat_back_train, edges_back_pos, edges_back_neg = \
prepare_adj_mat(r.adjacency_matrix_backward, ratios)
else:
adj_mat_back_train, edges_back_pos, edges_back_neg = \
None, torch.zeros((0, 2)), torch.zeros((0, 2))
edges_pos = torch.cat((edges_pos, edges_back_pos), dim=0)
edges_neg = torch.cat((edges_neg, edges_back_neg), dim=0)
None, _empty_edge_list_tvt(), _empty_edge_list_tvt()
return PreparedRelationType(r.name, r.node_type_row,
r.node_type_column,
norm_adj_mat_two_node_types(adj_mat_train),
norm_adj_mat_two_node_types(adj_mat_back_train),
edges_pos, edges_neg)
edges_pos, edges_neg, edges_back_pos, edges_back_neg)
def prepare_relation_type(r: RelationType,
@@ -180,7 +188,9 @@ def prepare_relation_type(r: RelationType,
return prep_rel_two_node_types_asym(r, ratios)
def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily:
def prepare_relation_family(fam: RelationFamily,
ratios: TrainValTest) -> PreparedRelationFamily:
relation_types = []
for r in fam.relation_types:
@@ -196,7 +206,7 @@ def prepare_training(data: Data, ratios: TrainValTest) -> PreparedData:
if not isinstance(data, Data):
raise ValueError('data must be of class Data')
relation_families = [ prepare_relation_family(fam) \
relation_families = [ prepare_relation_family(fam, ratios) \
for fam in data.relation_families ]
return PreparedData(data.node_types, relation_families)

+ 25
- 7
tests/icosagon/test_declayer.py 查看文件

@@ -6,7 +6,10 @@
from icosagon.input import OneHotInputLayer
from icosagon.convlayer import DecagonLayer
from icosagon.declayer import DecodeLayer
from icosagon.declayer import DecodeLayer, \
Predictions, \
RelationFamilyPredictions, \
RelationPredictions
from icosagon.decode import DEDICOMDecoder
from icosagon.data import Data
from icosagon.trainprep import prepare_training, \
@@ -17,21 +20,36 @@ import torch
def test_decode_layer_01():
d = Data()
d.add_node_type('Dummy', 100)
fam = d.add_relation_family('Dummy-Dummy', 0, 0, False)
fam.add_relation_type('Dummy Relation 1', 0, 0,
torch.rand((100, 100), dtype=torch.float32).round().to_sparse())
prep_d = prepare_training(d, TrainValTest(.8, .1, .1))
in_layer = OneHotInputLayer(d)
d_layer = DecagonLayer(in_layer.output_dim, 32, d)
seq = torch.nn.Sequential(in_layer, d_layer)
last_layer_repr = seq(None)
dec = DecodeLayer(input_dim=d_layer.output_dim, data=prep_d, keep_prob=1.,
decoder_class=DEDICOMDecoder, activation=lambda x: x)
pred_adj_matrices = dec(last_layer_repr)
assert isinstance(pred_adj_matrices, dict)
assert len(pred_adj_matrices) == 1
assert isinstance(pred_adj_matrices[0, 0], list)
assert len(pred_adj_matrices[0, 0]) == 1
activation=lambda x: x)
pred = dec(last_layer_repr)
assert isinstance(pred, Predictions)
assert isinstance(pred.relation_families, list)
assert len(pred.relation_families) == 1
assert isinstance(pred.relation_families[0], RelationFamilyPredictions)
assert isinstance(pred.relation_families[0].relation_types, list)
assert len(pred.relation_families[0].relation_types) == 1
assert isinstance(pred.relation_families[0].relation_types[0], RelationPredictions)
tmp = pred.relation_families[0].relation_types[0]
assert isinstance(tmp.edges_pos, TrainValTest)
assert isinstance(tmp.edges_neg, TrainValTest)
assert isinstance(tmp.edges_back_pos, TrainValTest)
assert isinstance(tmp.edges_back_neg, TrainValTest)
def test_decode_layer_02():


正在加载...
取消
保存