@@ -50,22 +50,46 @@ class DecagonLayer(torch.nn.Module): | |||
self.next_layer_repr = None | |||
self.build() | |||
def build_family(self, fam): | |||
for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||
if len(rels) == 0: | |||
continue | |||
convolutions = [] | |||
def build_fam_one_node_type(self, fam): | |||
convolutions = [] | |||
for r in fam.relation_types: | |||
conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column], | |||
self.output_dim[fam.node_type_row], r.adjacency_matrix, | |||
self.keep_prob, self.rel_activation) | |||
convolutions.append(conv) | |||
self.next_layer_repr[fam.node_type_row].append( | |||
Convolutions(fam.node_type_column, convolutions)) | |||
def build_fam_two_node_types(self, fam) -> None: | |||
convolutions_row = [] | |||
convolutions_column = [] | |||
for r in fam.relation_types: | |||
if r.adjacency_matrix is not None: | |||
conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_column], | |||
self.output_dim[fam.node_type_row], r.adjacency_matrix, | |||
self.keep_prob, self.rel_activation) | |||
convolutions_row.append(conv) | |||
for r in rels: | |||
conv = DropoutGraphConvActivation(self.input_dim[node_type_column], | |||
self.output_dim[node_type_row], r.adjacency_matrix, | |||
if r.adjacency_matrix_backward is not None: | |||
conv = DropoutGraphConvActivation(self.input_dim[fam.node_type_row], | |||
self.output_dim[fam.node_type_column], r.adjacency_matrix_backward, | |||
self.keep_prob, self.rel_activation) | |||
convolutions_column.append(conv) | |||
self.next_layer_repr[fam.node_type_row].append( | |||
Convolutions(fam.node_type_column, convolutions_row)) | |||
convolutions.append(conv) | |||
self.next_layer_repr[fam.node_type_column].append( | |||
Convolutions(fam.node_type_row, convolutions_column)) | |||
self.next_layer_repr[node_type_row].append( | |||
Convolutions(node_type_column, convolutions)) | |||
def build_family(self, fam) -> None: | |||
if fam.node_type_row == fam.node_type_column: | |||
self.build_fam_one_node_type(fam) | |||
else: | |||
self.build_fam_two_node_types(fam) | |||
def build(self): | |||
self.next_layer_repr = [ [] for _ in range(len(self.data.node_types)) ] | |||
@@ -49,12 +49,12 @@ class RelationTypeBase(object): | |||
node_type_row: int | |||
node_type_column: int | |||
adjacency_matrix: torch.Tensor | |||
two_way: bool | |||
adjacency_matrix_backward: torch.Tensor | |||
@dataclass | |||
class RelationType(RelationTypeBase): | |||
hints: Dict[str, Any] = field(default_factory=dict) | |||
pass | |||
@dataclass | |||
@@ -69,7 +69,7 @@ class RelationFamilyBase(object): | |||
@dataclass | |||
class RelationFamily(RelationFamilyBase): | |||
relation_types: Dict[Tuple[int, int], List[RelationType]] = None | |||
relation_types: List[RelationType] = None | |||
def __post_init__(self) -> None: | |||
if not self.is_symmetric and \ | |||
@@ -77,18 +77,18 @@ class RelationFamily(RelationFamilyBase): | |||
self.decoder_class != BilinearDecoder: | |||
raise TypeError('Family is assymetric but the specified decoder_class supports symmetric relations only') | |||
self.relation_types = { (self.node_type_row, self.node_type_column): [], | |||
(self.node_type_column, self.node_type_row): [] } | |||
self.relation_types = [] | |||
def add_relation_type(self, name: str, node_type_row: int, node_type_column: int, | |||
adjacency_matrix: torch.Tensor, adjacency_matrix_backward: torch.Tensor = None, | |||
two_way: bool = True) -> None: | |||
def add_relation_type(self, | |||
name: str, node_type_row: int, node_type_column: int, | |||
adjacency_matrix: torch.Tensor, | |||
adjacency_matrix_backward: torch.Tensor = None) -> None: | |||
name = str(name) | |||
node_type_row = int(node_type_row) | |||
node_type_column = int(node_type_column) | |||
if (node_type_row, node_type_column) not in self.relation_types: | |||
if (node_type_row, node_type_column) != (self.node_type_row, self.node_type_column): | |||
raise ValueError('Specified node_type_row/node_type_column tuple does not belong to this family') | |||
if node_type_row < 0 or node_type_row >= len(self.data.node_types): | |||
@@ -97,21 +97,33 @@ class RelationFamily(RelationFamilyBase): | |||
if node_type_column < 0 or node_type_column >= len(self.data.node_types): | |||
raise ValueError('node_type_column outside of the valid range of node types') | |||
if not isinstance(adjacency_matrix, torch.Tensor): | |||
if adjacency_matrix is None and adjacency_matrix_backward is None: | |||
raise ValueError('adjacency_matrix and adjacency_matrix_backward cannot both be None') | |||
if adjacency_matrix is not None and \ | |||
not isinstance(adjacency_matrix, torch.Tensor): | |||
raise ValueError('adjacency_matrix must be a torch.Tensor') | |||
# if isinstance(adjacency_matrix_backward, str) and \ | |||
# adjacency_matrix_backward == 'symmetric': | |||
# if self.is_symmetric: | |||
# adjacency_matrix_backward = None | |||
# else: | |||
# adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||
if adjacency_matrix_backward is not None \ | |||
and not isinstance(adjacency_matrix_backward, torch.Tensor): | |||
raise ValueError('adjacency_matrix_backward must be a torch.Tensor') | |||
if adjacency_matrix.shape != (self.data.node_types[node_type_row].count, | |||
self.data.node_types[node_type_column].count): | |||
if adjacency_matrix is not None and \ | |||
adjacency_matrix.shape != (self.data.node_types[node_type_row].count, | |||
self.data.node_types[node_type_column].count): | |||
raise ValueError('adjacency_matrix shape must be (num_row_nodes, num_column_nodes)') | |||
if adjacency_matrix_backward is not None and \ | |||
adjacency_matrix_backward.shape != (self.data.node_types[node_type_column].count, | |||
self.data.node_types[node_type_row].count): | |||
raise ValueError('adjacency_matrix shape must be (num_column_nodes, num_row_nodes)') | |||
raise ValueError('adjacency_matrix_backward shape must be (num_column_nodes, num_row_nodes)') | |||
if node_type_row == node_type_column and \ | |||
adjacency_matrix_backward is not None: | |||
@@ -125,19 +137,12 @@ class RelationFamily(RelationFamilyBase): | |||
adjacency_matrix.transpose(0, 1))): | |||
raise ValueError('Relation family is symmetric but adjacency_matrix is assymetric') | |||
two_way = bool(two_way) | |||
if node_type_row != node_type_column and two_way: | |||
print('%d != %d' % (node_type_row, node_type_column)) | |||
if adjacency_matrix_backward is None: | |||
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||
self.relation_types[node_type_column, node_type_row].append( | |||
RelationType(name, node_type_column, node_type_row, | |||
adjacency_matrix_backward, two_way, { 'display': False })) | |||
if self.is_symmetric and node_type_row != node_type_column: | |||
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1) | |||
self.relation_types[node_type_row, node_type_column].append( | |||
RelationType(name, node_type_row, node_type_column, | |||
adjacency_matrix, two_way)) | |||
self.relation_types.append(RelationType(name, | |||
node_type_row, node_type_column, | |||
adjacency_matrix, adjacency_matrix_backward)) | |||
def node_name(self, index): | |||
return self.data.node_types[index].name | |||
@@ -145,24 +150,26 @@ class RelationFamily(RelationFamilyBase): | |||
def __repr__(self): | |||
s = 'Relation family %s' % self.name | |||
for (node_type_row, node_type_column), rels in self.relation_types.items(): | |||
for r in rels: | |||
if 'display' in r.hints and not r.hints['display']: | |||
continue | |||
s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \ | |||
(self.node_name(node_type_row), self.node_name(node_type_column))) | |||
for r in self.relation_types: | |||
s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||
if (r.adjacency_matrix is not None \ | |||
and r.adjacency_matrix_backward is not None) \ | |||
or self.is_symmetric \ | |||
else '%s <- %s' % (self.node_name(self.node_type_row), | |||
self.node_name(self.node_type_column))) | |||
return s | |||
def repr_indented(self): | |||
s = ' - %s' % self.name | |||
for (node_type_row, node_type_column), rels in self.relation_types.items(): | |||
for r in rels: | |||
if 'display' in r.hints and not r.hints['display']: | |||
continue | |||
s += '\n - %s%s' % (r.name, ' (two-way)' if r.two_way else '%s <- %s' % \ | |||
(self.node_name(node_type_row), self.node_name(node_type_column))) | |||
for r in self.relation_types: | |||
s += '\n - %s%s' % (r.name, ' (two-way)' \ | |||
if (r.adjacency_matrix is not None \ | |||
and r.adjacency_matrix_backward is not None) \ | |||
or self.is_symmetric \ | |||
else '%s <- %s' % (self.node_name(self.node_type_row), | |||
self.node_name(self.node_type_column))) | |||
return s | |||
@@ -22,7 +22,6 @@ class DecodeLayer(torch.nn.Module): | |||
input_dim: List[int], | |||
data: PreparedData, | |||
keep_prob: float = 1., | |||
decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder, | |||
activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid, | |||
**kwargs) -> None: | |||
@@ -37,26 +36,25 @@ class DecodeLayer(torch.nn.Module): | |||
if not isinstance(data, PreparedData): | |||
raise TypeError('data must be an instance of PreparedData') | |||
if not isinstance(decoder_class, type) and \ | |||
not isinstance(decoder_class, dict): | |||
raise TypeError('decoder_class must be a Type or a Dict') | |||
if not isinstance(decoder_class, dict): | |||
decoder_class = { k: decoder_class \ | |||
for k in data.relation_types.keys() } | |||
self.input_dim = input_dim | |||
self.output_dim = 1 | |||
self.data = data | |||
self.keep_prob = keep_prob | |||
self.decoder_class = decoder_class | |||
self.activation = activation | |||
self.decoders = None | |||
self.build() | |||
def build(self) -> None: | |||
self.decoders = {} | |||
self.decoders = [] | |||
for fam in self.data.relation_families: | |||
for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||
for r in rels: | |||
pass | |||
dec = fam.decoder_class() | |||
self.decoders.append(dec) | |||
for (node_type_row, node_type_column), rels in self.data.relation_types.items(): | |||
if len(rels) == 0: | |||
@@ -39,7 +39,7 @@ class PreparedRelationType(RelationTypeBase): | |||
@dataclass | |||
class PreparedRelationFamily(RelationFamilyBase): | |||
relation_types: Dict[Tuple[int, int], List[PreparedRelationType]] | |||
relation_types: List[PreparedRelationType] | |||
@dataclass | |||
@@ -110,36 +110,81 @@ def prepare_adj_mat(adj_mat: torch.Tensor, | |||
return adj_mat_train, edges_pos, edges_neg | |||
def prepare_relation_type(r: RelationType, | |||
def prep_rel_one_node_type(r: RelationType, | |||
ratios: TrainValTest) -> PreparedRelationType: | |||
adj_mat = r.adjacency_matrix | |||
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||
print('adj_mat_train:', adj_mat_train) | |||
adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | |||
return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||
adj_mat_train, None, edges_pos, edges_neg) | |||
def prep_rel_two_node_types_sym(r: RelationType, | |||
ratios: TrainValTest) -> PreparedRelationType: | |||
adj_mat = r.adjacency_matrix | |||
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||
return PreparedRelationType(r.name, r.node_type_row, | |||
r.node_type_column, | |||
norm_adj_mat_two_node_types(adj_mat_train), | |||
norm_adj_mat_two_node_types(adj_mat_train.transpose(0, 1)), | |||
edges_pos, edges_neg) | |||
def prep_rel_two_node_types_asym(r: RelationType, | |||
ratios: TrainValTest) -> PreparedRelationType: | |||
if r.adjacency_matrix is not None: | |||
adj_mat_train, edges_pos, edges_neg =\ | |||
prepare_adj_mat(r.adjacency_matrix, ratios) | |||
else: | |||
adj_mat_train, edges_pos, edges_neg = \ | |||
None, torch.zeros((0, 2)), torch.zeros((0, 2)) | |||
if r.adjacency_matrix_backward is not None: | |||
adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||
prepare_adj_mat(r.adjacency_matrix_backward, ratios) | |||
else: | |||
adj_mat_back_train, edges_back_pos, edges_back_neg = \ | |||
None, torch.zeros((0, 2)), torch.zeros((0, 2)) | |||
edges_pos = torch.cat((edges_pos, edges_back_pos), dim=0) | |||
edges_neg = torch.cat((edges_neg, edges_back_neg), dim=0) | |||
return PreparedRelationType(r.name, r.node_type_row, | |||
r.node_type_column, | |||
norm_adj_mat_two_node_types(adj_mat_train), | |||
norm_adj_mat_two_node_types(adj_mat_back_train), | |||
edges_pos, edges_neg) | |||
def prepare_relation_type(r: RelationType, | |||
ratios: TrainValTest, is_symmetric: bool) -> PreparedRelationType: | |||
if not isinstance(r, RelationType): | |||
raise ValueError('r must be a RelationType') | |||
if not isinstance(ratios, TrainValTest): | |||
raise ValueError('ratios must be a TrainValTest') | |||
adj_mat = r.adjacency_matrix | |||
adj_mat_train, edges_pos, edges_neg = prepare_adj_mat(adj_mat, ratios) | |||
print('adj_mat_train:', adj_mat_train) | |||
if r.node_type_row == r.node_type_column: | |||
adj_mat_train = norm_adj_mat_one_node_type(adj_mat_train) | |||
return prep_rel_one_node_type(r, ratios) | |||
elif is_symmetric: | |||
return prep_rel_two_node_types_sym(r, ratios) | |||
else: | |||
adj_mat_train = norm_adj_mat_two_node_types(adj_mat_train) | |||
return PreparedRelationType(r.name, r.node_type_row, r.node_type_column, | |||
adj_mat_train, r.two_way, edges_pos, edges_neg) | |||
return prep_rel_two_node_types_asym(r, ratios) | |||
def prepare_relation_family(fam: RelationFamily) -> PreparedRelationFamily: | |||
relation_types = { (fam.node_type_row, fam.node_type_column): [], | |||
(fam.node_type_column, fam.node_type_row): [] } | |||
relation_types = [] | |||
for (node_type_row, node_type_column), rels in fam.relation_types.items(): | |||
for r in rels: | |||
relation_types[node_type_row, node_type_column].append( | |||
prepare_relation_type(r, ratios)) | |||
for r in fam.relation_types: | |||
relation_types.append(prepare_relation_type(r, ratios, fam.is_symmetric)) | |||
return PreparedRelationFamily(fam.data, fam.name, | |||
fam.node_type_row, fam.node_type_column, | |||
@@ -103,7 +103,7 @@ def test_decagon_layer_04(): | |||
in_layer = OneHotInputLayer(d) | |||
multi_dgca = MultiDGCA([10], 32, | |||
[r.adjacency_matrix for r in fam.relation_types[0, 0]], | |||
[r.adjacency_matrix for r in fam.relation_types], | |||
keep_prob=1., activation=lambda x: x) | |||
d_layer = DecagonLayer(in_layer.output_dim, 32, d, | |||
@@ -147,7 +147,7 @@ def test_decagon_layer_05(): | |||
in_layer = OneHotInputLayer(d) | |||
multi_dgca = MultiDGCA([100, 100], 32, | |||
[r.adjacency_matrix for r in fam.relation_types[0, 0]], | |||
[r.adjacency_matrix for r in fam.relation_types], | |||
keep_prob=1., activation=lambda x: x) | |||
d_layer = DecagonLayer(in_layer.output_dim, output_dim=32, data=d, | |||
@@ -107,7 +107,7 @@ def test_prepare_relation_type_01(): | |||
adj_mat = (torch.rand((10, 10)) > .5) | |||
r = RelationType('Test', 0, 0, adj_mat, True) | |||
ratios = TrainValTest(.8, .1, .1) | |||
_ = prepare_relation_type(r, ratios) | |||
_ = prepare_relation_type(r, ratios, False) | |||