|
@@ -5,6 +5,37 @@ from .data import Data |
|
|
from .trainprep import PreparedData
|
|
|
from .trainprep import PreparedData
|
|
|
import torch
|
|
|
import torch
|
|
|
from .weights import init_glorot
|
|
|
from .weights import init_glorot
|
|
|
|
|
|
from .normalize import _sparse_coo_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sparse_diag_cat(matrices: List[torch.Tensor]):
|
|
|
|
|
|
if len(matrices) == 0:
|
|
|
|
|
|
raise ValueError('The list of matrices must be non-empty')
|
|
|
|
|
|
|
|
|
|
|
|
if not all(m.is_sparse for m in matrices):
|
|
|
|
|
|
raise ValueError('All matrices must be sparse')
|
|
|
|
|
|
|
|
|
|
|
|
if not all(len(m.shape) == 2 for m in matrices):
|
|
|
|
|
|
raise ValueError('All matrices must be 2D')
|
|
|
|
|
|
|
|
|
|
|
|
indices = []
|
|
|
|
|
|
values = []
|
|
|
|
|
|
row_offset = 0
|
|
|
|
|
|
col_offset = 0
|
|
|
|
|
|
|
|
|
|
|
|
for m in matrices:
|
|
|
|
|
|
ind = m._indices().clone()
|
|
|
|
|
|
ind[0] += row_offset
|
|
|
|
|
|
ind[1] += col_offset
|
|
|
|
|
|
indices.append(ind)
|
|
|
|
|
|
values.append(m._values())
|
|
|
|
|
|
row_offset += m.shape[0]
|
|
|
|
|
|
col_offset += m.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
indices = torch.cat(indices, dim=1)
|
|
|
|
|
|
values = torch.cat(values)
|
|
|
|
|
|
|
|
|
|
|
|
return _sparse_coo_tensor(indices, values, size=(row_offset, col_offset))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cat(matrices: List[torch.Tensor]):
|
|
|
def _cat(matrices: List[torch.Tensor]):
|
|
@@ -79,7 +110,7 @@ class FastConvLayer(torch.nn.Module): |
|
|
output_dim: List[int],
|
|
|
output_dim: List[int],
|
|
|
data: Union[Data, PreparedData],
|
|
|
data: Union[Data, PreparedData],
|
|
|
keep_prob: float = 1.,
|
|
|
keep_prob: float = 1.,
|
|
|
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x
|
|
|
|
|
|
|
|
|
rel_activation: Callable[[torch.Tensor], torch.Tensor] = lambda x: x,
|
|
|
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
|
|
|
layer_activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu,
|
|
|
**kwargs):
|
|
|
**kwargs):
|
|
|
|
|
|
|
|
|