IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
浏览代码

Store relation types in a list instead of a dictionary and related fixes.

master
Stanislaw Adaszewski 4 年前
父节点
当前提交
832f620a78
共有 6 个文件被更改,包括 154 次插入80 次删除
  1. +36
    -12
      src/icosagon/convlayer.py
  2. +44
    -37
      src/icosagon/data.py
  3. +9
    -11
      src/icosagon/declayer.py
  4. +62
    -17
      src/icosagon/trainprep.py
  5. +2
    -2
      tests/icosagon/test_convlayer.py
  6. +1
    -1
      tests/icosagon/test_trainprep.py

+ 36
- 12
src/icosagon/convlayer.py 查看文件

@@ -50,22 +50,46 @@ class DecagonLayer(torch.nn.Module):
self.next_layer_repr = None
self.build()
def build_family(self, fam):
for (node_type_row, node_type_column), rels in fam.relation_types.items():
if len(rels) == 0:
continue
convolutions = []
def build_fam_one_node_type(self, fam):
convolutions = []
for r in fam.relation_types:
conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column],
self.output_dim[fam.node_type_row], r.adjacency_matrix,
self.keep_prob, self.rel_activation)
convolutions.append(conv)
self.next_layer_repr[fam.node_type_row].append(
Convolutions(fam.node_type_column, convolutions))
def build_fam_two_node_types(self, fam) -> None:
convolutions_row = []
convolutions_column = []
for r in fam.relation_types:
if r.adjacency_matrix is not None:
conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column],
self.output_dim[fam.node_type_row], r.adjacency_matrix,
self.keep_prob, self.rel_activation)
convolutions_row.append(conv)
for r in rels:
conv = DropoutGraphConvActivation(self.input_dim[node_type_column],
self.output_dim[node_type_row], r.adjacency_matrix,
if r.adjacency_matrix_backward is not None:
conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_row],
self.output_dim[fam.node_type_column], r.adjacency_matrix_backward,
self.keep_prob, self.rel_activation)
convolutions_column.append(conv)
self.next_layer_repr[fam.node_type_row].append(
Convolutions(fam.node_type_column, convolutions_row))
convolutions.append(conv)
self.next_layer_repr[fam.node_type_column].append(
Convolutions(fam.node_type_row, convolutions_column))
self.next_layer_repr[node_type_row].append(
Convolutions(node_type_column, convolutions))
def build_family(self, fam) -> None:
if fam.node_type_row == fam.node_type_column:
self.build_fam_one_node_type(fam)
else:
self.build_fam_two_node_types(fam)
def build(self):
self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ]


+ 44
- 37
src/icosagon/data.py 查看文件

