IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Przeglądaj źródła

Use hints instead of is_autogenerated.

master
Stanislaw Adaszewski 4 lat temu
rodzic
commit
15e95bdc72
2 zmienionych plików z 10 dodań i 7 usunięć
  1. +9
    -6
      src/icosagon/data.py
  2. +1
    -1
      src/icosagon/declayer.py

+ 9
- 6
src/icosagon/data.py Wyświetl plik

@@ -5,11 +5,12 @@
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, field
import torch
from typing import List, \
Dict, \
Tuple
Tuple, \
Any
@dataclass
@@ -24,7 +25,7 @@ class RelationType(object):
node_type_row: int
node_type_column: int
adjacency_matrix: torch.Tensor
is_autogenerated: bool = False
hints: Dict[str, Any] = field(default_factory=dict)
class Data(object):
@@ -81,14 +82,16 @@ class Data(object):
self.relation_types[node_type_row, node_type_column].append(
RelationType(name, node_type_row, node_type_column,
adjacency_matrix, False))
adjacency_matrix))
if node_type_row != node_type_column and two_way:
hints = { 'display': False }
if adjacency_matrix_backward is None:
adjacency_matrix_backward = adjacency_matrix.transpose(0, 1)
hints['symmetric'] = True
self.relation_types[node_type_column, node_type_row].append(
RelationType(name, node_type_column, node_type_row,
adjacency_matrix_backward, True))
adjacency_matrix_backward, hints))
def __repr__(self):
n = len(self.node_types)
@@ -114,7 +117,7 @@ class Data(object):
self.node_types[node_type_column].name + ':\n'
for r in self.relation_types[node_type_row, node_type_column]:
if r.is_autogenerated:
if not r.hints.get('display', True):
continue
s_1 += ' - ' + r.name + '\n'
count += 1


+ 1
- 1
src/icosagon/declayer.py Wyświetl plik

@@ -71,6 +71,6 @@ class DecodeLayer(torch.nn.Module):
for (node_type_row, node_type_column), dec in self.decoders.items():
inputs_row = last_layer_repr[node_type_row]
inputs_column = last_layer_repr[node_type_column]
pred_adj_matrices = dec(inputs_row, inputs_column)
pred_adj_matrices = [ dec(inputs_row, inputs_column, k) for k in range(dec.num_relation_types) ]
res[node_type_row, node_type_column] = pred_adj_matrices
return res

Ładowanie…
Anuluj
Zapisz