| @@ -50,22 +50,46 @@ class DecagonLayer(torch.nn.Module): | |||
| self.next_layer_repr = None | |||
| self.build() | |||
| def build_family(self, fam): | |||
| for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||
| if len(rels) == 0: | |||
| continue | |||
| convolutions = [] | |||
| def build_fam_one_node_type(self, fam): | |||
| convolutions = [] | |||
| for r in fam.relation_types: | |||
| conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column], | |||
| self.output_dim[fam.node_type_row], r.adjacency_matrix, | |||
| self.keep_prob, self.rel_activation) | |||
| convolutions.append(conv) | |||
| self.next_layer_repr[fam.node_type_row].append( | |||
| Convolutions(fam.node_type_column, convolutions)) | |||
| def build_fam_two_node_types(self, fam) -> None: | |||
| convolutions_row = [] | |||
| convolutions_column = [] | |||
| for r in fam.relation_types: | |||
| if r.adjacency_matrix is not None: | |||
| conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column], | |||
| self.output_dim[fam.node_type_row], r.adjacency_matrix, | |||
| self.keep_prob, self.rel_activation) | |||
| convolutions_row.append(conv) | |||
| for r in rels: | |||
| conv = DropoutGraphConvActivation(self.input_dim[node_type_column], | |||
| self.output_dim[node_type_row], r.adjacency_matrix, | |||
| if r.adjacency_matrix_backward is not None: | |||
| conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_row], | |||
| self.output_dim[fam.node_type_column], r.adjacency_matrix_backward, | |||
| self.keep_prob, self.rel_activation) | |||
| convolutions_column.append(conv) | |||
| self.next_layer_repr[fam.node_type_row].append( | |||
| Convolutions(fam.node_type_column, convolutions_row)) | |||
| convolutions.append(conv) | |||
| self.next_layer_repr[fam.node_type_column].append( | |||
| Convolutions(fam.node_type_row, convolutions_column)) | |||
| self.next_layer_repr[node_type_row].append( | |||
| Convolutions(node_type_column, convolutions)) | |||
| def build_family(self, fam) -> None: | |||
| if fam.node_type_row == fam.node_type_column: | |||
| self.build_fam_one_node_type(fam) | |||
| else: | |||
| self.build_fam_two_node_types(fam) | |||
| def build(self): | |||
| self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | |||
| @@ -49,12 +49,12 @@ class RelationTypeBase(object): | |||
| node_type_row: int | |||
| node_type_column: int | |||
| adjacency_matrix: torch.Tensor | |||
| two_way: bool | |||
| adjacency_matrix_backward: torch.Tensor | |||
| @dataclass | |||
| class RelationType(RelationTypeBase): | |||
| hints: Dict[str, Any] = field(default_factory=dict) | |||
| pass | |||
| @dataclass | |||
| @@ -69,7 +69,7 @@ class RelationFamilyBase(object): | |||
| @dataclass | |||
| class RelationFamily(RelationFamilyBase): | |||
| relation_types: Dict[Tuple[int, int], List[RelationType]] = None | |||
| relation_types: List[RelationType] = None | |||
| def __post_init__(self) -> None: | |||
| if not self.is_symmetric and \ | |||
| @@ -77,18 +77,18 @@ class RelationFamily(RelationFamilyBase): | |||
| self.decoder_class != BilinearDecoder: | |||
| raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||
| self.relation_types = { (self.node_type_row, self.node_type_column): [], | |||
| (self.node_type_column, self.node_type_row): [] } | |||
| self.relation_types = [] | |||
| def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, | |||
| adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, | |||
| two_way: bool = True) -> None: | |||
| def add_relation_type(self, | |||
| name: str, node_type_row: int, node_type_column: int, | |||
| adjacency_matrix: torch.Tensor, | |||
| adjacency_matrix_backward: torch.Tensor = None) -> None: | |||
| name = str(name) | |||
| node_type_row = int(node_type_row) | |||
| node_type_column = int(node_type_column) | |||
| if (node_type_row, node_type_column) not in self.relation_types: | |||
| if (node_type_row, node_type_column) != (self.node_type_row, self.node_type_column): | |||
| raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family') | |||
| if node_type_row < 0 or node_type_row >= len(self.data.node_types): | |||
| @@ -97,21 +97,33 @@ class RelationFamily(RelationFamilyBase): | |||
| if node_type_column < 0 or node_type_column >= len(self.data.node_types): | |||
| raise ValueError('node_type_column outside of the valid range of node types') | |||
| if not isinstance(adjacency_matrix, torch.Tensor): | |||
| if adjacency_matrix is None and adjacency_matrix_backward is None: | |||
| raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') | |||
| if adjacency_matrix is not None and \ | |||
| not isinstance(adjacency_matrix, torch.Tensor): | |||
| raise ValueError('adjacency_matrix must be a torch.Tensor') | |||
| # if isinstance(adjacency_matrix_backward, str) and \ | |||
| # adjacency_matrix_backward == 'symmetric': | |||
| # if self.is_symmetric: | |||
| # adjacency_matrix_backward = None | |||
| # else: | |||
| # adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||
| if adjacency_matrix_backward is not None \ | |||
| and not isinstance(adjacency_matrix_backward, torch.Tensor): | |||
| raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | |||
| if adjacency_matrix.shape != (self.data.node_types[node_type_row].count, | |||
| self.data.node_types[node_type_column].count): | |||
| if adjacency_matrix is not None and \ | |||
| adjacency_matrix.shape != (self.data.node_types[node_type_row].count, | |||
| self.data.node_types[node_type_column].count): | |||
| raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') | |||
| if adjacency_matrix_backward is not None and \ | |||
| adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, | |||
| self.data.node_types[node_type_row].count): | |||
| raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)') | |||
| raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)') | |||
| if node_type_row == node_type_column and \ | |||
| adjacency_matrix_backward is not None: | |||
| @@ -125,19 +137,12 @@ class RelationFamily(RelationFamilyBase): | |||
| adjacency_matrix.transpose(0, 1))): | |||
| raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') | |||
| two_way = bool(two_way) | |||
| if node_type_row != node_type_column and two_way: | |||
| print('%d != %d' % (node_type_row, node_type_column)) | |||
| if adjacency_matrix_backward is None: | |||
| adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||
| self.relation_types[node_type_column, node_type_row].append( | |||
| RelationType(name, node_type_column, node_type_row, | |||
| adjacency_matrix_backward, two_way, { 'display': False })) | |||
| if self.is_symmetric and node_type_row != node_type_column: | |||
| adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||
| self.relation_types[node_type_row, node_type_column].append( | |||
| RelationType(name, node_type_row, node_type_column, | |||
| adjacency_matrix, two_way)) | |||
| self.relation_types.append(RelationType(name, | |||
| node_type_row, node_type_column, | |||
| adjacency_matrix, adjacency_matrix_backward)) | |||
| def node_name(self, index): | |||
| return self.data.node_types[index].name | |||
| @@ -145,24 +150,26 @@ class RelationFamily(RelationFamilyBase): | |||
| def __repr__(self): | |||
| s = 'Relation family %s' % self.name | |||
| for (node_type_row, node_type_column), rels in self.relation_types.items(): | |||
| for r in rels: | |||
| if 'display' in r.hints and not r.hints['display']: | |||
| continue | |||
| s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \ | |||
| (self.node_name(node_type_row), self.node_name(node_type_column))) | |||
| for r in self.relation_types: | |||
| s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||
| if (r.adjacency_matrix is not None \ | |||
| and r.adjacency_matrix_backward is not None) \ | |||
| or self.is_symmetric \ | |||
| else '%s <- %s' % (self.node_name(self.node_type_row), | |||
| self.node_name(self.node_type_column))) | |||
| return s | |||
| def repr_indented(self): | |||
| s = ' - %s' % self.name | |||
| for (node_type_row, node_type_column), rels in self.relation_types.items(): | |||
| for r in rels: | |||
| if 'display' in r.hints and not r.hints['display']: | |||
| continue | |||
| s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \ | |||
| (self.node_name(node_type_row), self.node_name(node_type_column))) | |||
| for r in self.relation_types: | |||
| s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||
| if (r.adjacency_matrix is not None \ | |||
| and r.adjacency_matrix_backward is not None) \ | |||
| or self.is_symmetric \ | |||
| else '%s <- %s' % (self.node_name(self.node_type_row), | |||
| self.node_name(self.node_type_column))) | |||
| return s | |||
| @@ -22,7 +22,6 @@ class DecodeLayer(torch.nn.Module): | |||
| input_dim: List[int], | |||
| data: PreparedData, | |||
| keep_prob: float = 1., | |||
| decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, | |||
| activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, | |||
| **kwargs) -> None: | |||
| @@ -37,26 +36,25 @@ class DecodeLayer(torch.nn.Module): | |||
| if not isinstance(data, PreparedData): | |||
| raise TypeError('data must be an instance of PreparedData') | |||
| if not isinstance(decoder_class, type) and \ | |||
| not isinstance(decoder_class, dict): | |||
| raise TypeError('decoder_class must be a Type or a Dict') | |||
| if not isinstance(decoder_class, dict): | |||
| decoder_class = { k: decoder_class \ | |||
| for k in data.relation_types.keys() } | |||
| self.input_dim = input_dim | |||
| self.output_dim = 1 | |||
| self.data = data | |||
| self.keep_prob = keep_prob | |||
| self.decoder_class = decoder_class | |||
| self.activation = activation | |||
| self.decoders = None | |||
| self.build() | |||
| def build(self) -> None: | |||
| self.decoders = {} | |||
| self.decoders = [] | |||
| for fam in self.data.relation_families: | |||
| for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||
| for r in rels: | |||
| pass | |||
| dec = fam.decoder_class() | |||
| self.decoders.append(dec) | |||
| for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | |||
| if len(rels) == 0: | |||
| @@ -39,7 +39,7 @@ class PreparedRelationType(RelationTypeBase): | |||
| @dataclass | |||
| class PreparedRelationFamily(RelationFamilyBase): | |||
| relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||
| relation_types: List[PreparedRelationType] | |||
| @dataclass | |||
| @@ -110,36 +110,81 @@ def prepare_adj_mat(adj_mat: torch.Tensor, | |||
| return adj_mat_train, edges_pos, edges_neg | |||
| def prepare_relation_type(r: RelationType, | |||
| def prep_rel_one_node_type(r: RelationType, | |||
| ratios: TrainValTest) -> PreparedRelationType: | |||
| adj_mat = r.adjacency_matrix | |||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||
| print('adj_mat_train:', adj_mat_train) | |||
| adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | |||
| return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||
| adj_mat_train, None, edges_pos, edges_neg) | |||
| def prep_rel_two_node_types_sym(r: RelationType, | |||
| ratios: TrainValTest) -> PreparedRelationType: | |||
| adj_mat = r.adjacency_matrix | |||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||
| return PreparedRelationType(r.name, r.node_type_row, | |||
| r.node_type_column, | |||
| norm_adj_mat_two_node_types(adj_mat_train), | |||
| norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)), | |||
| edges_pos, edges_neg) | |||
| def prep_rel_two_node_types_asym(r: RelationType, | |||
| ratios: TrainValTest) -> PreparedRelationType: | |||
| if r.adjacency_matrix is not None: | |||
| adj_mat_train, edges_pos, edges_neg =\ | |||
| prepare_adj_mat(r.adjacency_matrix, ratios) | |||
| else: | |||
| adj_mat_train, edges_pos, edges_neg = \ | |||
| None, torch.zeros((0, 2)), torch.zeros((0, 2)) | |||
| if r.adjacency_matrix_backward is not None: | |||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||
| prepare_adj_mat(r.adjacency_matrix_backward, ratios) | |||
| else: | |||
| adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||
| None, torch.zeros((0, 2)), torch.zeros((0, 2)) | |||
| edges_pos = torch.cat((edges_pos, edges_back_pos), dim=0) | |||
| edges_neg = torch.cat((edges_neg, edges_back_neg), dim=0) | |||
| return PreparedRelationType(r.name, r.node_type_row, | |||
| r.node_type_column, | |||
| norm_adj_mat_two_node_types(adj_mat_train), | |||
| norm_adj_mat_two_node_types(adj_mat_back_train), | |||
| edges_pos, edges_neg) | |||
| def prepare_relation_type(r: RelationType, | |||
| ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType: | |||
| if not isinstance(r, RelationType): | |||
| raise ValueError('r must be a RelationType') | |||
| if not isinstance(ratios, TrainValTest): | |||
| raise ValueError('ratios must be a TrainValTest') | |||
| adj_mat = r.adjacency_matrix | |||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||
| print('adj_mat_train:', adj_mat_train) | |||
| if r.node_type_row == r.node_type_column: | |||
| adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | |||
| return prep_rel_one_node_type(r, ratios) | |||
| elif is_symmetric: | |||
| return prep_rel_two_node_types_sym(r, ratios) | |||
| else: | |||
| adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) | |||
| return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||
| adj_mat_train, r.two_way, edges_pos, edges_neg) | |||
| return prep_rel_two_node_types_asym(r, ratios) | |||
| def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: | |||
| relation_types = { (fam.node_type_row, fam.node_type_column): [], | |||
| (fam.node_type_column, fam.node_type_row): [] } | |||
| relation_types = [] | |||
| for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||
| for r in rels: | |||
| relation_types[node_type_row, node_type_column].append( | |||
| prepare_relation_type(r, ratios)) | |||
| for r in fam.relation_types: | |||
| relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric)) | |||
| return PreparedRelationFamily(fam.data, fam.name, | |||
| fam.node_type_row, fam.node_type_column, | |||
| @@ -103,7 +103,7 @@ def test_decagon_layer_04(): | |||
| in_layer = OneHotInputLayer(d) | |||
| multi_dgca = MultiDGCA([10], 32, | |||
| [r.adjacency_matrix for r in fam.relation_types[0, 0]], | |||
| [r.adjacency_matrix for r in fam.relation_types], | |||
| keep_prob=1., activation=lambda x: x) | |||
| d_layer = DecagonLayer(in_layer.output_dim, 32, d, | |||
| @@ -147,7 +147,7 @@ def test_decagon_layer_05(): | |||
| in_layer = OneHotInputLayer(d) | |||
| multi_dgca = MultiDGCA([100, 100], 32, | |||
| [r.adjacency_matrix for r in fam.relation_types[0, 0]], | |||
| [r.adjacency_matrix for r in fam.relation_types], | |||
| keep_prob=1., activation=lambda x: x) | |||
| d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d, | |||
| @@ -107,7 +107,7 @@ def test_prepare_relation_type_01(): | |||
| adj_mat = (torch.rand((10, 10)) > .5) | |||
| r = RelationType('Test', 0, 0, adj_mat, True) | |||
| ratios = TrainValTest(.8, .1, .1) | |||
| _ = prepare_relation_type(r, ratios) | |||
| _ = prepare_relation_type(r, ratios, False) | |||