@@ -49,12 +49,12 @@ class RelationTypeBase(object):
node_type_row: int
node_type_column: int
adjacency_matrix: torch.Tensor
two_way: bool
adjacency_matrix_backward: torch.Tensor
@dataclass
class RelationType(RelationTypeBase):
hints: Dict[str, Any] = field(default_factory=dict)
pass
@dataclass
@@ -69,7 +69,7 @@ class RelationFamilyBase(object):
@dataclass
class RelationFamily(RelationFamilyBase):
relation_types: Dict[Tuple[int, int], List[RelationType]] = None
relation_types: List[RelationType] = None
def __post_init__(self) -> None:
if not self.is_symmetric and \
@@ -77,18 +77,18 @@ class RelationFamily(RelationFamilyBase):
self.decoder_class != BilinearDecoder:
raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only')
self.relation_types = { (self.node_type_row, self.node_type_column): [],
(self.node_type_column, self.node_type_row): [] }
self.relation_types = []
def add_relation_type(self, name: str, node_type_row: int, node_type_column: int,
adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None,
two_way: bool = True) -> None:
def add_relation_type(self,
name: str, node_type_row: int, node_type_column: int,
adjacency_matrix: torch.Tensor,
adjacency_matrix_backward: torch.Tensor = None) -> None:
name = str(name)
node_type_row = int(node_type_row)
node_type_column = int(node_type_column)
if (node_type_row, node_type_column) not in self.relation_types:
if (node_type_row, node_type_column) != (self.node_type_row, self.node_type_column):
raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family')
if node_type_row < 0 or node_type_row >= len(self.data.node_types):
@@ -97,21 +97,33 @@ class RelationFamily(RelationFamilyBase):
if node_type_column < 0 or node_type_column >= len(self.data.node_types):
raise ValueError('node_type_column outside of the valid range of node types')
if not isinstance(adjacency_matrix, torch.Tensor):
if adjacency_matrix is None and adjacency_matrix_backward is None:
raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None')
if adjacency_matrix is not None and \
not isinstance(adjacency_matrix, torch.Tensor):
raise ValueError('adjacency_matrix must be a torch.Tensor')
# if isinstance(adjacency_matrix_backward, str) and \
# adjacency_matrix_backward == 'symmetric':
# if self.is_symmetric:
# adjacency_matrix_backward = None
# else:
# adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
if adjacency_matrix_backward is not None \
and not isinstance(adjacency_matrix_backward, torch.Tensor):
raise ValueError('adjacency_matrix_backward must be a torch.Tensor')
if adjacency_matrix.shape != (self.data.node_types[node_type_row].count,
self.data.node_types[node_type_column].count):
if adjacency_matrix is not None and \
adjacency_matrix.shape != (self.data.node_types[node_type_row].count,
self.data.node_types[node_type_column].count):
raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)')
if adjacency_matrix_backward is not None and \
adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count,
self.data.node_types[node_type_row].count):
raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)')
raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)')
if node_type_row == node_type_column and \
adjacency_matrix_backward is not None:
@@ -125,19 +137,12 @@ class RelationFamily(RelationFamilyBase):
adjacency_matrix.transpose(0, 1))):
raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric')
two_way = bool(two_way)
if node_type_row != node_type_column and two_way:
print('%d != %d' % (node_type_row, node_type_column))
if adjacency_matrix_backward is None:
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
self.relation_types[node_type_column, node_type_row].append(
RelationType(name, node_type_column, node_type_row,
adjacency_matrix_backward, two_way, { 'display': False }))
if self.is_symmetric and node_type_row != node_type_column:
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
self.relation_types[node_type_row, node_type_column].append(
RelationType(name, node_type_row, node_type_column,
adjacency_matrix, two_way))
self.relation_types.append(RelationType(name,
node_type_row, node_type_column,
adjacency_matrix, adjacency_matrix_backward))
def node_name(self, index):
return self.data.node_types[index].name
@@ -145,24 +150,26 @@ class RelationFamily(RelationFamilyBase):
def __repr__(self):
s = 'Relation family %s' % self.name
for (node_type_row, node_type_column), rels in self.relation_types.items():
for r in rels:
if 'display' in r.hints and not r.hints['display']:
continue
s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \
(self.node_name(node_type_row), self.node_name(node_type_column)))
for r in self.relation_types:
s += '\n - %s%s' % (r.name, ' (two-way)' \
if (r.adjacency_matrix is not None \
and r.adjacency_matrix_backward is not None) \
or self.is_symmetric \
else '%s <- %s' % (self.node_name(self.node_type_row),
self.node_name(self.node_type_column)))
return s
def repr_indented(self):
s = ' - %s' % self.name
for (node_type_row, node_type_column), rels in self.relation_types.items():
for r in rels:
if 'display' in r.hints and not r.hints['display']:
continue
s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \
(self.node_name(node_type_row), self.node_name(node_type_column)))
for r in self.relation_types:
s += '\n - %s%s' % (r.name, ' (two-way)' \
if (r.adjacency_matrix is not None \
and r.adjacency_matrix_backward is not None) \
or self.is_symmetric \
else '%s <- %s' % (self.node_name(self.node_type_row),
self.node_name(self.node_type_column)))
return s


+ 9
- 11
src/icosagon/declayer.py 查看文件

@@ -22,7 +22,6 @@ class DecodeLayer(torch.nn.Module):
input_dim: List[int],
data: PreparedData,
keep_prob: float = 1.,
decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder,
activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
**kwargs) -> None:
@@ -37,26 +36,25 @@ class DecodeLayer(torch.nn.Module):
if not isinstance(data, PreparedData):
raise TypeError('data must be an instance of PreparedData')
if not isinstance(decoder_class, type) and \
not isinstance(decoder_class, dict):
raise TypeError('decoder_class must be a Type or a Dict')
if not isinstance(decoder_class, dict):
decoder_class = { k: decoder_class \
for k in data.relation_types.keys() }
self.input_dim = input_dim
self.output_dim = 1
self.data = data
self.keep_prob = keep_prob
self.decoder_class = decoder_class
self.activation = activation
self.decoders = None
self.build()
def build(self) -> None:
self.decoders = {}
self.decoders = []
for fam in self.data.relation_families:
for (node_type_row, node_type_column), rels in fam.relation_types.items():
for r in rels:
pass
dec = fam.decoder_class()
self.decoders.append(dec)
for (node_type_row, node_type_column), rels in self.data.relation_types.items():
if len(rels) == 0:


+ 62
- 17
src/icosagon/trainprep.py 查看文件

