diff --git a/src/icosagon/convlayer.py b/src/icosagon/convlayer.py index 465f2a5..cdaced6 100644 --- a/src/icosagon/convlayer.py +++ b/src/icosagon/convlayer.py @@ -50,22 +50,46 @@ class DecagonLayer(torch.nn.Module): self.next_layer_repr = None self.build() - def build_family(self, fam): - for (node_type_row, node_type_column), rels in fam.relation_types.items(): - if len(rels) == 0: - continue - - convolutions = [] + def build_fam_one_node_type(self, fam): + convolutions = [] + + for r in fam.relation_types: + conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column], + self.output_dim[fam.node_type_row], r.adjacency_matrix, + self.keep_prob, self.rel_activation) + convolutions.append(conv) + + self.next_layer_repr[fam.node_type_row].append( + Convolutions(fam.node_type_column, convolutions)) + + def build_fam_two_node_types(self, fam) -> None: + convolutions_row = [] + convolutions_column = [] + + for r in fam.relation_types: + if r.adjacency_matrix is not None: + conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column], + self.output_dim[fam.node_type_row], r.adjacency_matrix, + self.keep_prob, self.rel_activation) + convolutions_row.append(conv) - for r in rels: - conv = DropoutGraphConvActivation(self.input_dim[node_type_column], - self.output_dim[node_type_row], r.adjacency_matrix, + if r.adjacency_matrix_backward is not None: + conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_row], + self.output_dim[fam.node_type_column], r.adjacency_matrix_backward, self.keep_prob, self.rel_activation) + convolutions_column.append(conv) + + self.next_layer_repr[fam.node_type_row].append( + Convolutions(fam.node_type_column, convolutions_row)) - convolutions.append(conv) + self.next_layer_repr[fam.node_type_column].append( + Convolutions(fam.node_type_row, convolutions_column)) - self.next_layer_repr[node_type_row].append( - Convolutions(node_type_column, convolutions)) + def build_family(self, fam) -> None: + if fam.node_type_row == fam.node_type_column: + self.build_fam_one_node_type(fam) + else: + self.build_fam_two_node_types(fam) def build(self): self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] diff --git a/src/icosagon/data.py b/src/icosagon/data.py index 6fc91cc..53bd769 100644 --- a/src/icosagon/data.py +++ b/src/icosagon/data.py @@ -49,12 +49,12 @@ class RelationTypeBase(object): node_type_row: int node_type_column: int adjacency_matrix: torch.Tensor - two_way: bool + adjacency_matrix_backward: torch.Tensor @dataclass class RelationType(RelationTypeBase): - hints: Dict[str, Any] = field(default_factory=dict) + pass @dataclass @@ -69,7 +69,7 @@ class RelationFamilyBase(object): @dataclass class RelationFamily(RelationFamilyBase): - relation_types: Dict[Tuple[int, int], List[RelationType]] = None + relation_types: List[RelationType] = None def __post_init__(self) -> None: if not self.is_symmetric and \ @@ -77,18 +77,18 @@ class RelationFamily(RelationFamilyBase): self.decoder_class != BilinearDecoder: raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') - self.relation_types = { (self.node_type_row, self.node_type_column): [], - (self.node_type_column, self.node_type_row): [] } + self.relation_types = [] - def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, - adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, - two_way: bool = True) -> None: + def add_relation_type(self, + name: str, node_type_row: int, node_type_column: int, + adjacency_matrix: torch.Tensor, + adjacency_matrix_backward: torch.Tensor = None) -> None: name = str(name) node_type_row = int(node_type_row) node_type_column = int(node_type_column) - if (node_type_row, node_type_column) not in self.relation_types: + if (node_type_row, node_type_column) != (self.node_type_row, self.node_type_column): raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family') if node_type_row < 0 or node_type_row >= len(self.data.node_types): @@ -97,21 +97,33 @@ class RelationFamily(RelationFamilyBase): if node_type_column < 0 or node_type_column >= len(self.data.node_types): raise ValueError('node_type_column outside of the valid range of node types') - if not isinstance(adjacency_matrix, torch.Tensor): + if adjacency_matrix is None and adjacency_matrix_backward is None: + raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') + + if adjacency_matrix is not None and \ + not isinstance(adjacency_matrix, torch.Tensor): raise ValueError('adjacency_matrix must be a torch.Tensor') + # if isinstance(adjacency_matrix_backward, str) and \ + # adjacency_matrix_backward == 'symmetric': + # if self.is_symmetric: + # adjacency_matrix_backward = None + # else: + # adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) + if adjacency_matrix_backward is not None \ and not isinstance(adjacency_matrix_backward, torch.Tensor): raise ValueError('adjacency_matrix_backward must be a torch.Tensor') - if adjacency_matrix.shape != (self.data.node_types[node_type_row].count, - self.data.node_types[node_type_column].count): + if adjacency_matrix is not None and \ + adjacency_matrix.shape != (self.data.node_types[node_type_row].count, + self.data.node_types[node_type_column].count): raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') if adjacency_matrix_backward is not None and \ adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, self.data.node_types[node_type_row].count): - raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)') + raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)') if node_type_row == node_type_column and \ adjacency_matrix_backward is not None: @@ -125,19 +137,12 @@ class RelationFamily(RelationFamilyBase): adjacency_matrix.transpose(0, 1))): raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') - two_way = bool(two_way) - - if node_type_row != node_type_column and two_way: - print('%d != %d' % (node_type_row, node_type_column)) - if adjacency_matrix_backward is None: - adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) - self.relation_types[node_type_column, node_type_row].append( - RelationType(name, node_type_column, node_type_row, - adjacency_matrix_backward, two_way, { 'display': False })) + if self.is_symmetric and node_type_row != node_type_column: + adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) - self.relation_types[node_type_row, node_type_column].append( - RelationType(name, node_type_row, node_type_column, - adjacency_matrix, two_way)) + self.relation_types.append(RelationType(name, + node_type_row, node_type_column, + adjacency_matrix, adjacency_matrix_backward)) def node_name(self, index): return self.data.node_types[index].name @@ -145,24 +150,26 @@ class RelationFamily(RelationFamilyBase): def __repr__(self): s = 'Relation family %s' % self.name - for (node_type_row, node_type_column), rels in self.relation_types.items(): - for r in rels: - if 'display' in r.hints and not r.hints['display']: - continue - s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \ - (self.node_name(node_type_row), self.node_name(node_type_column))) + for r in self.relation_types: + s += '\n - %s%s' % (r.name, ' (two-way)' \ + if (r.adjacency_matrix is not None \ + and r.adjacency_matrix_backward is not None) \ + or self.is_symmetric \ + else '%s <- %s' % (self.node_name(self.node_type_row), + self.node_name(self.node_type_column))) return s def repr_indented(self): s = ' - %s' % self.name - for (node_type_row, node_type_column), rels in self.relation_types.items(): - for r in rels: - if 'display' in r.hints and not r.hints['display']: - continue - s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \ - (self.node_name(node_type_row), self.node_name(node_type_column))) + for r in self.relation_types: + s += '\n - %s%s' % (r.name, ' (two-way)' \ + if (r.adjacency_matrix is not None \ + and r.adjacency_matrix_backward is not None) \ + or self.is_symmetric \ + else '%s <- %s' % (self.node_name(self.node_type_row), + self.node_name(self.node_type_column))) return s diff --git a/src/icosagon/declayer.py b/src/icosagon/declayer.py index 9047025..7840ab5 100644 --- a/src/icosagon/declayer.py +++ b/src/icosagon/declayer.py @@ -22,7 +22,6 @@ class DecodeLayer(torch.nn.Module): input_dim: List[int], data: PreparedData, keep_prob: float = 1., - decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, **kwargs) -> None: @@ -37,26 +36,25 @@ class DecodeLayer(torch.nn.Module): if not isinstance(data, PreparedData): raise TypeError('data must be an instance of PreparedData') - if not isinstance(decoder_class, type) and \ - not isinstance(decoder_class, dict): - raise TypeError('decoder_class must be a Type or a Dict') - - if not isinstance(decoder_class, dict): - decoder_class = { k: decoder_class \ - for k in data.relation_types.keys() } - self.input_dim = input_dim self.output_dim = 1 self.data = data self.keep_prob = keep_prob - self.decoder_class = decoder_class self.activation = activation self.decoders = None self.build() def build(self) -> None: - self.decoders = {} + self.decoders = [] + + for fam in self.data.relation_families: + for (node_type_row, node_type_column), rels in fam.relation_types.items(): + for r in rels: + pass + + dec = fam.decoder_class() + self.decoders.append(dec) for (node_type_row, node_type_column), rels in self.data.relation_types.items(): if len(rels) == 0: diff --git a/src/icosagon/trainprep.py b/src/icosagon/trainprep.py index d1dcdb3..5c7843d 100644 --- a/src/icosagon/trainprep.py +++ b/src/icosagon/trainprep.py @@ -39,7 +39,7 @@ class PreparedRelationType(RelationTypeBase): @dataclass class PreparedRelationFamily(RelationFamilyBase): - relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] + relation_types: List[PreparedRelationType] @dataclass @@ -110,36 +110,81 @@ def prepare_adj_mat(adj_mat: torch.Tensor, return adj_mat_train, edges_pos, edges_neg -def prepare_relation_type(r: RelationType, +def prep_rel_one_node_type(r: RelationType, ratios: TrainValTest) -> PreparedRelationType: + adj_mat = r.adjacency_matrix + adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) + + print('adj_mat_train:', adj_mat_train) + adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) + + return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, + adj_mat_train, None, edges_pos, edges_neg) + + +def prep_rel_two_node_types_sym(r: RelationType, + ratios: TrainValTest) -> PreparedRelationType: + + adj_mat = r.adjacency_matrix + adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) + + return PreparedRelationType(r.name, r.node_type_row, + r.node_type_column, + norm_adj_mat_two_node_types(adj_mat_train), + norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)), + edges_pos, edges_neg) + + +def prep_rel_two_node_types_asym(r: RelationType, + ratios: TrainValTest) -> PreparedRelationType: + + if r.adjacency_matrix is not None: + adj_mat_train, edges_pos, edges_neg =\ + prepare_adj_mat(r.adjacency_matrix, ratios) + else: + adj_mat_train, edges_pos, edges_neg = \ + None, torch.zeros((0, 2)), torch.zeros((0, 2)) + + if r.adjacency_matrix_backward is not None: + adj_mat_back_train, edges_back_pos, edges_back_neg = \ + prepare_adj_mat(r.adjacency_matrix_backward, ratios) + else: + adj_mat_back_train, edges_back_pos, edges_back_neg = \ + None, torch.zeros((0, 2)), torch.zeros((0, 2)) + + edges_pos = torch.cat((edges_pos, edges_back_pos), dim=0) + edges_neg = torch.cat((edges_neg, edges_back_neg), dim=0) + + return PreparedRelationType(r.name, r.node_type_row, + r.node_type_column, + norm_adj_mat_two_node_types(adj_mat_train), + norm_adj_mat_two_node_types(adj_mat_back_train), + edges_pos, edges_neg) + + +def prepare_relation_type(r: RelationType, + ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType: + if not isinstance(r, RelationType): raise ValueError('r must be a RelationType') if not isinstance(ratios, TrainValTest): raise ValueError('ratios must be a TrainValTest') - adj_mat = r.adjacency_matrix - adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) - - print('adj_mat_train:', adj_mat_train) if r.node_type_row == r.node_type_column: - adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) + return prep_rel_one_node_type(r, ratios) + elif is_symmetric: + return prep_rel_two_node_types_sym(r, ratios) else: - adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) - - return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, - adj_mat_train, r.two_way, edges_pos, edges_neg) + return prep_rel_two_node_types_asym(r, ratios) def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: - relation_types = { (fam.node_type_row, fam.node_type_column): [], - (fam.node_type_column, fam.node_type_row): [] } + relation_types = [] - for (node_type_row, node_type_column), rels in fam.relation_types.items(): - for r in rels: - relation_types[node_type_row, node_type_column].append( - prepare_relation_type(r, ratios)) + for r in fam.relation_types: + relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric)) return PreparedRelationFamily(fam.data, fam.name, fam.node_type_row, fam.node_type_column, diff --git a/tests/icosagon/test_convlayer.py b/tests/icosagon/test_convlayer.py index cc0458f..11e8ffc 100644 --- a/tests/icosagon/test_convlayer.py +++ b/tests/icosagon/test_convlayer.py @@ -103,7 +103,7 @@ def test_decagon_layer_04(): in_layer = OneHotInputLayer(d) multi_dgca = MultiDGCA([10], 32, - [r.adjacency_matrix for r in fam.relation_types[0, 0]], + [r.adjacency_matrix for r in fam.relation_types], keep_prob=1., activation=lambda x: x) d_layer = DecagonLayer(in_layer.output_dim, 32, d, @@ -147,7 +147,7 @@ def test_decagon_layer_05(): in_layer = OneHotInputLayer(d) multi_dgca = MultiDGCA([100, 100], 32, - [r.adjacency_matrix for r in fam.relation_types[0, 0]], + [r.adjacency_matrix for r in fam.relation_types], keep_prob=1., activation=lambda x: x) d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d, diff --git a/tests/icosagon/test_trainprep.py b/tests/icosagon/test_trainprep.py index 712d8c5..b3a2d1f 100644 --- a/tests/icosagon/test_trainprep.py +++ b/tests/icosagon/test_trainprep.py @@ -107,7 +107,7 @@ def test_prepare_relation_type_01(): adj_mat = (torch.rand((10, 10)) > .5) r = RelationType('Test', 0, 0, adj_mat, True) ratios = TrainValTest(.8, .1, .1) - _ = prepare_relation_type(r, ratios) + _ = prepare_relation_type(r, ratios, False)