IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Ver código fonte

Add first test for declayer.

master
Stanislaw Adaszewski 4 anos atrás
pai
commit
f51855e5b3
1 arquivos alterados com 8 adições e 25 exclusões
  1. +8
    -25
      src/icosagon/declayer.py

+ 8
- 25
src/icosagon/declayer.py Ver arquivo

@@ -22,8 +22,8 @@ class DecodeLayer(torch.nn.Module):
input_dim: List[int],
data: Union[Data, PreparedData],
keep_prob: float = 1.,
activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
decoder_class: Union[Type, Dict[Tuple[int, int], Type]] = DEDICOMDecoder,
activation: Callable[[torch.Tensor], torch.Tensor] = torch.sigmoid,
**kwargs) -> None:
super().__init__(**kwargs)
@@ -35,9 +35,9 @@ class DecodeLayer(torch.nn.Module):
self.output_dim = 1
self.data = data
self.keep_prob = keep_prob
self.decoder_class = decoder_class
self.activation = activation
self.decoder_class = decoder_class
self.decoders = None
self.build()
@@ -45,6 +45,7 @@ class DecodeLayer(torch.nn.Module):
self.decoders = {}
n = len(self.data.node_types)
relation_types = self.data.relation_types
for node_type_row in range(n):
if node_type_row not in relation_types:
continue
@@ -70,34 +71,16 @@ class DecodeLayer(torch.nn.Module):
decoder_class = self.decoder_class
self.decoders[node_type_row, node_type_column] = \
decoder_class(self.input_dim,
decoder_class(self.input_dim[node_type_row],
num_relation_types = len(rels),
drop_prob = 1. - self.keep_prob,
keep_prob = self.keep_prob,
activation = self.activation)
def forward(self, last_layer_repr: List[torch.Tensor]) -> TrainValTest:
# n = len(self.data.node_types)
# relation_types = self.data.relation_types
# for node_type_row in range(n):
# if node_type_row not in relation_types:
# continue
#
# for node_type_column in range(n):
# if node_type_column not in relation_types[node_type_row]:
# continue
#
# rels = relation_types[node_type_row][node_type_column]
#
# for mode in ['train', 'val', 'test']:
# getattr(relation_types[node_type_row][node_type_column].edges_pos, mode)
# getattr(self.data.edges_neg, mode)
# last_layer[]
def forward(self, last_layer_repr: List[torch.Tensor]) -> Dict[Tuple[int, int], List[torch.Tensor]]:
res = {}
for (node_type_row, node_type_column), dec in self.decoders.items():
inputs_row = last_layer_repr[node_type_row]
inputs_column = last_layer_repr[node_type_column]
pred_adj_matrices = dec(inputs_row, inputs_col)
res[node_type_row, node_type_col] = pred_adj_matrices
pred_adj_matrices = dec(inputs_row, inputs_column)
res[node_type_row, node_type_column] = pred_adj_matrices
return res

Carregando…
Cancelar
Salvar