| @@ -21,7 +21,7 @@ class RelationType(object): | |||||
| node_type_row: int | node_type_row: int | ||||
| node_type_column: int | node_type_column: int | ||||
| adjacency_matrix: torch.Tensor | adjacency_matrix: torch.Tensor | ||||
| is_autogenerated: bool | |||||
| is_autogenerated: bool = False | |||||
| class Data(object): | class Data(object): | ||||
| @@ -6,17 +6,90 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import scipy.sparse as sp | import scipy.sparse as sp | ||||
| import torch | |||||
| def norm_adj_mat_one_node_type(adj): | |||||
| adj = sp.coo_matrix(adj) | |||||
| assert adj.shape[0] == adj.shape[1] | |||||
| adj_ = adj + sp.eye(adj.shape[0]) | |||||
| rowsum = np.array(adj_.sum(1)) | |||||
| degree_mat_inv_sqrt = np.power(rowsum, -0.5).flatten() | |||||
| degree_mat_inv_sqrt = sp.diags(degree_mat_inv_sqrt) | |||||
| adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt) | |||||
| return adj_normalized | |||||
| def add_eye_sparse(adj_mat: torch.Tensor) -> torch.Tensor: | |||||
| if not isinstance(adj_mat, torch.Tensor): | |||||
| raise ValueError('adj_mat must be a torch.Tensor') | |||||
| if not adj_mat.is_sparse: | |||||
| raise ValueError('adj_mat must be sparse') | |||||
| if len(adj_mat.shape) != 2 or \ | |||||
| adj_mat.shape[0] != adj_mat.shape[1]: | |||||
| raise ValueError('adj_mat must be a square matrix') | |||||
| adj_mat = adj_mat.coalesce() | |||||
| indices = adj_mat.indices() | |||||
| values = adj_mat.values() | |||||
| eye_indices = torch.arange(adj_mat.shape[0], dtype=indices.dtype).view(1, -1) | |||||
| eye_indices = torch.cat((eye_indices, eye_indices), 0) | |||||
| eye_values = torch.ones(adj_mat.shape[0], dtype=values.dtype) | |||||
| indices = torch.cat((indices, eye_indices), 1) | |||||
| values = torch.cat((values, eye_values), 0) | |||||
| adj_mat = torch.sparse_coo_tensor(indices=indices, values=values, size=adj_mat.shape) | |||||
| return adj_mat | |||||
| def norm_adj_mat_one_node_type_sparse(adj_mat): | |||||
| if len(adj_mat.shape) != 2 or \ | |||||
| adj_mat.shape[0] != adj_mat.shape[1]: | |||||
| raise ValueError('adj_mat must be a square matrix') | |||||
| adj_mat = add_eye_sparse(adj_mat) | |||||
| adj_mat = adj_mat.coalesce() | |||||
| indices = adj_mat.indices() | |||||
| values = adj_mat.values() | |||||
| degrees = torch.zeros(adj_mat.shape[0]) | |||||
| degrees = degrees.index_add(0, indices[0], values.to(degrees.dtype)) | |||||
| print('degrees:', degrees) | |||||
| print('values:', values) | |||||
| values = values.to(degrees.dtype) / degrees[indices[0]] | |||||
| adj_mat = torch.sparse_coo_tensor(indices=indices, values=values, size=adj_mat.shape) | |||||
| return adj_mat | |||||
| def norm_adj_mat_one_node_type_dense(adj_mat): | |||||
| if not isinstance(adj_mat, torch.Tensor): | |||||
| raise ValueError('adj_mat must be a torch.Tensor') | |||||
| if adj_mat.is_sparse: | |||||
| raise ValueError('adj_mat must be dense') | |||||
| if len(adj_mat.shape) != 2 or \ | |||||
| adj_mat.shape[0] != adj_mat.shape[1]: | |||||
| raise ValueError('adj_mat must be a square matrix') | |||||
| adj_mat = adj_mat + torch.eye(adj_mat.shape[0], dtype=adj_mat.dtype) | |||||
| degrees = adj_mat.sum(1).view(-1, 1).to(torch.float32) | |||||
| adj_mat = adj_mat.to(degrees.dtype) / degrees | |||||
| return adj_mat | |||||
| def norm_adj_mat_one_node_type(adj_mat): | |||||
| if adj_mat.is_sparse: | |||||
| return norm_adj_mat_one_node_type_sparse(adj_mat) | |||||
| else: | |||||
| return norm_adj_mat_one_node_type_dense(adj_mat) | |||||
| # def norm_adj_mat_one_node_type(adj): | |||||
| # adj = sp.coo_matrix(adj) | |||||
| # assert adj.shape[0] == adj.shape[1] | |||||
| # adj_ = adj + sp.eye(adj.shape[0]) | |||||
| # rowsum = np.array(adj_.sum(1)) | |||||
| # degree_mat_inv_sqrt = np.power(rowsum, -0.5).flatten() | |||||
| # degree_mat_inv_sqrt = sp.diags(degree_mat_inv_sqrt) | |||||
| # adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt) | |||||
| # return adj_normalized | |||||
| def norm_adj_mat_two_node_types(adj): | def norm_adj_mat_two_node_types(adj): | ||||
| @@ -11,7 +11,9 @@ from typing import Any, \ | |||||
| List, \ | List, \ | ||||
| Tuple, \ | Tuple, \ | ||||
| Dict | Dict | ||||
| from .data import NodeType | |||||
| from .data import NodeType, \ | |||||
| RelationType, \ | |||||
| Data | |||||
| from collections import defaultdict | from collections import defaultdict | ||||
| from .normalize import norm_adj_mat_one_node_type, \ | from .normalize import norm_adj_mat_one_node_type, \ | ||||
| norm_adj_mat_two_node_types | norm_adj_mat_two_node_types | ||||
| @@ -73,7 +75,7 @@ def train_val_test_split_edges(edges: torch.Tensor, | |||||
| return TrainValTest(edges_train, edges_val, edges_test) | return TrainValTest(edges_train, edges_val, edges_test) | ||||
| def get_edges_and_degrees(adj_mat): | |||||
| def get_edges_and_degrees(adj_mat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
| if adj_mat.is_sparse: | if adj_mat.is_sparse: | ||||
| adj_mat = adj_mat.coalesce() | adj_mat = adj_mat.coalesce() | ||||
| degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64) | degrees = torch.zeros(adj_mat.shape[1], dtype=torch.int64) | ||||
| @@ -109,23 +111,35 @@ def prepare_adj_mat(adj_mat: torch.Tensor, | |||||
| return adj_mat_train, edges_pos, edges_neg | return adj_mat_train, edges_pos, edges_neg | ||||
| def prepare_relation(r, ratios): | |||||
| def prepare_relation_type(r: RelationType, | |||||
| ratios: TrainValTest) -> PreparedRelationType: | |||||
| if not isinstance(r, RelationType): | |||||
| raise ValueError('r must be a RelationType') | |||||
| if not isinstance(ratios, TrainValTest): | |||||
| raise ValueError('ratios must be a TrainValTest') | |||||
| adj_mat = r.adjacency_matrix | adj_mat = r.adjacency_matrix | ||||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat) | |||||
| adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||||
| print('adj_mat_train:', adj_mat_train) | |||||
| if r.node_type_row == r.node_type_column: | if r.node_type_row == r.node_type_column: | ||||
| adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | ||||
| else: | else: | ||||
| adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) | adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) | ||||
| return PreparedRelation(r.name, r.node_type_row, r.node_type_column, | |||||
| return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||||
| adj_mat_train, edges_pos, edges_neg) | adj_mat_train, edges_pos, edges_neg) | ||||
| def prepare_training(data): | |||||
| def prepare_training(data: Data) -> PreparedData: | |||||
| if not isinstance(data, Data): | |||||
| raise ValueError('data must be of class Data') | |||||
| relation_types = defaultdict(lambda: defaultdict(list)) | relation_types = defaultdict(lambda: defaultdict(list)) | ||||
| for (node_type_row, node_type_column), rels in data.relation_types: | for (node_type_row, node_type_column), rels in data.relation_types: | ||||
| for r in rels: | for r in rels: | ||||
| relation_types[node_type_row][node_type_column].append( | relation_types[node_type_row][node_type_column].append( | ||||
| prep_relation(r)) | |||||
| prep_relation_type(r)) | |||||
| return PreparedData(data.node_types, relation_types) | return PreparedData(data.node_types, relation_types) | ||||
| @@ -0,0 +1,95 @@ | |||||
| from icosagon.normalize import add_eye_sparse, \ | |||||
| norm_adj_mat_one_node_type_sparse, \ | |||||
| norm_adj_mat_one_node_type_dense, \ | |||||
| norm_adj_mat_one_node_type | |||||
| import decagon_pytorch.normalize | |||||
| import torch | |||||
| import pytest | |||||
| import numpy as np | |||||
| def test_add_eye_sparse_01(): | |||||
| adj_mat_dense = torch.rand((10, 10)) | |||||
| adj_mat_sparse = adj_mat_dense.to_sparse() | |||||
| adj_mat_dense += torch.eye(10) | |||||
| adj_mat_sparse = add_eye_sparse(adj_mat_sparse) | |||||
| assert torch.all(adj_mat_sparse.to_dense() == adj_mat_dense) | |||||
| def test_add_eye_sparse_02(): | |||||
| adj_mat_dense = torch.rand((10, 20)) | |||||
| adj_mat_sparse = adj_mat_dense.to_sparse() | |||||
| with pytest.raises(ValueError): | |||||
| _ = add_eye_sparse(adj_mat_sparse) | |||||
| def test_add_eye_sparse_03(): | |||||
| adj_mat_dense = torch.rand((10, 10)) | |||||
| with pytest.raises(ValueError): | |||||
| _ = add_eye_sparse(adj_mat_dense) | |||||
| def test_add_eye_sparse_04(): | |||||
| adj_mat_dense = np.random.rand(10, 10) | |||||
| with pytest.raises(ValueError): | |||||
| _ = add_eye_sparse(adj_mat_dense) | |||||
| def test_norm_adj_mat_one_node_type_sparse_01(): | |||||
| adj_mat = torch.rand((10, 10)) | |||||
| adj_mat = (adj_mat > .5) | |||||
| adj_mat = adj_mat.to_sparse() | |||||
| _ = norm_adj_mat_one_node_type_sparse(adj_mat) | |||||
| def test_norm_adj_mat_one_node_type_sparse_02(): | |||||
| adj_mat_dense = torch.rand((10, 10)) | |||||
| adj_mat_dense = (adj_mat_dense > .5) | |||||
| adj_mat_sparse = adj_mat_dense.to_sparse() | |||||
| adj_mat_sparse = norm_adj_mat_one_node_type_sparse(adj_mat_sparse) | |||||
| adj_mat_dense = norm_adj_mat_one_node_type_dense(adj_mat_dense) | |||||
| assert torch.all(adj_mat_sparse.to_dense() == adj_mat_dense) | |||||
| def test_norm_adj_mat_one_node_type_dense_01(): | |||||
| adj_mat = torch.rand((10, 10)) | |||||
| adj_mat = (adj_mat > .5) | |||||
| _ = norm_adj_mat_one_node_type_dense(adj_mat) | |||||
| def test_norm_adj_mat_one_node_type_dense_02(): | |||||
| adj_mat = torch.tensor([ | |||||
| [0, 1, 1, 0], # 3 | |||||
| [1, 0, 1, 0], # 3 | |||||
| [1, 1, 0, 1], # 4 | |||||
| [0, 0, 1, 0] # 2 | |||||
| ]) | |||||
| expect = np.array([ | |||||
| [1/3, 1/3, 1/3, 0], | |||||
| [1/3, 1/3, 1/3, 0], | |||||
| [1/4, 1/4, 1/4, 1/4], | |||||
| [0, 0, 1/2, 1/2] | |||||
| ], dtype=np.float32) | |||||
| res = decagon_pytorch.normalize.norm_adj_mat_one_node_type(adj_mat) | |||||
| res = res.todense().astype(np.float32) | |||||
| print('res:', res) | |||||
| print('expect:', expect) | |||||
| assert torch.all(res == expect) | |||||
| @pytest.mark.skip | |||||
| def test_norm_adj_mat_one_node_type_dense_03(): | |||||
| adj_mat = torch.rand((10, 10)) | |||||
| adj_mat = (adj_mat > .5) | |||||
| adj_mat_dec = decagon_pytorch.normalize.norm_adj_mat_one_node_type(adj_mat) | |||||
| adj_mat_ico = norm_adj_mat_one_node_type_dense(adj_mat) | |||||
| adj_mat_dec = adj_mat_dec.todense() | |||||
| adj_mat_ico = adj_mat_ico.detach().cpu().numpy() | |||||
| print('adj_mat_dec:', adj_mat_dec) | |||||
| print('adj_mat_ico:', adj_mat_ico) | |||||
| assert np.all(adj_mat_dec == adj_mat_ico) | |||||
| @@ -7,11 +7,13 @@ | |||||
| from icosagon.trainprep import TrainValTest, \ | from icosagon.trainprep import TrainValTest, \ | ||||
| train_val_test_split_edges, \ | train_val_test_split_edges, \ | ||||
| get_edges_and_degrees, \ | get_edges_and_degrees, \ | ||||
| prepare_adj_mat | |||||
| prepare_adj_mat, \ | |||||
| prepare_relation_type | |||||
| import torch | import torch | ||||
| import pytest | import pytest | ||||
| import numpy as np | import numpy as np | ||||
| from itertools import chain | from itertools import chain | ||||
| from icosagon.data import RelationType | |||||
| def test_train_val_test_split_edges_01(): | def test_train_val_test_split_edges_01(): | ||||
| @@ -100,17 +102,23 @@ def test_prepare_adj_mat_02(): | |||||
| assert len(edges.shape) == 2 | assert len(edges.shape) == 2 | ||||
| assert edges.shape[1] == 2 | assert edges.shape[1] == 2 | ||||
| # def prepare_adj_mat(adj_mat: torch.Tensor, | |||||
| # ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]: | |||||
| # | |||||
| # degrees = adj_mat.sum(0) | |||||
| # edges_pos = torch.nonzero(adj_mat) | |||||
| # | |||||
| # neg_neighbors = fixed_unigram_candidate_sampler(edges_pos[:, 1], | |||||
| # len(edges), degrees, 0.75) | |||||
| # edges_neg = torch.cat((edges_pos[:, 0], neg_neighbors.view(-1, 1)), 1) | |||||
| def test_prepare_relation_type_01(): | |||||
| adj_mat = (torch.rand((10, 10)) > .5) | |||||
| r = RelationType('Test', 0, 0, adj_mat) | |||||
| ratios = TrainValTest(.8, .1, .1) | |||||
| _ = prepare_relation_type(r, ratios) | |||||
| # def prepare_relation(r, ratios): | |||||
| # adj_mat = r.adjacency_matrix | |||||
| # adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat) | |||||
| # | # | ||||
| # edges_pos = train_val_test_split_edges(edges_pos, ratios) | |||||
| # edges_neg = train_val_test_split_edges(edges_neg, ratios) | |||||
| # if r.node_type_row == r.node_type_column: | |||||
| # adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | |||||
| # else: | |||||
| # adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) | |||||
| # | # | ||||
| # return edges_pos, edges_neg | |||||
| # return PreparedRelation(r.name, r.node_type_row, r.node_type_column, | |||||
| # adj_mat_train, edges_pos, edges_neg) | |||||