|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- from .matrix import NodeType
- import torch
- from collections import defaultdict
-
-
- class AdjListRelationType(object):
- def __init__(self, name, node_type_row, node_type_column,
- adjacency_list, adjacency_list_transposed=None):
-
- #if adjacency_matrix_transposed is not None and \
- # adjacency_matrix_transposed.shape != adjacency_matrix.transpose(0, 1).shape:
- # raise ValueError('adjacency_matrix_transposed has incorrect shape')
-
- self.name = name
- self.node_type_row = node_type_row
- self.node_type_column = node_type_column
- self.adjacency_list = adjacency_list
- self.adjacency_list_transposed = adjacency_list_transposed
-
- def get_adjacency_list(self, node_type_row, node_type_column):
- if self.node_type_row == node_type_row and \
- self.node_type_column == node_type_column:
- return self.adjacency_list
-
- elif self.node_type_row == node_type_column and \
- self.node_type_column == node_type_row:
- if self.adjacency_list_transposed is not None:
- return self.adjacency_list_transposed
- else:
- return torch.index_select(self.adjacency_list, 1,
- torch.LongTensor([1, 0]))
-
- else:
- raise ValueError('Specified row/column types do not correspond to this relation')
-
-
- def _verify_adjacency_list(adjacency_list, node_count_row, node_count_col):
- assert isinstance(adjacency_list, torch.Tensor)
- assert len(adjacency_list.shape) == 2
- assert torch.all(adjacency_list[:, 0] >= 0)
- assert torch.all(adjacency_list[:, 0] < node_count_row)
- assert torch.all(adjacency_list[:, 1] >= 0)
- assert torch.all(adjacency_list[:, 1] < node_count_col)
-
-
- class AdjListData(object):
- def __init__(self):
- self.node_types = []
- self.relation_types = defaultdict(list)
-
- def add_node_type(self, name, count): # , latent_length):
- self.node_types.append(NodeType(name, count))
-
- def add_relation_type(self, name, node_type_row, node_type_col, adjacency_list, adjacency_list_transposed=None):
- assert node_type_row >= 0 and node_type_row < len(self.node_types)
- assert node_type_col >= 0 and node_type_col < len(self.node_types)
-
- node_count_row = self.node_types[node_type_row].count
- node_count_col = self.node_types[node_type_col].count
-
- _verify_adjacency_list(adjacency_list, node_count_row, node_count_col)
- if adjacency_list_transposed is not None:
- _verify_adjacency_list(adjacency_list_transposed,
- node_count_col, node_count_row)
-
- self.relation_types[node_type_row, node_type_col].append(
- AdjListRelationType(name, node_type_row, node_type_col,
- adjacency_list, adjacency_list_transposed))
|