@@ -39,7 +39,7 @@ class PreparedRelationType(RelationTypeBase):
@dataclass
class PreparedRelationFamily(RelationFamilyBase):
relation_types: Dict[Tuple[int, int], List[PreparedRelationType]]
relation_types: List[PreparedRelationType]
@dataclass
@@ -110,36 +110,81 @@ def prepare_adj_mat(adj_mat: torch.Tensor,
return adj_mat_train, edges_pos, edges_neg
def prepare_relation_type(r: RelationType,
def prep_rel_one_node_type(r: RelationType,
ratios: TrainValTest) -> PreparedRelationType:
adj_mat = r.adjacency_matrix
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
print('adj_mat_train:', adj_mat_train)
adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
return PreparedRelationType(r.name, r.node_type_row, r.node_type_column,
adj_mat_train, None, edges_pos, edges_neg)
def prep_rel_two_node_types_sym(r: RelationType,
ratios: TrainValTest) -> PreparedRelationType:
adj_mat = r.adjacency_matrix
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
return PreparedRelationType(r.name, r.node_type_row,
r.node_type_column,
norm_adj_mat_two_node_types(adj_mat_train),
norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)),
edges_pos, edges_neg)
def prep_rel_two_node_types_asym(r: RelationType,
ratios: TrainValTest) -> PreparedRelationType:
if r.adjacency_matrix is not None:
adj_mat_train, edges_pos, edges_neg =\
prepare_adj_mat(r.adjacency_matrix, ratios)
else:
adj_mat_train, edges_pos, edges_neg = \
None, torch.zeros((0, 2)), torch.zeros((0, 2))
if r.adjacency_matrix_backward is not None:
adj_mat_back_train, edges_back_pos, edges_back_neg = \
prepare_adj_mat(r.adjacency_matrix_backward, ratios)
else:
adj_mat_back_train, edges_back_pos, edges_back_neg = \
None, torch.zeros((0, 2)), torch.zeros((0, 2))
edges_pos = torch.cat((edges_pos, edges_back_pos), dim=0)
edges_neg = torch.cat((edges_neg, edges_back_neg), dim=0)
return PreparedRelationType(r.name, r.node_type_row,
r.node_type_column,
norm_adj_mat_two_node_types(adj_mat_train),
norm_adj_mat_two_node_types(adj_mat_back_train),
edges_pos, edges_neg)
def prepare_relation_type(r: RelationType,
ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType:
if not isinstance(r, RelationType):
raise ValueError('r must be a RelationType')
if not isinstance(ratios, TrainValTest):
raise ValueError('ratios must be a TrainValTest')
adj_mat = r.adjacency_matrix
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios)
print('adj_mat_train:', adj_mat_train)
if r.node_type_row == r.node_type_column:
adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train)
return prep_rel_one_node_type(r, ratios)
elif is_symmetric:
return prep_rel_two_node_types_sym(r, ratios)
else:
adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train)
return PreparedRelationType(r.name, r.node_type_row, r.node_type_column,
adj_mat_train, r.two_way, edges_pos, edges_neg)
return prep_rel_two_node_types_asym(r, ratios)
def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily:
relation_types = { (fam.node_type_row, fam.node_type_column): [],
(fam.node_type_column, fam.node_type_row): [] }
relation_types = []
for (node_type_row, node_type_column), rels in fam.relation_types.items():
for r in rels:
relation_types[node_type_row, node_type_column].append(
prepare_relation_type(r, ratios))
for r in fam.relation_types:
relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric))
return PreparedRelationFamily(fam.data, fam.name,
fam.node_type_row, fam.node_type_column,


+ 2
- 2
tests/icosagon/test_convlayer.py 查看文件

@@ -103,7 +103,7 @@ def test_decagon_layer_04():
in_layer = OneHotInputLayer(d)
multi_dgca = MultiDGCA([10], 32,
[r.adjacency_matrix for r in fam.relation_types[0, 0]],
[r.adjacency_matrix for r in fam.relation_types],
keep_prob=1., activation=lambda x: x)
d_layer = DecagonLayer(in_layer.output_dim, 32, d,
@@ -147,7 +147,7 @@ def test_decagon_layer_05():
in_layer = OneHotInputLayer(d)
multi_dgca = MultiDGCA([100, 100], 32,
[r.adjacency_matrix for r in fam.relation_types[0, 0]],
[r.adjacency_matrix for r in fam.relation_types],
keep_prob=1., activation=lambda x: x)
d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d,


+ 1
- 1
tests/icosagon/test_trainprep.py 查看文件

@@ -107,7 +107,7 @@ def test_prepare_relation_type_01():
adj_mat = (torch.rand((10, 10)) > .5)
r = RelationType('Test', 0, 0, adj_mat, True)
ratios = TrainValTest(.8, .1, .1)
_ = prepare_relation_type(r, ratios)
_ = prepare_relation_type(r, ratios, False)


正在加载...
取消
保存