@@ -50,22 +50,46 @@ class DecagonLayer(torch.nn.Module): | |||||
self.next_layer_repr = None | self.next_layer_repr = None | ||||
self.build() | self.build() | ||||
def build_family(self, fam): | |||||
for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||||
if len(rels) == 0: | |||||
continue | |||||
convolutions = [] | |||||
def build_fam_one_node_type(self, fam): | |||||
convolutions = [] | |||||
for r in fam.relation_types: | |||||
conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column], | |||||
self.output_dim[fam.node_type_row], r.adjacency_matrix, | |||||
self.keep_prob, self.rel_activation) | |||||
convolutions.append(conv) | |||||
self.next_layer_repr[fam.node_type_row].append( | |||||
Convolutions(fam.node_type_column, convolutions)) | |||||
def build_fam_two_node_types(self, fam) -> None: | |||||
convolutions_row = [] | |||||
convolutions_column = [] | |||||
for r in fam.relation_types: | |||||
if r.adjacency_matrix is not None: | |||||
conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column], | |||||
self.output_dim[fam.node_type_row], r.adjacency_matrix, | |||||
self.keep_prob, self.rel_activation) | |||||
convolutions_row.append(conv) | |||||
for r in rels: | |||||
conv = DropoutGraphConvActivation(self.input_dim[node_type_column], | |||||
self.output_dim[node_type_row], r.adjacency_matrix, | |||||
if r.adjacency_matrix_backward is not None: | |||||
conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_row], | |||||
self.output_dim[fam.node_type_column], r.adjacency_matrix_backward, | |||||
self.keep_prob, self.rel_activation) | self.keep_prob, self.rel_activation) | ||||
convolutions_column.append(conv) | |||||
self.next_layer_repr[fam.node_type_row].append( | |||||
Convolutions(fam.node_type_column, convolutions_row)) | |||||
convolutions.append(conv) | |||||
self.next_layer_repr[fam.node_type_column].append( | |||||
Convolutions(fam.node_type_row, convolutions_column)) | |||||
self.next_layer_repr[node_type_row].append( | |||||
Convolutions(node_type_column, convolutions)) | |||||
def build_family(self, fam) -> None: | |||||
if fam.node_type_row == fam.node_type_column: | |||||
self.build_fam_one_node_type(fam) | |||||
else: | |||||
self.build_fam_two_node_types(fam) | |||||
def build(self): | def build(self): | ||||
self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | ||||
@@ -49,12 +49,12 @@ class RelationTypeBase(object): | |||||
node_type_row: int | node_type_row: int | ||||
node_type_column: int | node_type_column: int | ||||
adjacency_matrix: torch.Tensor | adjacency_matrix: torch.Tensor | ||||
two_way: bool | |||||
adjacency_matrix_backward: torch.Tensor | |||||
@dataclass | @dataclass | ||||
class RelationType(RelationTypeBase): | class RelationType(RelationTypeBase): | ||||
hints: Dict[str, Any] = field(default_factory=dict) | |||||
pass | |||||
@dataclass | @dataclass | ||||
@@ -69,7 +69,7 @@ class RelationFamilyBase(object): | |||||
@dataclass | @dataclass | ||||
class RelationFamily(RelationFamilyBase): | class RelationFamily(RelationFamilyBase): | ||||
relation_types: Dict[Tuple[int, int], List[RelationType]] = None | |||||
relation_types: List[RelationType] = None | |||||
def __post_init__(self) -> None: | def __post_init__(self) -> None: | ||||
if not self.is_symmetric and \ | if not self.is_symmetric and \ | ||||
@@ -77,18 +77,18 @@ class RelationFamily(RelationFamilyBase): | |||||
self.decoder_class != BilinearDecoder: | self.decoder_class != BilinearDecoder: | ||||
raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | ||||
self.relation_types = { (self.node_type_row, self.node_type_column): [], | |||||
(self.node_type_column, self.node_type_row): [] } | |||||
self.relation_types = [] | |||||
def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, | |||||
adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, | |||||
two_way: bool = True) -> None: | |||||
def add_relation_type(self, | |||||
name: str, node_type_row: int, node_type_column: int, | |||||
adjacency_matrix: torch.Tensor, | |||||
adjacency_matrix_backward: torch.Tensor = None) -> None: | |||||
name = str(name) | name = str(name) | ||||
node_type_row = int(node_type_row) | node_type_row = int(node_type_row) | ||||
node_type_column = int(node_type_column) | node_type_column = int(node_type_column) | ||||
if (node_type_row, node_type_column) not in self.relation_types: | |||||
if (node_type_row, node_type_column) != (self.node_type_row, self.node_type_column): | |||||
raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family') | raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family') | ||||
if node_type_row < 0 or node_type_row >= len(self.data.node_types): | if node_type_row < 0 or node_type_row >= len(self.data.node_types): | ||||
@@ -97,21 +97,33 @@ class RelationFamily(RelationFamilyBase): | |||||
if node_type_column < 0 or node_type_column >= len(self.data.node_types): | if node_type_column < 0 or node_type_column >= len(self.data.node_types): | ||||
raise ValueError('node_type_column outside of the valid range of node types') | raise ValueError('node_type_column outside of the valid range of node types') | ||||
if not isinstance(adjacency_matrix, torch.Tensor): | |||||
if adjacency_matrix is None and adjacency_matrix_backward is None: | |||||
raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') | |||||
if adjacency_matrix is not None and \ | |||||
not isinstance(adjacency_matrix, torch.Tensor): | |||||
raise ValueError('adjacency_matrix must be a torch.Tensor') | raise ValueError('adjacency_matrix must be a torch.Tensor') | ||||
# if isinstance(adjacency_matrix_backward, str) and \ | |||||
# adjacency_matrix_backward == 'symmetric': | |||||
# if self.is_symmetric: | |||||
# adjacency_matrix_backward = None | |||||
# else: | |||||
# adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||||
if adjacency_matrix_backward is not None \ | if adjacency_matrix_backward is not None \ | ||||
and not isinstance(adjacency_matrix_backward, torch.Tensor): | and not isinstance(adjacency_matrix_backward, torch.Tensor): | ||||
raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | ||||
if adjacency_matrix.shape != (self.data.node_types[node_type_row].count, | |||||
self.data.node_types[node_type_column].count): | |||||
if adjacency_matrix is not None and \ | |||||
adjacency_matrix.shape != (self.data.node_types[node_type_row].count, | |||||
self.data.node_types[node_type_column].count): | |||||
raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') | raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') | ||||
if adjacency_matrix_backward is not None and \ | if adjacency_matrix_backward is not None and \ | ||||
adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, | adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, | ||||
self.data.node_types[node_type_row].count): | self.data.node_types[node_type_row].count): | ||||
raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)') | |||||
raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)') | |||||
if node_type_row == node_type_column and \ | if node_type_row == node_type_column and \ | ||||
adjacency_matrix_backward is not None: | adjacency_matrix_backward is not None: | ||||
@@ -125,19 +137,12 @@ class RelationFamily(RelationFamilyBase): | |||||
adjacency_matrix.transpose(0, 1))): | adjacency_matrix.transpose(0, 1))): | ||||
raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') | raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') | ||||
two_way = bool(two_way) | |||||
if node_type_row != node_type_column and two_way: | |||||
print('%d != %d' % (node_type_row, node_type_column)) | |||||
if adjacency_matrix_backward is None: | |||||
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||||
self.relation_types[node_type_column, node_type_row].append( | |||||
RelationType(name, node_type_column, node_type_row, | |||||
adjacency_matrix_backward, two_way, { 'display': False })) | |||||
if self.is_symmetric and node_type_row != node_type_column: | |||||
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||||
self.relation_types[node_type_row, node_type_column].append( | |||||
RelationType(name, node_type_row, node_type_column, | |||||
adjacency_matrix, two_way)) | |||||
self.relation_types.append(RelationType(name, | |||||
node_type_row, node_type_column, | |||||
adjacency_matrix, adjacency_matrix_backward)) | |||||
def node_name(self, index): | def node_name(self, index): | ||||
return self.data.node_types[index].name | return self.data.node_types[index].name | ||||
@@ -145,24 +150,26 @@ class RelationFamily(RelationFamilyBase): | |||||
def __repr__(self): | def __repr__(self): | ||||
s = 'Relation family %s' % self.name | s = 'Relation family %s' % self.name | ||||
for (node_type_row, node_type_column), rels in self.relation_types.items(): | |||||
for r in rels: | |||||
if 'display' in r.hints and not r.hints['display']: | |||||
continue | |||||
s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \ | |||||
(self.node_name(node_type_row), self.node_name(node_type_column))) | |||||
for r in self.relation_types: | |||||
s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||||
if (r.adjacency_matrix is not None \ | |||||
and r.adjacency_matrix_backward is not None) \ | |||||
or self.is_symmetric \ | |||||
else '%s <- %s' % (self.node_name(self.node_type_row), | |||||
self.node_name(self.node_type_column))) | |||||
return s | return s | ||||
def repr_indented(self): | def repr_indented(self): | ||||
s = ' - %s' % self.name | s = ' - %s' % self.name | ||||
for (node_type_row, node_type_column), rels in self.relation_types.items(): | |||||
for r in rels: | |||||
if 'display' in r.hints and not r.hints['display']: | |||||
continue | |||||
s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \ | |||||
(self.node_name(node_type_row), self.node_name(node_type_column))) | |||||
for r in self.relation_types: | |||||
s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||||
if (r.adjacency_matrix is not None \ | |||||
and r.adjacency_matrix_backward is not None) \ | |||||
or self.is_symmetric \ | |||||
else '%s <- %s' % (self.node_name(self.node_type_row), | |||||
self.node_name(self.node_type_column))) | |||||
return s | return s | ||||
@@ -22,7 +22,6 @@ class DecodeLayer(torch.nn.Module): | |||||
input_dim: List[int], | input_dim: List[int], | ||||
data: PreparedData, | data: PreparedData, | ||||
keep_prob: float = 1., | keep_prob: float = 1., | ||||
decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, | |||||
activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, | activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, | ||||
**kwargs) -> None: | **kwargs) -> None: | ||||
@@ -37,26 +36,25 @@ class DecodeLayer(torch.nn.Module): | |||||
if not isinstance(data, PreparedData): | if not isinstance(data, PreparedData): | ||||
raise TypeError('data must be an instance of PreparedData') | raise TypeError('data must be an instance of PreparedData') | ||||
if not isinstance(decoder_class, type) and \ | |||||
not isinstance(decoder_class, dict): | |||||
raise TypeError('decoder_class must be a Type or a Dict') | |||||
if not isinstance(decoder_class, dict): | |||||
decoder_class = { k: decoder_class \ | |||||
for k in data.relation_types.keys() } | |||||
self.input_dim = input_dim | self.input_dim = input_dim | ||||
self.output_dim = 1 | self.output_dim = 1 | ||||
self.data = data | self.data = data | ||||
self.keep_prob = keep_prob | self.keep_prob = keep_prob | ||||
self.decoder_class = decoder_class | |||||
self.activation = activation | self.activation = activation | ||||
self.decoders = None | self.decoders = None | ||||
self.build() | self.build() | ||||
def build(self) -> None: | def build(self) -> None: | ||||
self.decoders = {} | |||||
self.decoders = [] | |||||
for fam in self.data.relation_families: | |||||
for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||||
for r in rels: | |||||
pass | |||||
dec = fam.decoder_class() | |||||
self.decoders.append(dec) | |||||
for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | ||||
if len(rels) == 0: | if len(rels) == 0: | ||||
@@ -39,7 +39,7 @@ class PreparedRelationType(RelationTypeBase): | |||||
@dataclass | @dataclass | ||||
class PreparedRelationFamily(RelationFamilyBase): | class PreparedRelationFamily(RelationFamilyBase): | ||||
relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||||
relation_types: List[PreparedRelationType] | |||||
@dataclass | @dataclass | ||||
@@ -110,36 +110,81 @@ def prepare_adj_mat(adj_mat: torch.Tensor, | |||||
return adj_mat_train, edges_pos, edges_neg | return adj_mat_train, edges_pos, edges_neg | ||||
def prepare_relation_type(r: RelationType, | |||||
def prep_rel_one_node_type(r: RelationType, | |||||
ratios: TrainValTest) -> PreparedRelationType: | ratios: TrainValTest) -> PreparedRelationType: | ||||
adj_mat = r.adjacency_matrix | |||||
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||||
print('adj_mat_train:', adj_mat_train) | |||||
adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | |||||
return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||||
adj_mat_train, None, edges_pos, edges_neg) | |||||
def prep_rel_two_node_types_sym(r: RelationType, | |||||
ratios: TrainValTest) -> PreparedRelationType: | |||||
adj_mat = r.adjacency_matrix | |||||
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||||
return PreparedRelationType(r.name, r.node_type_row, | |||||
r.node_type_column, | |||||
norm_adj_mat_two_node_types(adj_mat_train), | |||||
norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)), | |||||
edges_pos, edges_neg) | |||||
def prep_rel_two_node_types_asym(r: RelationType, | |||||
ratios: TrainValTest) -> PreparedRelationType: | |||||
if r.adjacency_matrix is not None: | |||||
adj_mat_train, edges_pos, edges_neg =\ | |||||
prepare_adj_mat(r.adjacency_matrix, ratios) | |||||
else: | |||||
adj_mat_train, edges_pos, edges_neg = \ | |||||
None, torch.zeros((0, 2)), torch.zeros((0, 2)) | |||||
if r.adjacency_matrix_backward is not None: | |||||
adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||||
prepare_adj_mat(r.adjacency_matrix_backward, ratios) | |||||
else: | |||||
adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||||
None, torch.zeros((0, 2)), torch.zeros((0, 2)) | |||||
edges_pos = torch.cat((edges_pos, edges_back_pos), dim=0) | |||||
edges_neg = torch.cat((edges_neg, edges_back_neg), dim=0) | |||||
return PreparedRelationType(r.name, r.node_type_row, | |||||
r.node_type_column, | |||||
norm_adj_mat_two_node_types(adj_mat_train), | |||||
norm_adj_mat_two_node_types(adj_mat_back_train), | |||||
edges_pos, edges_neg) | |||||
def prepare_relation_type(r: RelationType, | |||||
ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType: | |||||
if not isinstance(r, RelationType): | if not isinstance(r, RelationType): | ||||
raise ValueError('r must be a RelationType') | raise ValueError('r must be a RelationType') | ||||
if not isinstance(ratios, TrainValTest): | if not isinstance(ratios, TrainValTest): | ||||
raise ValueError('ratios must be a TrainValTest') | raise ValueError('ratios must be a TrainValTest') | ||||
adj_mat = r.adjacency_matrix | |||||
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||||
print('adj_mat_train:', adj_mat_train) | |||||
if r.node_type_row == r.node_type_column: | if r.node_type_row == r.node_type_column: | ||||
adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | |||||
return prep_rel_one_node_type(r, ratios) | |||||
elif is_symmetric: | |||||
return prep_rel_two_node_types_sym(r, ratios) | |||||
else: | else: | ||||
adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) | |||||
return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||||
adj_mat_train, r.two_way, edges_pos, edges_neg) | |||||
return prep_rel_two_node_types_asym(r, ratios) | |||||
def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: | def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: | ||||
relation_types = { (fam.node_type_row, fam.node_type_column): [], | |||||
(fam.node_type_column, fam.node_type_row): [] } | |||||
relation_types = [] | |||||
for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||||
for r in rels: | |||||
relation_types[node_type_row, node_type_column].append( | |||||
prepare_relation_type(r, ratios)) | |||||
for r in fam.relation_types: | |||||
relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric)) | |||||
return PreparedRelationFamily(fam.data, fam.name, | return PreparedRelationFamily(fam.data, fam.name, | ||||
fam.node_type_row, fam.node_type_column, | fam.node_type_row, fam.node_type_column, | ||||
@@ -103,7 +103,7 @@ def test_decagon_layer_04(): | |||||
in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
multi_dgca = MultiDGCA([10], 32, | multi_dgca = MultiDGCA([10], 32, | ||||
[r.adjacency_matrix for r in fam.relation_types[0, 0]], | |||||
[r.adjacency_matrix for r in fam.relation_types], | |||||
keep_prob=1., activation=lambda x: x) | keep_prob=1., activation=lambda x: x) | ||||
d_layer = DecagonLayer(in_layer.output_dim, 32, d, | d_layer = DecagonLayer(in_layer.output_dim, 32, d, | ||||
@@ -147,7 +147,7 @@ def test_decagon_layer_05(): | |||||
in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
multi_dgca = MultiDGCA([100, 100], 32, | multi_dgca = MultiDGCA([100, 100], 32, | ||||
[r.adjacency_matrix for r in fam.relation_types[0, 0]], | |||||
[r.adjacency_matrix for r in fam.relation_types], | |||||
keep_prob=1., activation=lambda x: x) | keep_prob=1., activation=lambda x: x) | ||||
d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d, | d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d, | ||||
@@ -107,7 +107,7 @@ def test_prepare_relation_type_01(): | |||||
adj_mat = (torch.rand((10, 10)) > .5) | adj_mat = (torch.rand((10, 10)) > .5) | ||||
r = RelationType('Test', 0, 0, adj_mat, True) | r = RelationType('Test', 0, 0, adj_mat, True) | ||||
ratios = TrainValTest(.8, .1, .1) | ratios = TrainValTest(.8, .1, .1) | ||||
_ = prepare_relation_type(r, ratios) | |||||
_ = prepare_relation_type(r, ratios, False) | |||||