| @@ -50,22 +50,46 @@ class DecagonLayer(torch.nn.Module): | |||||
| self.next_layer_repr = None | self.next_layer_repr = None | ||||
| self.build() | self.build() | ||||
| def build_family(self, fam): | |||||
| for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||||
| if len(rels) == 0: | |||||
| continue | |||||
| convolutions = [] | |||||
| def build_fam_one_node_type(self, fam): | |||||
| convolutions = [] | |||||
| for r in fam.relation_types: | |||||
| conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column], | |||||
| self.output_dim[fam.node_type_row], r.adjacency_matrix, | |||||
| self.keep_prob, self.rel_activation) | |||||
| convolutions.append(conv) | |||||
| self.next_layer_repr[fam.node_type_row].append( | |||||
| Convolutions(fam.node_type_column, convolutions)) | |||||
| def build_fam_two_node_types(self, fam) -> None: | |||||
| convolutions_row = [] | |||||
| convolutions_column = [] | |||||
| for r in fam.relation_types: | |||||
| if r.adjacency_matrix is not None: | |||||
| conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column], | |||||
| self.output_dim[fam.node_type_row], r.adjacency_matrix, | |||||
| self.keep_prob, self.rel_activation) | |||||
| convolutions_row.append(conv) | |||||
| for r in rels: | |||||
| conv = DropoutGraphConvActivation(self.input_dim[node_type_column], | |||||
| self.output_dim[node_type_row], r.adjacency_matrix, | |||||
| if r.adjacency_matrix_backward is not None: | |||||
| conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_row], | |||||
| self.output_dim[fam.node_type_column], r.adjacency_matrix_backward, | |||||
| self.keep_prob, self.rel_activation) | self.keep_prob, self.rel_activation) | ||||
| convolutions_column.append(conv) | |||||
| self.next_layer_repr[fam.node_type_row].append( | |||||
| Convolutions(fam.node_type_column, convolutions_row)) | |||||
| convolutions.append(conv) | |||||
| self.next_layer_repr[fam.node_type_column].append( | |||||
| Convolutions(fam.node_type_row, convolutions_column)) | |||||
| self.next_layer_repr[node_type_row].append( | |||||
| Convolutions(node_type_column, convolutions)) | |||||
| def build_family(self, fam) -> None: | |||||
| if fam.node_type_row == fam.node_type_column: | |||||
| self.build_fam_one_node_type(fam) | |||||
| else: | |||||
| self.build_fam_two_node_types(fam) | |||||
| def build(self): | def build(self): | ||||
| self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | ||||
| @@ -49,12 +49,12 @@ class RelationTypeBase(object): | |||||
| node_type_row: int | node_type_row: int | ||||
| node_type_column: int | node_type_column: int | ||||
| adjacency_matrix: torch.Tensor | adjacency_matrix: torch.Tensor | ||||
| two_way: bool | |||||
| adjacency_matrix_backward: torch.Tensor | |||||
| @dataclass | @dataclass | ||||
| class RelationType(RelationTypeBase): | class RelationType(RelationTypeBase): | ||||
| hints: Dict[str, Any] = field(default_factory=dict) | |||||
| pass | |||||
| @dataclass | @dataclass | ||||
| @@ -69,7 +69,7 @@ class RelationFamilyBase(object): | |||||
| @dataclass | @dataclass | ||||
| class RelationFamily(RelationFamilyBase): | class RelationFamily(RelationFamilyBase): | ||||
| relation_types: Dict[Tuple[int, int], List[RelationType]] = None | |||||
| relation_types: List[RelationType] = None | |||||
| def __post_init__(self) -> None: | def __post_init__(self) -> None: | ||||
| if not self.is_symmetric and \ | if not self.is_symmetric and \ | ||||
| @@ -77,18 +77,18 @@ class RelationFamily(RelationFamilyBase): | |||||
| self.decoder_class != BilinearDecoder: | self.decoder_class != BilinearDecoder: | ||||
| raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | ||||
| self.relation_types = { (self.node_type_row, self.node_type_column): [], | |||||
| (self.node_type_column, self.node_type_row): [] } | |||||
| self.relation_types = [] | |||||
| def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, | |||||
| adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, | |||||
| two_way: bool = True) -> None: | |||||
| def add_relation_type(self, | |||||
| name: str, node_type_row: int, node_type_column: int, | |||||
| adjacency_matrix: torch.Tensor, | |||||
| adjacency_matrix_backward: torch.Tensor = None) -> None: | |||||
| name = str(name) | name = str(name) | ||||
| node_type_row = int(node_type_row) | node_type_row = int(node_type_row) | ||||
| node_type_column = int(node_type_column) | node_type_column = int(node_type_column) | ||||
| if (node_type_row, node_type_column) not in self.relation_types: | |||||
| if (node_type_row, node_type_column) != (self.node_type_row, self.node_type_column): | |||||
| raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family') | raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family') | ||||
| if node_type_row < 0 or node_type_row >= len(self.data.node_types): | if node_type_row < 0 or node_type_row >= len(self.data.node_types): | ||||
| @@ -97,21 +97,33 @@ class RelationFamily(RelationFamilyBase): | |||||
| if node_type_column < 0 or node_type_column >= len(self.data.node_types): | if node_type_column < 0 or node_type_column >= len(self.data.node_types): | ||||
| raise ValueError('node_type_column outside of the valid range of node types') | raise ValueError('node_type_column outside of the valid range of node types') | ||||
| if not isinstance(adjacency_matrix, torch.Tensor): | |||||
| if adjacency_matrix is None and adjacency_matrix_backward is None: | |||||
| raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') | |||||
| if adjacency_matrix is not None and \ | |||||
| not isinstance(adjacency_matrix, torch.Tensor): | |||||
| raise ValueError('adjacency_matrix must be a torch.Tensor') | raise ValueError('adjacency_matrix must be a torch.Tensor') | ||||
| # if isinstance(adjacency_matrix_backward, str) and \ | |||||
| # adjacency_matrix_backward == 'symmetric': | |||||
| # if self.is_symmetric: | |||||
| # adjacency_matrix_backward = None | |||||
| # else: | |||||
| # adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||||
| if adjacency_matrix_backward is not None \ | if adjacency_matrix_backward is not None \ | ||||
| and not isinstance(adjacency_matrix_backward, torch.Tensor): | and not isinstance(adjacency_matrix_backward, torch.Tensor): | ||||
| raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | ||||
| if adjacency_matrix.shape != (self.data.node_types[node_type_row].count, | |||||
| self.data.node_types[node_type_column].count): | |||||
| if adjacency_matrix is not None and \ | |||||
| adjacency_matrix.shape != (self.data.node_types[node_type_row].count, | |||||
| self.data.node_types[node_type_column].count): | |||||
| raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') | raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') | ||||
| if adjacency_matrix_backward is not None and \ | if adjacency_matrix_backward is not None and \ | ||||
| adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, | adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, | ||||
| self.data.node_types[node_type_row].count): | self.data.node_types[node_type_row].count): | ||||
| raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)') | |||||
| raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)') | |||||
| if node_type_row == node_type_column and \ | if node_type_row == node_type_column and \ | ||||
| adjacency_matrix_backward is not None: | adjacency_matrix_backward is not None: | ||||
| @@ -125,19 +137,12 @@ class RelationFamily(RelationFamilyBase): | |||||
| adjacency_matrix.transpose(0, 1))): | adjacency_matrix.transpose(0, 1))): | ||||
| raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') | raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') | ||||
| two_way = bool(two_way) | |||||
| if node_type_row != node_type_column and two_way: | |||||
| print('%d != %d' % (node_type_row, node_type_column)) | |||||
| if adjacency_matrix_backward is None: | |||||
| adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||||
| self.relation_types[node_type_column, node_type_row].append( | |||||
| RelationType(name, node_type_column, node_type_row, | |||||
| adjacency_matrix_backward, two_way, { 'display': False })) | |||||
| if self.is_symmetric and node_type_row != node_type_column: | |||||
| adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||||
| self.relation_types[node_type_row, node_type_column].append( | |||||
| RelationType(name, node_type_row, node_type_column, | |||||
| adjacency_matrix, two_way)) | |||||
| self.relation_types.append(RelationType(name, | |||||
| node_type_row, node_type_column, | |||||
| adjacency_matrix, adjacency_matrix_backward)) | |||||
| def node_name(self, index): | def node_name(self, index): | ||||
| return self.data.node_types[index].name | return self.data.node_types[index].name | ||||
| @@ -145,24 +150,26 @@ class RelationFamily(RelationFamilyBase): | |||||
| def __repr__(self): | def __repr__(self): | ||||
| s = 'Relation family %s' % self.name | s = 'Relation family %s' % self.name | ||||
| for (node_type_row, node_type_column), rels in self.relation_types.items(): | |||||
| for r in rels: | |||||
| if 'display' in r.hints and not r.hints['display']: | |||||
| continue | |||||
| s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \ | |||||
| (self.node_name(node_type_row), self.node_name(node_type_column))) | |||||
| for r in self.relation_types: | |||||
| s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||||
| if (r.adjacency_matrix is not None \ | |||||
| and r.adjacency_matrix_backward is not None) \ | |||||
| or self.is_symmetric \ | |||||
| else '%s <- %s' % (self.node_name(self.node_type_row), | |||||
| self.node_name(self.node_type_column))) | |||||
| return s | return s | ||||
| def repr_indented(self): | def repr_indented(self): | ||||
| s = ' - %s' % self.name | s = ' - %s' % self.name | ||||
| for (node_type_row, node_type_column), rels in self.relation_types.items(): | |||||
| for r in rels: | |||||
| if 'display' in r.hints and not r.hints['display']: | |||||
| continue | |||||
| s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \ | |||||
| (self.node_name(node_type_row), self.node_name(node_type_column))) | |||||
| for r in self.relation_types: | |||||
| s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||||
| if (r.adjacency_matrix is not None \ | |||||
| and r.adjacency_matrix_backward is not None) \ | |||||
| or self.is_symmetric \ | |||||
| else '%s <- %s' % (self.node_name(self.node_type_row), | |||||
| self.node_name(self.node_type_column))) | |||||
| return s | return s | ||||
| @@ -22,7 +22,6 @@ class DecodeLayer(torch.nn.Module): | |||||
| input_dim: List[int], | input_dim: List[int], | ||||
| data: PreparedData, | data: PreparedData, | ||||
| keep_prob: float = 1., | keep_prob: float = 1., | ||||
| decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, | |||||
| activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, | activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, | ||||
| **kwargs) -> None: | **kwargs) -> None: | ||||
| @@ -37,26 +36,25 @@ class DecodeLayer(torch.nn.Module): | |||||
| if not isinstance(data, PreparedData): | if not isinstance(data, PreparedData): | ||||
| raise TypeError('data must be an instance of PreparedData') | raise TypeError('data must be an instance of PreparedData') | ||||
| if not isinstance(decoder_class, type) and \ | |||||
| not isinstance(decoder_class, dict): | |||||
| raise TypeError('decoder_class must be a Type or a Dict') | |||||
| if not isinstance(decoder_class, dict): | |||||
| decoder_class = { k: decoder_class \ | |||||
| for k in data.relation_types.keys() } | |||||
| self.input_dim = input_dim | self.input_dim = input_dim | ||||
| self.output_dim = 1 | self.output_dim = 1 | ||||
| self.data = data | self.data = data | ||||
| self.keep_prob = keep_prob | self.keep_prob = keep_prob | ||||
| self.decoder_class = decoder_class | |||||
| self.activation = activation | self.activation = activation | ||||
| self.decoders = None | self.decoders = None | ||||
| self.build() | self.build() | ||||
| def build(self) -> None: | def build(self) -> None: | ||||
| self.decoders = {} | |||||
| self.decoders = [] | |||||
| for fam in self.data.relation_families: | |||||
| for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||||
| for r in rels: | |||||
| pass | |||||
| dec = fam.decoder_class() | |||||
| self.decoders.append(dec) | |||||
| for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | ||||
| if len(rels) == 0: | if len(rels) == 0: | ||||
| @@ -39,7 +39,7 @@ class PreparedRelationType(RelationTypeBase): | |||||
| @dataclass | @dataclass | ||||
| class PreparedRelationFamily(RelationFamilyBase): | class PreparedRelationFamily(RelationFamilyBase): | ||||
| relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||||
| relation_types: List[PreparedRelationType] | |||||
| @dataclass | @dataclass | ||||
| @@ -110,36 +110,81 @@ def prepare_adj_mat(adj_mat: torch.Tensor, | |||||
| return adj_mat_train, edges_pos, edges_neg | return adj_mat_train, edges_pos, edges_neg | ||||
| def prepare_relation_type(r: RelationType, | |||||
| def prep_rel_one_node_type(r: RelationType, | |||||
| ratios: TrainValTest) -> PreparedRelationType: | ratios: TrainValTest) -> PreparedRelationType: | ||||
| adj_mat = r.adjacency_matrix | |||||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||||
| print('adj_mat_train:', adj_mat_train) | |||||
| adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | |||||
| return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||||
| adj_mat_train, None, edges_pos, edges_neg) | |||||
| def prep_rel_two_node_types_sym(r: RelationType, | |||||
| ratios: TrainValTest) -> PreparedRelationType: | |||||
| adj_mat = r.adjacency_matrix | |||||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||||
| return PreparedRelationType(r.name, r.node_type_row, | |||||
| r.node_type_column, | |||||
| norm_adj_mat_two_node_types(adj_mat_train), | |||||
| norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)), | |||||
| edges_pos, edges_neg) | |||||
| def prep_rel_two_node_types_asym(r: RelationType, | |||||
| ratios: TrainValTest) -> PreparedRelationType: | |||||
| if r.adjacency_matrix is not None: | |||||
| adj_mat_train, edges_pos, edges_neg =\ | |||||
| prepare_adj_mat(r.adjacency_matrix, ratios) | |||||
| else: | |||||
| adj_mat_train, edges_pos, edges_neg = \ | |||||
| None, torch.zeros((0, 2)), torch.zeros((0, 2)) | |||||
| if r.adjacency_matrix_backward is not None: | |||||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||||
| prepare_adj_mat(r.adjacency_matrix_backward, ratios) | |||||
| else: | |||||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||||
| None, torch.zeros((0, 2)), torch.zeros((0, 2)) | |||||
| edges_pos = torch.cat((edges_pos, edges_back_pos), dim=0) | |||||
| edges_neg = torch.cat((edges_neg, edges_back_neg), dim=0) | |||||
| return PreparedRelationType(r.name, r.node_type_row, | |||||
| r.node_type_column, | |||||
| norm_adj_mat_two_node_types(adj_mat_train), | |||||
| norm_adj_mat_two_node_types(adj_mat_back_train), | |||||
| edges_pos, edges_neg) | |||||
| def prepare_relation_type(r: RelationType, | |||||
| ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType: | |||||
| if not isinstance(r, RelationType): | if not isinstance(r, RelationType): | ||||
| raise ValueError('r must be a RelationType') | raise ValueError('r must be a RelationType') | ||||
| if not isinstance(ratios, TrainValTest): | if not isinstance(ratios, TrainValTest): | ||||
| raise ValueError('ratios must be a TrainValTest') | raise ValueError('ratios must be a TrainValTest') | ||||
| adj_mat = r.adjacency_matrix | |||||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||||
| print('adj_mat_train:', adj_mat_train) | |||||
| if r.node_type_row == r.node_type_column: | if r.node_type_row == r.node_type_column: | ||||
| adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | |||||
| return prep_rel_one_node_type(r, ratios) | |||||
| elif is_symmetric: | |||||
| return prep_rel_two_node_types_sym(r, ratios) | |||||
| else: | else: | ||||
| adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) | |||||
| return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||||
| adj_mat_train, r.two_way, edges_pos, edges_neg) | |||||
| return prep_rel_two_node_types_asym(r, ratios) | |||||
| def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: | def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: | ||||
| relation_types = { (fam.node_type_row, fam.node_type_column): [], | |||||
| (fam.node_type_column, fam.node_type_row): [] } | |||||
| relation_types = [] | |||||
| for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||||
| for r in rels: | |||||
| relation_types[node_type_row, node_type_column].append( | |||||
| prepare_relation_type(r, ratios)) | |||||
| for r in fam.relation_types: | |||||
| relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric)) | |||||
| return PreparedRelationFamily(fam.data, fam.name, | return PreparedRelationFamily(fam.data, fam.name, | ||||
| fam.node_type_row, fam.node_type_column, | fam.node_type_row, fam.node_type_column, | ||||
| @@ -103,7 +103,7 @@ def test_decagon_layer_04(): | |||||
| in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
| multi_dgca = MultiDGCA([10], 32, | multi_dgca = MultiDGCA([10], 32, | ||||
| [r.adjacency_matrix for r in fam.relation_types[0, 0]], | |||||
| [r.adjacency_matrix for r in fam.relation_types], | |||||
| keep_prob=1., activation=lambda x: x) | keep_prob=1., activation=lambda x: x) | ||||
| d_layer = DecagonLayer(in_layer.output_dim, 32, d, | d_layer = DecagonLayer(in_layer.output_dim, 32, d, | ||||
| @@ -147,7 +147,7 @@ def test_decagon_layer_05(): | |||||
| in_layer = OneHotInputLayer(d) | in_layer = OneHotInputLayer(d) | ||||
| multi_dgca = MultiDGCA([100, 100], 32, | multi_dgca = MultiDGCA([100, 100], 32, | ||||
| [r.adjacency_matrix for r in fam.relation_types[0, 0]], | |||||
| [r.adjacency_matrix for r in fam.relation_types], | |||||
| keep_prob=1., activation=lambda x: x) | keep_prob=1., activation=lambda x: x) | ||||
| d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d, | d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d, | ||||
| @@ -107,7 +107,7 @@ def test_prepare_relation_type_01(): | |||||
| adj_mat = (torch.rand((10, 10)) > .5) | adj_mat = (torch.rand((10, 10)) > .5) | ||||
| r = RelationType('Test', 0, 0, adj_mat, True) | r = RelationType('Test', 0, 0, adj_mat, True) | ||||
| ratios = TrainValTest(.8, .1, .1) | ratios = TrainValTest(.8, .1, .1) | ||||
| _ = prepare_relation_type(r, ratios) | |||||
| _ = prepare_relation_type(r, ratios, False) | |||||