IF YOU WOULD LIKE TO GET AN ACCOUNT, please write an email to s dot adaszewski at gmail dot com. User accounts are meant only to report issues and/or generate pull requests. This is a purpose-specific Git hosting for ADARED projects. Thank you for your understanding!
Bläddra i källkod

Work on icosagon.trainprep.

master
Stanislaw Adaszewski 4 år sedan
förälder
incheckning
7ed4bc373a
7 ändrade filer med 499 tillägg och 6 borttagningar
  1. +252
    -6
      docs/decagon-diagram.svg
  2. +6
    -0
      src/icosagon/__init__.py
  3. +6
    -0
      src/icosagon/data.py
  4. +29
    -0
      src/icosagon/normalize.py
  5. +42
    -0
      src/icosagon/sampling.py
  6. +106
    -0
      src/icosagon/trainprep.py
  7. +58
    -0
      tests/icosagon/test_trainprep.py

+ 252
- 6
docs/decagon-diagram.svg Visa fil

@@ -18,6 +18,94 @@
sodipodi:docname="decagon-diagram.svg">
<defs
id="defs2">
<marker
inkscape:isstock="true"
style="overflow:visible;"
id="marker18999"
refX="0.0"
refY="0.0"
orient="auto"
inkscape:stockid="Arrow1Lend">
<path
transform="scale(0.8) rotate(180) translate(12.5,0)"
style="fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1;fill:#000000;fill-opacity:1"
d="M 0.0,0.0 L 5.0,-5.0 L -12.5,0.0 L 5.0,5.0 L 0.0,0.0 z "
id="path18997" />
</marker>
<marker
inkscape:stockid="Arrow1Lend"
orient="auto"
refY="0.0"
refX="0.0"
id="marker18243"
style="overflow:visible;"
inkscape:isstock="true"
inkscape:collect="always">
<path
id="path18241"
d="M 0.0,0.0 L 5.0,-5.0 L -12.5,0.0 L 5.0,5.0 L 0.0,0.0 z "
style="fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1;fill:#000000;fill-opacity:1"
transform="scale(0.8) rotate(180) translate(12.5,0)" />
</marker>
<marker
inkscape:isstock="true"
style="overflow:visible;"
id="marker17673"
refX="0.0"
refY="0.0"
orient="auto"
inkscape:stockid="Arrow1Lend"
inkscape:collect="always">
<path
transform="scale(0.8) rotate(180) translate(12.5,0)"
style="fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1;fill:#000000;fill-opacity:1"
d="M 0.0,0.0 L 5.0,-5.0 L -12.5,0.0 L 5.0,5.0 L 0.0,0.0 z "
id="path17671" />
</marker>
<marker
inkscape:stockid="Arrow1Lend"
orient="auto"
refY="0.0"
refX="0.0"
id="marker17199"
style="overflow:visible;"
inkscape:isstock="true"
inkscape:collect="always">
<path
id="path17197"
d="M 0.0,0.0 L 5.0,-5.0 L -12.5,0.0 L 5.0,5.0 L 0.0,0.0 z "
style="fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1;fill:#000000;fill-opacity:1"
transform="scale(0.8) rotate(180) translate(12.5,0)" />
</marker>
<marker
inkscape:isstock="true"
style="overflow:visible;"
id="marker16965"
refX="0.0"
refY="0.0"
orient="auto"
inkscape:stockid="Arrow1Lend"
inkscape:collect="always">
<path
transform="scale(0.8) rotate(180) translate(12.5,0)"
style="fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1;fill:#000000;fill-opacity:1"
d="M 0.0,0.0 L 5.0,-5.0 L -12.5,0.0 L 5.0,5.0 L 0.0,0.0 z "
id="path16963" />
</marker>
<marker
inkscape:isstock="true"
style="overflow:visible;"
id="marker16677"
refX="0.0"
refY="0.0"
orient="auto"
inkscape:stockid="Arrow1Lend">
<path
transform="scale(0.8) rotate(180) translate(12.5,0)"
style="fill-rule:evenodd;stroke:#000000;stroke-width:1pt;stroke-opacity:1;fill:#000000;fill-opacity:1"
d="M 0.0,0.0 L 5.0,-5.0 L -12.5,0.0 L 5.0,5.0 L 0.0,0.0 z "
id="path16675" />
</marker>
<marker
inkscape:stockid="Arrow1Lend"
orient="auto"
@@ -246,8 +334,8 @@
inkscape:pageopacity="0.0"
inkscape:pageshadow="2"
inkscape:zoom="0.98994949"
inkscape:cx="75.316161"
inkscape:cy="262.38784"
inkscape:cx="545.82299"
inkscape:cy="-315.17707"
inkscape:document-units="mm"
inkscape:current-layer="layer1"
showgrid="false"
@@ -257,10 +345,10 @@
fit-margin-left="0"
fit-margin-right="0"
fit-margin-bottom="0"
inkscape:window-width="1367"
inkscape:window-height="1080"
inkscape:window-x="0"
inkscape:window-y="0"
inkscape:window-width="1901"
inkscape:window-height="909"
inkscape:window-x="94"
inkscape:window-y="94"
inkscape:window-maximized="0" />
<metadata
id="metadata5">
@@ -1831,5 +1919,163 @@
x="-6.6224065"
id="tspan2683"
sodipodi:role="line">A'</tspan></text>
<circle
r="7.6171813"
cy="291.65302"
cx="243.53423"
id="circle16627"
style="opacity:1;vector-effect:none;fill:#b9b9b9;fill-opacity:1;stroke:#000000;stroke-width:0.66500002;stroke-linecap:butt;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
<text
id="text16631"
y="294.96997"
x="240.43907"
style="font-style:normal;font-weight:normal;font-size:9.48118305px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.23702957"
xml:space="preserve"><tspan
style="stroke-width:0.23702957"
y="294.96997"
x="240.43907"
id="tspan16629"
sodipodi:role="line">1</tspan></text>
<circle
style="opacity:1;vector-effect:none;fill:#b9b9b9;fill-opacity:1;stroke:#000000;stroke-width:0.66500002;stroke-linecap:butt;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
id="circle16633"
cx="243.8015"
cy="317.04361"
r="7.6171813" />
<text
xml:space="preserve"
style="font-style:normal;font-weight:normal;font-size:9.48118305px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.23702957"
x="240.70634"
y="320.36057"
id="text16637"><tspan
sodipodi:role="line"
id="tspan16635"
x="240.70634"
y="320.36057"
style="stroke-width:0.23702957">2</tspan></text>
<circle
r="7.6171813"
cy="303.41287"
cx="267.5885"
id="circle16639"
style="opacity:1;vector-effect:none;fill:#b9b9b9;fill-opacity:1;stroke:#000000;stroke-width:0.66500002;stroke-linecap:butt;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1" />
<text
id="text16643"
y="306.72983"
x="264.49335"
style="font-style:normal;font-weight:normal;font-size:9.48118305px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.23702957"
xml:space="preserve"><tspan
style="stroke-width:0.23702957"
y="306.72983"
x="264.49335"
id="tspan16641"
sodipodi:role="line">3</tspan></text>
<path
inkscape:connector-curvature="0"
id="path16645"
d="m 260.23859,299.00293 -9.6217,-4.81085"
style="fill:none;stroke:#000000;stroke-width:0.26458332px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker16965)" />
<path
sodipodi:nodetypes="cc"
style="fill:none;stroke:#000000;stroke-width:0.26458332px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker1703)"
d="m 243.40061,299.00293 0.26727,10.15624"
id="path16647"
inkscape:connector-curvature="0" />
<path
inkscape:connector-curvature="0"
id="path16649"
d="m 251.15143,314.23729 12.2944,-4.54358"
style="fill:none;stroke:#000000;stroke-width:0.26458332px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker16677)"
sodipodi:nodetypes="cc" />
<text
xml:space="preserve"
style="font-style:normal;font-weight:normal;font-size:47.54595947px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1.18864894"
x="160.11197"
y="279.72699"
id="text16653"
transform="scale(0.87333338,1.1450381)"><tspan
sodipodi:role="line"
id="tspan16651"
x="160.11197"
y="279.72699"
style="stroke-width:1.18864894">[</tspan></text>
<text
transform="scale(-0.87333338,1.1450381)"
id="text16657"
y="280.04086"
x="-240.0833"
style="font-style:normal;font-weight:normal;font-size:47.54595947px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1.18864894"
xml:space="preserve"><tspan
style="stroke-width:1.18864894"
y="280.04086"
x="-240.0833"
id="tspan16655"
sodipodi:role="line">[</tspan></text>
<flowRoot
xml:space="preserve"
id="flowRoot16669"
style="font-style:normal;font-weight:normal;font-size:40px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none"
transform="matrix(0.26458333,0,0,0.26458333,185.93767,111.06099)"><flowRegion
id="flowRegion16661"><rect
id="rect16659"
width="228.29446"
height="208.09143"
x="-117.1777"
y="631.53986" /></flowRegion><flowPara
id="flowPara16663">0 0 1 1</flowPara><flowPara
id="flowPara16665">1 0 0 0</flowPara><flowPara
id="flowPara16667">0 1 0 1</flowPara><flowPara
id="flowPara16955">0 1 1 0</flowPara></flowRoot> <text
xml:space="preserve"
style="font-style:normal;font-weight:normal;font-size:20.50874329px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.51271856"
x="159.35197"
y="274.14685"
id="text16673"><tspan
sodipodi:role="line"
id="tspan16671"
x="159.35197"
y="274.14685"
style="stroke-width:0.51271856">A</tspan></text>
<circle
style="opacity:1;vector-effect:none;fill:#b9b9b9;fill-opacity:1;stroke:#000000;stroke-width:0.66500002;stroke-linecap:butt;stroke-linejoin:miter;stroke-miterlimit:4;stroke-dasharray:none;stroke-dashoffset:0;stroke-opacity:1"
id="circle16957"
cx="272.9339"
cy="281.22949"
r="7.6171813" />
<text
xml:space="preserve"
style="font-style:normal;font-weight:normal;font-size:9.48118305px;line-height:1.25;font-family:sans-serif;letter-spacing:0px;word-spacing:0px;fill:#000000;fill-opacity:1;stroke:none;stroke-width:0.23702957"
x="269.83875"
y="284.54645"
id="text16961"><tspan
sodipodi:role="line"
id="tspan16959"
x="269.83875"
y="284.54645"
style="stroke-width:0.23702957">4</tspan></text>
<path
style="fill:none;stroke:#000000;stroke-width:0.26458332px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker17199)"
d="m 265.85125,282.69949 -15.50163,5.07812"
id="path17195"
inkscape:connector-curvature="0"
sodipodi:nodetypes="cc" />
<path
sodipodi:nodetypes="cc"
inkscape:connector-curvature="0"
id="path17669"
d="m 269.86029,288.31215 -2.67268,6.949"
style="fill:none;stroke:#000000;stroke-width:0.26458332px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker17673)" />
<path
style="fill:none;stroke:#000000;stroke-width:0.26458332px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker18243)"
d="m 270.39484,295.52842 2.13816,-6.68174"
id="path18239"
inkscape:connector-curvature="0"
sodipodi:nodetypes="cc" />
<path
style="fill:none;stroke:#000000;stroke-width:0.26458332px;stroke-linecap:butt;stroke-linejoin:miter;stroke-opacity:1;marker-end:url(#marker18999)"
d="m 247.40965,323.85898 c 26.32452,12.56583 51.55327,-22.64989 33.94323,-41.15949"
id="path18965"
inkscape:connector-curvature="0"
sodipodi:nodetypes="cc" />
</g>
</svg>

+ 6
- 0
src/icosagon/__init__.py Visa fil

@@ -1 +1,7 @@
#
# Copyright (C) Stanislaw Adaszewski, 2020
# License: GPLv3
#
from .data import Data

+ 6
- 0
src/icosagon/data.py Visa fil

@@ -1,3 +1,9 @@
#
# Copyright (C) Stanislaw Adaszewski, 2020
# License: GPLv3
#
from collections import defaultdict
from dataclasses import dataclass
import torch


+ 29
- 0
src/icosagon/normalize.py Visa fil

@@ -0,0 +1,29 @@
#
# Copyright (C) Stanislaw Adaszewski, 2020
# License: GPLv3
#
import numpy as np
import scipy.sparse as sp
def norm_adj_mat_one_node_type(adj):
adj = sp.coo_matrix(adj)
assert adj.shape[0] == adj.shape[1]
adj_ = adj + sp.eye(adj.shape[0])
rowsum = np.array(adj_.sum(1))
degree_mat_inv_sqrt = np.power(rowsum, -0.5).flatten()
degree_mat_inv_sqrt = sp.diags(degree_mat_inv_sqrt)
adj_normalized = adj_.dot(degree_mat_inv_sqrt).transpose().dot(degree_mat_inv_sqrt)
return adj_normalized
def norm_adj_mat_two_node_types(adj):
adj = sp.coo_matrix(adj)
rowsum = np.array(adj.sum(1))
colsum = np.array(adj.sum(0))
rowdegree_mat_inv = sp.diags(np.nan_to_num(np.power(rowsum, -0.5)).flatten())
coldegree_mat_inv = sp.diags(np.nan_to_num(np.power(colsum, -0.5)).flatten())
adj_normalized = rowdegree_mat_inv.dot(adj).dot(coldegree_mat_inv).tocoo()
return adj_normalized

+ 42
- 0
src/icosagon/sampling.py Visa fil

@@ -0,0 +1,42 @@
#
# Copyright (C) Stanislaw Adaszewski, 2020
# License: GPLv3
#
import numpy as np
import torch
import torch.utils.data
from typing import List, \
Union
def fixed_unigram_candidate_sampler(
true_classes: Union[np.array, torch.Tensor],
num_samples: int,
unigrams: List[Union[int, float]],
distortion: float = 1.):
if isinstance(true_classes, torch.Tensor):
true_classes = true_classes.detach().cpu().numpy()
if true_classes.shape[0] != num_samples:
raise ValueError('true_classes must be a 2D matrix with shape (num_samples, num_true)')
unigrams = np.array(unigrams)
if distortion != 1.:
unigrams = unigrams.astype(np.float64) ** distortion
# print('unigrams:', unigrams)
indices = np.arange(num_samples)
result = np.zeros(num_samples, dtype=np.int64)
while len(indices) > 0:
# print('len(indices):', len(indices))
sampler = torch.utils.data.WeightedRandomSampler(unigrams, len(indices))
candidates = np.array(list(sampler))
candidates = np.reshape(candidates, (len(indices), 1))
# print('candidates:', candidates)
# print('true_classes:', true_classes[indices, :])
result[indices] = candidates.T
mask = (candidates == true_classes[indices, :])
mask = mask.sum(1).astype(np.bool)
# print('mask:', mask)
indices = indices[mask]
return result

+ 106
- 0
src/icosagon/trainprep.py Visa fil

@@ -0,0 +1,106 @@
#
# Copyright (C) Stanislaw Adaszewski, 2020
# License: GPLv3
#
from .sampling import fixed_unigram_candidate_sampler
import torch
from dataclasses import dataclass
from typing import Any, \
List, \
Tuple, \
Dict
from .data import NodeType
from collections import defaultdict
@dataclass
class TrainValTest(object):
train: Any
val: Any
test: Any
@dataclass
class PreparedEdges(object):
positive: TrainValTest
negative: TrainValTest
@dataclass
class PreparedRelationType(object):
name: str
node_type_row: int
node_type_column: int
adj_mat_train: torch.Tensor
edges_pos: TrainValTest
edges_neg: TrainValTest
@dataclass
class PreparedData(object):
node_types: List[NodeType]
relation_types: Dict[int, Dict[int, List[PreparedRelationType]]]
def train_val_test_split_edges(edges: torch.Tensor,
ratios: TrainValTest) -> TrainValTest:
if not isinstance(edges, torch.Tensor):
raise ValueError('edges must be a torch.Tensor')
if len(edges.shape) != 2 or edges.shape[1] != 2:
raise ValueError('edges shape must be (num_edges, 2)')
if not isinstance(ratios, TrainValTest):
raise ValueError('ratios must be a TrainValTest')
if ratios.train + ratios.val + ratios.test != 1.0:
raise ValueError('Train, validation and test ratios must add up to 1')
order = torch.randperm(len(edges))
edges = edges[order, :]
n = round(len(edges) * ratios.train)
edges_train = edges[:n]
n_1 = round(len(edges) * (ratios.train + ratios.val))
edges_val = edges[n:n_1]
edges_test = edges[n_1:]
return TrainValTest(edges_train, edges_val, edges_test)
def prepare_adj_mat(adj_mat: torch.Tensor,
ratios: TrainValTest) -> Tuple[TrainValTest, TrainValTest]:
degrees = adj_mat.sum(0)
edges_pos = torch.nonzero(adj_mat)
neg_neighbors = fixed_unigram_candidate_sampler(edges_pos[:, 1],
len(edges), degrees, 0.75)
edges_neg = torch.cat((edges_pos[:, 0], neg_neighbors.view(-1, 1)), 1)
edges_pos = train_val_test_split_edges(edges_pos, ratios)
edges_neg = train_val_test_split_edges(edges_neg, ratios)
return edges_pos, edges_neg
def prepare_relation(r, ratios):
adj_mat = r.adjacency_matrix
edges_pos, edges_neg = prepare_adj_mat(adj_mat)
adj_mat_train = torch.sparse_coo_tensor(indices = edges_pos[0].transpose(0, 1),
values=torch.ones(len(edges_pos[0]), dtype=adj_mat.dtype))
return PreparedRelation(r.name, r.node_type_row, r.node_type_column,
adj_mat_train, edges_pos, edges_neg)
def prepare_training(data):
relation_types = defaultdict(lambda: defaultdict(list))
for (node_type_row, node_type_column), rels in data.relation_types:
for r in rels:
relation_types[node_type_row][node_type_column].append(
prep_relation(r))
return PreparedData(data.node_types, relation_types)

+ 58
- 0
tests/icosagon/test_trainprep.py Visa fil

@@ -0,0 +1,58 @@
from icosagon.trainprep import TrainValTest, \
train_val_test_split_edges
import torch
import pytest
import numpy as np
def test_train_val_test_split_edges_01():
edges = torch.randint(0, 10, (10, 2))
with pytest.raises(ValueError):
_ = train_val_test_split_edges(edges, TrainValTest(.5, .5, .5))
with pytest.raises(ValueError):
_ = train_val_test_split_edges(edges, TrainValTest(.2, .2, .2))
with pytest.raises(ValueError):
_ = train_val_test_split_edges(edges, None)
with pytest.raises(ValueError):
_ = train_val_test_split_edges(edges, (.8, .1, .1))
with pytest.raises(ValueError):
_ = train_val_test_split_edges(np.random.randint(0, 10, (10, 2)), TrainValTest(.8, .1, .1))
with pytest.raises(ValueError):
_ = train_val_test_split_edges(torch.randint(0, 10, (10, 3)), TrainValTest(.8, .1, .1))
with pytest.raises(ValueError):
_ = train_val_test_split_edges(torch.randint(0, 10, (10, 2, 1)), TrainValTest(.8, .1, .1))
with pytest.raises(ValueError):
_ = train_val_test_split_edges(None, TrainValTest(.8, .2, .2))
res = train_val_test_split_edges(edges, TrainValTest(.8, .1, .1))
assert res.train.shape == (8, 2) and res.val.shape == (1, 2) and \
res.test.shape == (1, 2)
res = train_val_test_split_edges(edges, TrainValTest(.8, .0, .2))
assert res.train.shape == (8, 2) and res.val.shape == (0, 2) and \
res.test.shape == (2, 2)
res = train_val_test_split_edges(edges, TrainValTest(.8, .2, .0))
assert res.train.shape == (8, 2) and res.val.shape == (2, 2) and \
res.test.shape == (0, 2)
res = train_val_test_split_edges(edges, TrainValTest(.0, .5, .5))
assert res.train.shape == (0, 2) and res.val.shape == (5, 2) and \
res.test.shape == (5, 2)
res = train_val_test_split_edges(edges, TrainValTest(.0, .0, 1.))
assert res.train.shape == (0, 2) and res.val.shape == (0, 2) and \
res.test.shape == (10, 2)
res = train_val_test_split_edges(edges, TrainValTest(.0, 1., .0))
assert res.train.shape == (0, 2) and res.val.shape == (10, 2) and \
res.test.shape == (0, 2)
# if ratios.train + ratios.val + ratios.test != 1.0:
# raise ValueError('Train, validation and test ratios must add up to 1')
#
# order = torch.randperm(len(edges))
# edges = edges[order, :]
# n = round(len(edges) * ratios.train)
# edges_train = edges[:n]
# n_1 = round(len(edges) * (ratios.train + ratios.val))
# edges_val = edges[n:n_1]
# edges_test = edges[n_1:]
#
# return TrainValTest(edges_train, edges_val, edges_test)

Laddar…
Avbryt
Spara