|
@@ -5,11 +5,12 @@ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import defaultdict
|
|
|
from collections import defaultdict
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
import torch
|
|
|
import torch
|
|
|
from typing import List, \
|
|
|
from typing import List, \
|
|
|
Dict, \
|
|
|
Dict, \
|
|
|
Tuple
|
|
|
|
|
|
|
|
|
Tuple, \
|
|
|
|
|
|
Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
@dataclass
|
|
@@ -24,7 +25,7 @@ class RelationType(object): |
|
|
node_type_row: int
|
|
|
node_type_row: int
|
|
|
node_type_column: int
|
|
|
node_type_column: int
|
|
|
adjacency_matrix: torch.Tensor
|
|
|
adjacency_matrix: torch.Tensor
|
|
|
is_autogenerated: bool = False
|
|
|
|
|
|
|
|
|
hints: Dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Data(object):
|
|
|
class Data(object):
|
|
@@ -81,14 +82,16 @@ class Data(object): |
|
|
|
|
|
|
|
|
self.relation_types[node_type_row, node_type_column].append(
|
|
|
self.relation_types[node_type_row, node_type_column].append(
|
|
|
RelationType(name, node_type_row, node_type_column,
|
|
|
RelationType(name, node_type_row, node_type_column,
|
|
|
adjacency_matrix, False))
|
|
|
|
|
|
|
|
|
adjacency_matrix))
|
|
|
|
|
|
|
|
|
if node_type_row != node_type_column and two_way:
|
|
|
if node_type_row != node_type_column and two_way:
|
|
|
|
|
|
hints = { 'display': False }
|
|
|
if adjacency_matrix_backward is None:
|
|
|
if adjacency_matrix_backward is None:
|
|
|
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
|
|
|
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
|
|
|
|
|
|
hints['symmetric'] = True
|
|
|
self.relation_types[node_type_column, node_type_row].append(
|
|
|
self.relation_types[node_type_column, node_type_row].append(
|
|
|
RelationType(name, node_type_column, node_type_row,
|
|
|
RelationType(name, node_type_column, node_type_row,
|
|
|
adjacency_matrix_backward, True))
|
|
|
|
|
|
|
|
|
adjacency_matrix_backward, hints))
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
def __repr__(self):
|
|
|
n = len(self.node_types)
|
|
|
n = len(self.node_types)
|
|
@@ -114,7 +117,7 @@ class Data(object): |
|
|
self.node_types[node_type_column].name + ':\n'
|
|
|
self.node_types[node_type_column].name + ':\n'
|
|
|
|
|
|
|
|
|
for r in self.relation_types[node_type_row, node_type_column]:
|
|
|
for r in self.relation_types[node_type_row, node_type_column]:
|
|
|
if r.is_autogenerated:
|
|
|
|
|
|
|
|
|
if not r.hints.get('display', True):
|
|
|
continue
|
|
|
continue
|
|
|
s_1 += ' - ' + r.name + '\n'
|
|
|
s_1 += ' - ' + r.name + '\n'
|
|
|
count += 1
|
|
|
count += 1
|
|
|