|
|
@@ -8,7 +8,12 @@ from typing import List, \ |
|
|
|
Dict, \
|
|
|
|
Callable, \
|
|
|
|
Tuple
|
|
|
|
from .util import _sparse_coo_tensor
|
|
|
|
from .util import _sparse_coo_tensor, \
|
|
|
|
_sparse_diag_cat, \
|
|
|
|
_mm
|
|
|
|
from .normalize import norm_adj_mat_one_node_type, \
|
|
|
|
norm_adj_mat_two_node_types
|
|
|
|
from .dropout import dropout
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
@@ -19,6 +24,44 @@ class TrainingBatch(object): |
|
|
|
edges: torch.Tensor
|
|
|
|
|
|
|
|
|
|
|
|
def _per_layer_required_vertices(data: Data, batch: TrainingBatch,
|
|
|
|
num_layers: int) -> List[List[EdgeType]]:
|
|
|
|
|
|
|
|
Q = [
|
|
|
|
( batch.vertex_type_row, batch.edges[:, 0] ),
|
|
|
|
( batch.vertex_type_column, batch.edges[:, 1] )
|
|
|
|
]
|
|
|
|
print('Q:', Q)
|
|
|
|
res = []
|
|
|
|
|
|
|
|
for _ in range(num_layers):
|
|
|
|
R = []
|
|
|
|
required_rows = [ [] for _ in range(len(data.vertex_types)) ]
|
|
|
|
|
|
|
|
for vertex_type, vertices in Q:
|
|
|
|
for et in data.edge_types.values():
|
|
|
|
if et.vertex_type_row == vertex_type:
|
|
|
|
required_rows[vertex_type].append(vertices)
|
|
|
|
indices = et.total_connectivity.indices()
|
|
|
|
mask = torch.zeros(et.total_connectivity.shape[0])
|
|
|
|
mask[vertices] = 1
|
|
|
|
mask = torch.nonzero(mask[indices[0]], as_tuple=True)[0]
|
|
|
|
R.append((et.vertex_type_column,
|
|
|
|
indices[1, mask]))
|
|
|
|
else:
|
|
|
|
pass # required_rows[et.vertex_type_row].append(torch.zeros(0))
|
|
|
|
|
|
|
|
required_rows = [ torch.unique(torch.cat(x)) \
|
|
|
|
if len(x) > 0 \
|
|
|
|
else None \
|
|
|
|
for x in required_rows ]
|
|
|
|
|
|
|
|
res.append(required_rows)
|
|
|
|
Q = R
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
class Model(torch.nn.Module):
|
|
|
|
def __init__(self, data: Data, layer_dimensions: List[int],
|
|
|
|
keep_prob: float,
|
|
|
@@ -30,11 +73,11 @@ class Model(torch.nn.Module): |
|
|
|
if not isinstance(data, Data):
|
|
|
|
raise TypeError('data must be an instance of Data')
|
|
|
|
|
|
|
|
if not isinstance(conv_activation, types.FunctionType):
|
|
|
|
raise TypeError('conv_activation must be a function')
|
|
|
|
if not callable(conv_activation):
|
|
|
|
raise TypeError('conv_activation must be callable')
|
|
|
|
|
|
|
|
if not isinstance(dec_activation, types.FunctionType):
|
|
|
|
raise TypeError('dec_activation must be a function')
|
|
|
|
if not callable(dec_activation):
|
|
|
|
raise TypeError('dec_activation must be callable')
|
|
|
|
|
|
|
|
self.data = data
|
|
|
|
self.layer_dimensions = list(layer_dimensions)
|
|
|
@@ -42,35 +85,90 @@ class Model(torch.nn.Module): |
|
|
|
self.conv_activation = conv_activation
|
|
|
|
self.dec_activation = dec_activation
|
|
|
|
|
|
|
|
self.adj_matrices = None
|
|
|
|
self.conv_weights = None
|
|
|
|
self.dec_weights = None
|
|
|
|
self.build()
|
|
|
|
|
|
|
|
|
|
|
|
def build(self) -> None:
|
|
|
|
self.adj_matrices = torch.nn.ParameterDict()
|
|
|
|
for _, et in self.data.edge_types.items():
|
|
|
|
adj_matrices = [
|
|
|
|
norm_adj_mat_one_node_type(x) \
|
|
|
|
if et.vertex_type_row == et.vertex_type_column \
|
|
|
|
else norm_adj_mat_two_node_types(x) \
|
|
|
|
for x in et.adjacency_matrices
|
|
|
|
]
|
|
|
|
adj_matrices = _sparse_diag_cat(et.adjacency_matrices)
|
|
|
|
print('adj_matrices:', adj_matrices)
|
|
|
|
self.adj_matrices['%d-%d' % (et.vertex_type_row, et.vertex_type_column)] = \
|
|
|
|
torch.nn.Parameter(adj_matrices, requires_grad=False)
|
|
|
|
|
|
|
|
self.conv_weights = torch.nn.ParameterDict()
|
|
|
|
for i in range(len(self.layer_dimensions) - 1):
|
|
|
|
in_dimension = self.layer_dimensions[i]
|
|
|
|
out_dimension = self.layer_dimensions[i + 1]
|
|
|
|
|
|
|
|
for _, et in self.data.edge_types.items():
|
|
|
|
weight = init_glorot(in_dimension, out_dimension)
|
|
|
|
self.conv_weights[et.vertex_type_row, et.vertex_type_column, i] = \
|
|
|
|
torch.nn.Parameter(weight)
|
|
|
|
weights = [ init_glorot(in_dimension, out_dimension) \
|
|
|
|
for _ in range(len(et.adjacency_matrices)) ]
|
|
|
|
weights = torch.cat(weights, dim=1)
|
|
|
|
self.conv_weights['%d-%d-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \
|
|
|
|
torch.nn.Parameter(weights)
|
|
|
|
|
|
|
|
self.dec_weights = torch.nn.ParameterDict()
|
|
|
|
for _, et in self.data.edge_types.items():
|
|
|
|
global_interaction, local_variation = \
|
|
|
|
et.decoder_factory(self.layer_dimensions[-1],
|
|
|
|
len(et.adjacency_matrices))
|
|
|
|
self.dec_weights[et.vertex_type_row, et.vertex_type_column] = \
|
|
|
|
torch.nn.ParameterList([
|
|
|
|
torch.nn.Parameter(global_interaction),
|
|
|
|
torch.nn.Parameter(local_variation)
|
|
|
|
])
|
|
|
|
self.dec_weights['%d-%d-global-interaction' % (et.vertex_type_row, et.vertex_type_column)] = \
|
|
|
|
torch.nn.Parameter(global_interaction)
|
|
|
|
for i in range(len(local_variation)):
|
|
|
|
self.dec_weights['%d-%d-local-variation-%d' % (et.vertex_type_row, et.vertex_type_column, i)] = \
|
|
|
|
torch.nn.Parameter(local_variation[i])
|
|
|
|
|
|
|
|
|
|
|
|
def convolve(self, in_layer_repr: List[torch.Tensor]) -> \
|
|
|
|
List[torch.Tensor]:
|
|
|
|
|
|
|
|
def convolve(self, batch: TrainingBatch) -> List[torch.Tensor]:
|
|
|
|
cur_layer_repr = in_layer_repr
|
|
|
|
next_layer_repr = [ None ] * len(self.data.vertex_types)
|
|
|
|
|
|
|
|
for i in range(len(self.layer_dimensions) - 1):
|
|
|
|
for _, et in self.data.edge_types.items():
|
|
|
|
vt_row, vt_col = et.vertex_type_row, et.vertex_type_column
|
|
|
|
adj_matrices = self.adj_matrices['%d-%d' % (vt_row, vt_col)]
|
|
|
|
conv_weights = self.conv_weights['%d-%d-%d' % (vt_row, vt_col, i)]
|
|
|
|
|
|
|
|
num_relation_types = len(et.adjacency_matrices)
|
|
|
|
x = cur_layer_repr[vt_col]
|
|
|
|
if self.keep_prob != 1:
|
|
|
|
x = dropout(x, self.keep_prob)
|
|
|
|
|
|
|
|
print('a, Layer:', i, 'x.shape:', x.shape)
|
|
|
|
|
|
|
|
x = _mm(x, conv_weights)
|
|
|
|
x = torch.split(x,
|
|
|
|
x.shape[1] // num_relation_types,
|
|
|
|
dim=1)
|
|
|
|
x = torch.cat(x)
|
|
|
|
x = _mm(adj_matrices, x)
|
|
|
|
x = x.view(num_relation_types,
|
|
|
|
self.data.vertex_types[vt_row].count,
|
|
|
|
self.layer_dimensions[i + 1])
|
|
|
|
|
|
|
|
print('b, Layer:', i, 'x.shape:', x.shape)
|
|
|
|
|
|
|
|
x = x.sum(dim=0)
|
|
|
|
x = torch.nn.functional.normalize(x, p=2, dim=1)
|
|
|
|
x = self.conv_activation(x)
|
|
|
|
|
|
|
|
next_layer_repr[vt_row] = x
|
|
|
|
cur_layer_repr = next_layer_repr
|
|
|
|
return next_layer_repr
|
|
|
|
|
|
|
|
def convolve_old(self, batch: TrainingBatch) -> List[torch.Tensor]:
|
|
|
|
edges = []
|
|
|
|
cur_edges = batch.edges
|
|
|
|
for _ in range(len(self.layer_dimensions) - 1):
|
|
|
|