.
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | CapsGNN
2 | ==================
3 | [](https://paperswithcode.com/sota/graph-classification-on-re-m5k?p=capsule-graph-neural-network) [](https://codebeat.co/projects/github-com-benedekrozemberczki-capsgnn-master) [](https://github.com/benedekrozemberczki/CapsGNN/archive/master.zip) [](https://twitter.com/intent/follow?screen_name=benrozemberczki)
4 |
5 | A **PyTorch** implementation of **Capsule Graph Neural Network (ICLR 2019).**
6 |
7 |
8 |
9 |
10 | ### Abstract
11 |
12 | The high-quality node embeddings learned from the Graph Neural Networks (GNNs) have been applied to a wide range of node-based applications and some of them have achieved state-of-the-art (SOTA) performance. However, when applying node embeddings learned from GNNs to generate graph embeddings, the scalar node representation may not suffice to preserve the node/graph properties efficiently, resulting in sub-optimal graph embeddings. Inspired by the Capsule Neural Network (CapsNet), we propose the Capsule Graph Neural Network (CapsGNN), which adopts the concept of capsules to address the weakness in existing GNN-based graph embeddings algorithms. By extracting node features in the form of capsules, routing mechanism can be utilized to capture important information at the graph level. As a result, our model generates multiple embeddings for each graph to capture graph properties from different aspects. The attention module incorporated in CapsGNN is used to tackle graphs with various sizes which also enables the model to focus on critical parts of the graphs. Our extensive evaluations with 10 graph-structured datasets demonstrate that CapsGNN has a powerful mechanism that operates to capture macroscopic properties of the whole graph by data-driven. It outperforms other SOTA techniques on several graph classification tasks, by virtue of the new instrument.
13 |
14 | This repository provides a PyTorch implementation of CapsGNN as described in the paper:
15 |
16 | > Capsule Graph Neural Network.
17 | > Zhang Xinyi, Lihui Chen.
18 | > ICLR, 2019.
19 | > [[Paper]](https://openreview.net/forum?id=Byl8BnRcYm)
20 |
21 | The core Capsule Neural Network implementation adapted is available [[here]](https://github.com/timomernick/pytorch-capsule).
22 |
23 | ### Requirements
24 | The codebase is implemented in Python 3.5.2. package versions used for development are just below.
25 | ```
26 | networkx 2.4
27 | tqdm 4.28.1
28 | numpy 1.15.4
29 | pandas 0.23.4
30 | texttable 1.5.0
31 | scipy 1.1.0
32 | argparse 1.1.0
33 | torch 1.1.0
34 | torch-scatter 1.4.0
35 | torch-sparse 0.4.3
36 | torch-cluster 1.4.5
37 | torch-geometric 1.3.2
38 | torchvision 0.3.0
39 | ```
40 | ### Datasets
41 |
42 | The code takes graphs for training from an input folder where each graph is stored as a JSON. Graphs used for testing are also stored as JSON files. Every node id and node label has to be indexed from 0. Keys of dictionaries are stored strings in order to make JSON serialization possible.
43 |
44 | Every JSON file has the following key-value structure:
45 |
46 | ```javascript
47 | {"edges": [[0, 1],[1, 2],[2, 3],[3, 4]],
48 | "labels": {"0": "A", "1": "B", "2": "C", "3": "A", "4": "B"},
49 | "target": 1}
50 | ```
51 |
52 | The **edges** key has an edge list value which descibes the connectivity structure. The **labels** key has labels for each node which are stored as a dictionary -- within this nested dictionary labels are values, node identifiers are keys. The **target** key has an integer value which is the class membership.
53 |
54 | ### Outputs
55 |
56 | The predictions are saved in the `output/` directory. Each embedding has a header and a column with the graph identifiers. Finally, the predictions are sorted by the identifier column.
57 |
58 | ### Options
59 |
60 | Training a CapsGNN model is handled by the `src/main.py` script which provides the following command line arguments.
61 |
62 | #### Input and output options
63 | ```
64 | --training-graphs STR Training graphs folder. Default is `input/train/`.
65 | --testing-graphs STR Testing graphs folder. Default is `input/test/`.
66 | --prediction-path STR Output predictions file. Default is `output/watts_predictions.csv`.
67 | ```
68 | #### Model options
69 | ```
70 | --epochs INT Number of epochs. Default is 100.
71 | --batch-size INT Number fo graphs per batch. Default is 32.
72 | --gcn-filters INT Number of filters in GCNs. Default is 20.
73 | --gcn-layers INT Number of GCNs chained together. Default is 2.
74 | --inner-attention-dimension INT Number of neurons in attention. Default is 20.
75 | --capsule-dimensions INT Number of capsule neurons. Default is 8.
76 | --number-of-capsules INT Number of capsules in layer. Default is 8.
77 | --weight-decay FLOAT Weight decay of Adam. Defatuls is 10^-6.
78 | --lambd FLOAT Regularization parameter. Default is 0.5.
79 | --theta FLOAT Reconstruction loss weight. Default is 0.1.
80 | --learning-rate FLOAT Adam learning rate. Default is 0.01.
81 | ```
82 | ### Examples
83 | The following commands learn a model and save the predictions. Training a model on the default dataset:
84 | ```sh
85 | $ python src/main.py
86 | ```
87 |
88 |
89 |
90 |
91 | Training a CapsGNNN model for a 100 epochs.
92 | ```sh
93 | $ python src/main.py --epochs 100
94 | ```
95 |
96 | Changing the batch size.
97 |
98 | ```sh
99 | $ python src/main.py --batch-size 128
100 | ```
101 | ----------------------
102 |
103 | **License**
104 |
105 | - [GNU License](https://github.com/benedekrozemberczki/CapsGNN/blob/master/LICENSE)
106 |
--------------------------------------------------------------------------------
/capsgnn.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/benedekrozemberczki/CapsGNN/e665c3c78bcee01f9814c885fea27b5c32c0f467/capsgnn.gif
--------------------------------------------------------------------------------
/input/input.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/benedekrozemberczki/CapsGNN/e665c3c78bcee01f9814c885fea27b5c32c0f467/input/input.zip
--------------------------------------------------------------------------------
/input/test/0.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 22], [0, 13], [1, 18], [1, 20], [1, 6], [2, 20], [2, 12], [2, 21], [3, 16], [3, 19], [3, 5], [4, 14], [4, 6], [5, 9], [5, 6], [6, 18], [6, 17], [6, 21], [7, 8], [7, 26], [7, 18], [8, 20], [8, 9], [8, 13], [8, 15], [9, 19], [9, 27], [9, 23], [9, 12], [9, 28], [10, 27], [10, 12], [10, 21], [11, 26], [11, 28], [11, 13], [12, 19], [12, 22], [13, 20], [13, 28], [14, 19], [14, 25], [14, 15], [15, 26], [15, 28], [16, 17], [16, 18], [16, 24], [17, 19], [17, 25], [18, 20], [19, 21], [21, 22], [21, 23], [22, 28], [23, 24], [24, 25]], "target": 2, "labels": {"0": "3", "1": "4", "2": "3", "3": "3", "4": "2", "5": "3", "6": "6", "7": "3", "8": "5", "9": "7", "10": "3", "11": "3", "12": "5", "13": "5", "14": "4", "15": "4", "16": "4", "17": "4", "18": "5", "19": "6", "20": "5", "21": "6", "22": "4", "23": "3", "24": "3", "25": "3", "26": "3", "27": "2", "28": "5"}}
--------------------------------------------------------------------------------
/input/test/1.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 16], [0, 1], [0, 5], [0, 17], [0, 15], [0, 11], [0, 12], [0, 13], [1, 11], [1, 18], [1, 14], [2, 16], [2, 18], [2, 3], [2, 7], [2, 8], [3, 6], [3, 7], [3, 9], [3, 10], [3, 12], [4, 5], [4, 6], [4, 13], [4, 14], [5, 17], [5, 8], [5, 12], [6, 7], [6, 10], [6, 11], [6, 13], [7, 8], [7, 9], [7, 10], [7, 15], [8, 18], [8, 9], [8, 13], [9, 14], [10, 16], [10, 18], [10, 13], [10, 15], [11, 17], [11, 13], [11, 14], [12, 18], [12, 14], [13, 16], [13, 14], [14, 17], [14, 18], [15, 16], [16, 17], [16, 18], [17, 18]], "target": 2, "labels": {"0": "8", "1": "4", "2": "5", "3": "6", "4": "4", "5": "5", "6": "6", "7": "7", "8": "6", "9": "4", "10": "7", "11": "6", "12": "5", "13": "8", "14": "8", "15": "4", "16": "7", "17": "6", "18": "8"}}
--------------------------------------------------------------------------------
/input/test/10.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 2], [0, 5], [1, 2], [1, 4], [1, 6], [2, 3], [2, 4], [2, 6], [3, 4], [3, 5], [4, 5], [4, 6], [5, 6]], "target": 0, "labels": {"0": "3", "1": "4", "2": "5", "3": "3", "4": "5", "5": "4", "6": "4"}}
--------------------------------------------------------------------------------
/input/test/11.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 34], [0, 2], [0, 33], [1, 34], [1, 19], [1, 22], [2, 3], [2, 4], [3, 5], [3, 33], [3, 15], [4, 21], [4, 6], [5, 16], [5, 21], [6, 20], [6, 13], [6, 21], [7, 8], [7, 9], [8, 18], [8, 19], [8, 14], [9, 11], [9, 28], [10, 17], [10, 11], [11, 12], [11, 33], [12, 27], [12, 14], [13, 23], [14, 16], [14, 15], [15, 30], [16, 18], [16, 20], [16, 31], [17, 18], [17, 19], [18, 19], [18, 20], [19, 21], [20, 32], [20, 21], [20, 27], [21, 22], [21, 23], [22, 24], [22, 23], [23, 34], [23, 24], [24, 25], [24, 26], [25, 26], [25, 31], [26, 27], [26, 28], [27, 28], [27, 29], [28, 29], [28, 30], [29, 30], [29, 31], [30, 32], [30, 31], [31, 32], [32, 34], [33, 34]], "target": 1, "labels": {"0": "4", "1": "4", "2": "3", "3": "4", "4": "3", "5": "3", "6": "4", "7": "2", "8": "4", "9": "3", "10": "2", "11": "4", "12": "3", "13": "2", "14": "4", "15": "3", "16": "5", "17": "3", "18": "5", "19": "5", "20": "6", "21": "7", "22": "4", "23": "5", "24": "4", "25": "3", "26": "4", "27": "5", "28": "5", "29": "4", "30": "5", "31": "5", "32": "4", "33": "4", "34": "5"}}
--------------------------------------------------------------------------------
/input/test/12.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 2], [1, 27], [2, 3], [2, 4], [3, 9], [3, 20], [3, 13], [4, 12], [5, 24], [5, 14], [6, 23], [7, 24], [8, 25], [9, 23], [10, 17], [11, 27], [11, 23], [14, 17], [14, 18], [15, 26], [16, 17], [16, 25], [17, 28], [18, 19], [18, 23], [21, 22], [22, 28], [23, 27], [26, 27]], "target": 2, "labels": {"0": "1", "1": "1", "2": "3", "3": "4", "4": "2", "5": "2", "6": "1", "7": "1", "8": "1", "9": "2", "10": "1", "11": "2", "12": "1", "13": "1", "14": "3", "15": "1", "16": "2", "17": "4", "18": "3", "19": "1", "20": "1", "21": "1", "22": "2", "23": "5", "24": "2", "25": "2", "26": "2", "27": "4", "28": "2"}}
--------------------------------------------------------------------------------
/input/test/13.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 18], [0, 3], [0, 2], [0, 17], [1, 16], [1, 2], [1, 4], [1, 6], [1, 14], [1, 13], [2, 18], [2, 4], [2, 5], [2, 8], [2, 11], [2, 15], [3, 18], [3, 5], [3, 10], [4, 7], [6, 16], [6, 9], [6, 12], [6, 13], [7, 8], [7, 9], [7, 10], [8, 9], [8, 11], [9, 17], [9, 10], [9, 13], [10, 16], [10, 12], [10, 15], [11, 13], [12, 13], [14, 15]], "target": 2, "labels": {"0": "4", "1": "6", "2": "8", "3": "4", "4": "3", "5": "2", "6": "5", "7": "4", "8": "4", "9": "6", "10": "6", "11": "3", "12": "3", "13": "5", "14": "2", "15": "3", "16": "3", "17": "2", "18": "3"}}
--------------------------------------------------------------------------------
/input/test/14.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 4], [0, 5], [0, 6], [0, 7], [1, 6], [1, 7], [1, 8], [1, 9], [2, 8], [2, 9], [2, 10], [2, 3], [2, 4], [3, 6], [3, 7], [3, 8], [5, 8], [5, 10], [6, 7], [7, 10], [9, 10]], "target": 2, "labels": {"0": "5", "1": "5", "2": "5", "3": "4", "4": "2", "5": "3", "6": "4", "7": "5", "8": "4", "9": "3", "10": "4"}}
--------------------------------------------------------------------------------
/input/test/15.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 16], [0, 3], [1, 2], [2, 3], [2, 14], [2, 15], [3, 13], [4, 5], [5, 12], [5, 15], [6, 9], [7, 14], [8, 9], [8, 13], [9, 10], [9, 11], [10, 14]], "target": 1, "labels": {"0": "2", "1": "1", "2": "4", "3": "3", "4": "1", "5": "3", "6": "1", "7": "1", "8": "2", "9": "4", "10": "2", "11": "1", "12": "1", "13": "2", "14": "3", "15": "2", "16": "1"}}
--------------------------------------------------------------------------------
/input/test/16.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 34], [0, 15], [0, 10], [0, 33], [1, 2], [1, 3], [2, 24], [2, 26], [2, 13], [3, 4], [3, 5], [4, 5], [4, 6], [5, 32], [5, 6], [5, 7], [5, 30], [6, 8], [6, 14], [6, 7], [7, 8], [7, 9], [8, 9], [8, 31], [9, 11], [9, 12], [10, 11], [10, 28], [11, 12], [11, 13], [11, 28], [11, 31], [12, 13], [12, 14], [13, 21], [13, 15], [14, 15], [15, 21], [16, 17], [16, 18], [17, 18], [17, 19], [18, 19], [18, 20], [19, 20], [19, 28], [20, 21], [20, 22], [21, 22], [21, 23], [22, 24], [22, 23], [23, 24], [23, 25], [24, 26], [25, 26], [25, 27], [26, 27], [26, 28], [26, 34], [27, 28], [27, 29], [28, 30], [29, 30], [29, 31], [30, 32], [31, 32], [32, 33], [33, 34]], "target": 0, "labels": {"0": "5", "1": "3", "2": "4", "3": "3", "4": "3", "5": "6", "6": "5", "7": "4", "8": "4", "9": "4", "10": "3", "11": "6", "12": "4", "13": "5", "14": "3", "15": "4", "16": "2", "17": "3", "18": "4", "19": "4", "20": "4", "21": "5", "22": "4", "23": "4", "24": "4", "25": "3", "26": "6", "27": "4", "28": "6", "29": "3", "30": "4", "31": "4", "32": "4", "33": "3", "34": "3"}}
--------------------------------------------------------------------------------
/input/test/17.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 18], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], [13, 14], [14, 15], [15, 16], [16, 17], [17, 18]], "target": 0, "labels": {"0": "2", "1": "2", "2": "2", "3": "2", "4": "2", "5": "2", "6": "2", "7": "2", "8": "2", "9": "2", "10": "2", "11": "2", "12": "2", "13": "2", "14": "2", "15": "2", "16": "2", "17": "2", "18": "2"}}
--------------------------------------------------------------------------------
/input/test/18.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 33], [1, 2], [2, 3], [3, 4], [4, 16], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 25], [13, 14], [14, 15], [15, 16], [17, 18], [18, 19], [19, 20], [20, 21], [21, 22], [22, 32], [22, 31], [22, 23], [23, 24], [24, 25], [25, 26], [26, 27], [27, 28], [28, 29], [29, 30], [30, 31]], "target": 0, "labels": {"0": "2", "1": "2", "2": "2", "3": "2", "4": "3", "5": "2", "6": "2", "7": "2", "8": "2", "9": "2", "10": "2", "11": "2", "12": "2", "13": "1", "14": "2", "15": "2", "16": "2", "17": "1", "18": "2", "19": "2", "20": "2", "21": "2", "22": "4", "23": "2", "24": "2", "25": "3", "26": "2", "27": "2", "28": "2", "29": "2", "30": "2", "31": "2", "32": "1", "33": "1"}}
--------------------------------------------------------------------------------
/input/test/19.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 24], [0, 20], [0, 21], [1, 26], [2, 3], [3, 13], [3, 6], [4, 19], [5, 11], [6, 12], [7, 10], [8, 13], [8, 21], [9, 20], [10, 19], [11, 20], [12, 18], [12, 19], [14, 15], [15, 24], [15, 17], [16, 17], [19, 21], [19, 26], [22, 24], [23, 25], [25, 26]], "target": 2, "labels": {"0": "3", "1": "1", "2": "1", "3": "3", "4": "1", "5": "1", "6": "2", "7": "1", "8": "2", "9": "1", "10": "2", "11": "2", "12": "3", "13": "2", "14": "1", "15": "3", "16": "1", "17": "2", "18": "1", "19": "5", "20": "3", "21": "3", "22": "1", "23": "1", "24": "3", "25": "2", "26": "3"}}
--------------------------------------------------------------------------------
/input/test/2.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 16], [0, 18], [0, 3], [0, 17], [0, 2], [0, 14], [0, 15], [1, 17], [1, 2], [1, 4], [1, 6], [1, 12], [1, 18], [2, 18], [2, 4], [2, 5], [2, 13], [2, 15], [3, 4], [3, 5], [3, 6], [4, 5], [4, 6], [4, 7], [4, 8], [5, 6], [5, 7], [5, 8], [5, 14], [6, 7], [6, 8], [6, 14], [7, 8], [7, 9], [7, 10], [8, 10], [8, 12], [9, 10], [9, 11], [9, 12], [10, 11], [10, 12], [10, 13], [11, 12], [11, 13], [11, 14], [12, 13], [12, 14], [13, 16], [13, 14], [13, 15], [14, 16], [14, 17], [15, 18], [16, 17], [16, 18], [17, 18]], "target": 0, "labels": {"0": "7", "1": "6", "2": "7", "3": "4", "4": "7", "5": "7", "6": "7", "7": "6", "8": "6", "9": "4", "10": "6", "11": "5", "12": "7", "13": "7", "14": "8", "15": "4", "16": "5", "17": "5", "18": "6"}}
--------------------------------------------------------------------------------
/input/test/20.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 2], [0, 3], [0, 6], [0, 10], [0, 11], [1, 2], [1, 4], [1, 8], [1, 9], [1, 10], [1, 11], [2, 4], [2, 5], [2, 8], [2, 11], [3, 4], [3, 5], [3, 6], [4, 5], [4, 6], [4, 9], [4, 10], [5, 8], [5, 9], [6, 7], [6, 8], [6, 9], [7, 8], [7, 9], [8, 9], [8, 10], [8, 11], [9, 10], [9, 11], [10, 11]], "target": 0, "labels": {"0": "6", "1": "7", "2": "6", "3": "4", "4": "7", "5": "5", "6": "6", "7": "3", "8": "8", "9": "8", "10": "6", "11": "6"}}
--------------------------------------------------------------------------------
/input/test/21.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 16], [0, 1], [0, 11], [0, 15], [1, 2], [1, 3], [1, 10], [2, 16], [2, 14], [3, 4], [3, 5], [3, 9], [4, 10], [4, 5], [4, 7], [5, 16], [5, 6], [5, 7], [6, 16], [6, 9], [7, 16], [7, 8], [8, 10], [9, 11], [10, 11], [10, 13], [10, 14], [11, 16], [11, 15], [12, 13], [12, 14], [13, 14], [14, 16], [14, 15]], "target": 1, "labels": {"0": "4", "1": "4", "2": "3", "3": "4", "4": "4", "5": "5", "6": "3", "7": "4", "8": "2", "9": "3", "10": "6", "11": "5", "12": "2", "13": "3", "14": "6", "15": "3", "16": "7"}}
--------------------------------------------------------------------------------
/input/test/22.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 4], [1, 4], [1, 9], [1, 14], [2, 3], [2, 5], [2, 6], [2, 11], [3, 12], [3, 7], [4, 16], [4, 6], [4, 10], [4, 13], [5, 9], [5, 14], [6, 8], [6, 7], [7, 8], [7, 9], [7, 11], [7, 13], [7, 15], [8, 9], [8, 10], [8, 12], [8, 13], [8, 15], [9, 16], [9, 10], [9, 12], [11, 12], [13, 16]], "target": 2, "labels": {"0": "2", "1": "4", "2": "4", "3": "3", "4": "6", "5": "3", "6": "4", "7": "7", "8": "7", "9": "7", "10": "3", "11": "3", "12": "4", "13": "4", "14": "2", "15": "2", "16": "3"}}
--------------------------------------------------------------------------------
/input/test/23.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 2], [0, 12], [0, 13], [1, 2], [1, 3], [1, 13], [2, 5], [2, 6], [3, 4], [3, 5], [4, 5], [4, 6], [5, 6], [5, 9], [6, 8], [6, 7], [7, 8], [7, 9], [8, 9], [8, 10], [9, 10], [9, 11], [10, 11], [10, 12], [11, 12], [11, 13], [12, 13]], "target": 0, "labels": {"0": "4", "1": "4", "2": "4", "3": "3", "4": "3", "5": "5", "6": "5", "7": "3", "8": "4", "9": "5", "10": "4", "11": "4", "12": "4", "13": "4"}}
--------------------------------------------------------------------------------
/input/test/24.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 2], [0, 4], [0, 14], [0, 15], [1, 9], [1, 2], [1, 3], [1, 15], [2, 3], [2, 4], [3, 4], [3, 7], [3, 10], [3, 11], [4, 5], [4, 6], [4, 8], [5, 10], [5, 7], [6, 8], [6, 12], [6, 13], [6, 15], [7, 8], [8, 10], [8, 13], [9, 11], [10, 11], [11, 12], [11, 13], [11, 14], [12, 13]], "target": 1, "labels": {"0": "4", "1": "4", "2": "4", "3": "6", "4": "6", "5": "3", "6": "5", "7": "3", "8": "5", "9": "2", "10": "4", "11": "6", "12": "3", "13": "4", "14": "2", "15": "3"}}
--------------------------------------------------------------------------------
/input/test/25.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 6], [1, 3], [2, 11], [3, 4], [3, 14], [4, 14], [5, 6], [6, 7], [7, 8], [8, 9], [8, 11], [9, 10], [10, 11], [12, 13], [13, 19], [13, 14], [15, 16], [16, 17], [17, 18], [18, 19]], "target": 0, "labels": {"0": "1", "1": "1", "2": "1", "3": "3", "4": "2", "5": "1", "6": "3", "7": "2", "8": "3", "9": "2", "10": "2", "11": "3", "12": "1", "13": "3", "14": "3", "15": "1", "16": "2", "17": "2", "18": "2", "19": "2"}}
--------------------------------------------------------------------------------
/input/test/26.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 16], [0, 2], [0, 15], [1, 2], [1, 3], [2, 3], [2, 4], [2, 17], [2, 12], [3, 5], [3, 10], [3, 14], [4, 9], [4, 10], [4, 6], [5, 16], [5, 6], [6, 16], [6, 15], [7, 8], [7, 9], [7, 17], [8, 9], [8, 10], [9, 11], [10, 11], [10, 12], [11, 12], [11, 14], [12, 14], [13, 14], [13, 15], [14, 16], [15, 16], [15, 17], [16, 17]], "target": 1, "labels": {"0": "3", "1": "2", "2": "6", "3": "5", "4": "4", "5": "3", "6": "4", "7": "3", "8": "3", "9": "4", "10": "5", "11": "4", "12": "4", "13": "2", "14": "5", "15": "5", "16": "6", "17": "4"}}
--------------------------------------------------------------------------------
/input/test/27.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 8], [0, 1], [0, 2], [0, 9], [1, 9], [1, 2], [1, 3], [2, 3], [2, 4], [2, 6], [3, 4], [3, 5], [4, 5], [4, 6], [5, 6], [5, 7], [6, 8], [7, 8], [7, 9], [8, 9]], "target": 0, "labels": {"0": "4", "1": "4", "2": "5", "3": "4", "4": "4", "5": "4", "6": "4", "7": "3", "8": "4", "9": "4"}}
--------------------------------------------------------------------------------
/input/test/28.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 17], [0, 18], [0, 3], [0, 21], [0, 24], [0, 14], [1, 2], [1, 3], [1, 22], [1, 24], [1, 10], [1, 12], [2, 4], [2, 7], [2, 8], [2, 14], [3, 4], [3, 23], [3, 8], [3, 15], [4, 16], [4, 17], [4, 5], [4, 6], [4, 9], [4, 12], [5, 18], [5, 25], [5, 22], [5, 8], [5, 9], [5, 10], [5, 12], [5, 13], [6, 19], [6, 24], [6, 18], [6, 22], [6, 8], [7, 19], [7, 21], [7, 22], [7, 10], [7, 20], [8, 12], [9, 24], [9, 10], [9, 15], [10, 14], [11, 16], [11, 17], [11, 21], [11, 23], [11, 24], [11, 12], [11, 13], [13, 17], [13, 18], [13, 19], [13, 14], [13, 21], [14, 22], [14, 24], [15, 19], [15, 21], [15, 23], [16, 22], [17, 22], [18, 19], [19, 21], [20, 21], [20, 22], [20, 23], [21, 22], [21, 23], [21, 25], [23, 24], [23, 25]], "target": 2, "labels": {"0": "6", "1": "6", "2": "5", "3": "6", "4": "8", "5": "9", "6": "6", "7": "6", "8": "5", "9": "5", "10": "5", "11": "7", "12": "5", "13": "7", "14": "6", "15": "5", "16": "3", "17": "5", "18": "5", "19": "6", "20": "4", "21": "10", "22": "9", "23": "7", "24": "7", "25": "3"}}
--------------------------------------------------------------------------------
/input/test/29.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 19], [0, 1], [0, 5], [0, 9], [0, 10], [0, 7], [1, 3], [1, 2], [1, 19], [1, 12], [1, 13], [2, 4], [2, 5], [2, 9], [2, 11], [3, 19], [3, 4], [3, 6], [3, 8], [3, 10], [4, 16], [4, 5], [4, 7], [4, 10], [5, 17], [5, 7], [5, 13], [6, 17], [6, 18], [6, 10], [6, 14], [7, 16], [7, 18], [7, 9], [7, 14], [8, 9], [8, 10], [8, 11], [8, 13], [8, 15], [9, 17], [9, 11], [9, 12], [9, 14], [9, 15], [10, 11], [11, 12], [11, 18], [11, 15], [12, 17], [12, 14], [13, 19], [13, 17], [14, 18], [15, 16], [15, 18], [16, 17], [17, 18], [17, 19], [18, 19]], "target": 2, "labels": {"0": "6", "1": "6", "2": "5", "3": "6", "4": "6", "5": "6", "6": "5", "7": "7", "8": "6", "9": "9", "10": "6", "11": "7", "12": "5", "13": "5", "14": "5", "15": "5", "16": "4", "17": "8", "18": "7", "19": "6"}}
--------------------------------------------------------------------------------
/input/test/3.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 32], [0, 2], [0, 11], [0, 18], [0, 31], [1, 16], [1, 2], [1, 26], [1, 10], [2, 3], [2, 4], [2, 32], [3, 17], [3, 4], [3, 5], [4, 19], [4, 20], [4, 6], [4, 14], [5, 15], [5, 23], [6, 14], [6, 7], [7, 9], [7, 28], [8, 9], [8, 10], [9, 25], [9, 23], [9, 29], [9, 31], [10, 20], [10, 31], [11, 18], [11, 12], [11, 21], [12, 23], [12, 28], [13, 14], [13, 15], [14, 16], [14, 15], [15, 16], [15, 17], [16, 24], [16, 30], [17, 23], [18, 20], [19, 21], [19, 30], [21, 22], [22, 24], [22, 23], [23, 24], [24, 25], [24, 31], [25, 26], [25, 27], [25, 30], [26, 27], [27, 28], [27, 29], [28, 29], [28, 30], [29, 30], [31, 32]], "target": 1, "labels": {"0": "5", "1": "4", "2": "5", "3": "4", "4": "6", "5": "3", "6": "3", "7": "3", "8": "2", "9": "6", "10": "4", "11": "4", "12": "3", "13": "2", "14": "5", "15": "5", "16": "5", "17": "3", "18": "3", "19": "3", "20": "3", "21": "3", "22": "3", "23": "6", "24": "5", "25": "5", "26": "3", "27": "4", "28": "5", "29": "4", "30": "5", "31": "5", "32": "3"}}
--------------------------------------------------------------------------------
/input/test/4.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 2], [0, 10], [0, 14], [1, 4], [1, 9], [1, 12], [1, 13], [1, 15], [2, 7], [2, 14], [3, 4], [3, 6], [3, 14], [4, 16], [4, 10], [4, 11], [5, 16], [5, 6], [5, 8], [5, 15], [7, 10], [7, 11], [7, 13], [8, 9], [8, 14], [8, 13], [9, 11], [9, 12], [9, 15], [10, 16], [10, 12], [11, 12], [11, 13]], "target": 2, "labels": {"0": "4", "1": "6", "2": "3", "3": "3", "4": "5", "5": "4", "6": "2", "7": "4", "8": "4", "9": "5", "10": "5", "11": "5", "12": "4", "13": "4", "14": "4", "15": "3", "16": "3"}}
--------------------------------------------------------------------------------
/input/test/5.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 29], [1, 2], [2, 3], [3, 27], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], [13, 14], [14, 15], [15, 16], [16, 21], [17, 18], [18, 19], [19, 20], [20, 21], [21, 22], [22, 23], [23, 24], [24, 25], [25, 26], [26, 27], [28, 29]], "target": 0, "labels": {"0": "2", "1": "2", "2": "2", "3": "3", "4": "2", "5": "2", "6": "2", "7": "2", "8": "2", "9": "2", "10": "2", "11": "2", "12": "2", "13": "2", "14": "2", "15": "2", "16": "2", "17": "1", "18": "2", "19": "2", "20": "2", "21": "3", "22": "2", "23": "2", "24": "2", "25": "2", "26": "2", "27": "2", "28": "1", "29": "2"}}
--------------------------------------------------------------------------------
/input/test/6.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 16], [0, 3], [0, 5], [0, 6], [0, 9], [0, 14], [1, 3], [1, 6], [1, 9], [1, 10], [1, 12], [2, 6], [2, 8], [2, 10], [2, 14], [3, 4], [3, 5], [3, 7], [3, 15], [4, 16], [4, 6], [4, 10], [4, 13], [5, 16], [5, 7], [5, 11], [5, 15], [6, 7], [6, 11], [7, 16], [7, 17], [7, 10], [7, 13], [7, 14], [8, 17], [8, 9], [8, 12], [8, 13], [8, 14], [9, 16], [9, 10], [9, 12], [10, 11], [10, 12], [10, 13], [10, 14], [11, 12], [12, 16], [12, 14], [12, 15], [13, 16], [13, 17], [14, 17], [16, 17]], "target": 2, "labels": {"0": "6", "1": "5", "2": "4", "3": "6", "4": "5", "5": "6", "6": "6", "7": "8", "8": "6", "9": "6", "10": "9", "11": "4", "12": "8", "13": "6", "14": "7", "15": "3", "16": "8", "17": "5"}}
--------------------------------------------------------------------------------
/input/test/7.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 2], [1, 16], [1, 2], [1, 8], [1, 26], [2, 3], [2, 4], [2, 6], [3, 4], [3, 5], [3, 26], [3, 18], [3, 13], [4, 9], [4, 27], [5, 21], [5, 23], [5, 29], [5, 15], [6, 20], [7, 8], [7, 13], [8, 9], [9, 27], [9, 21], [9, 22], [9, 20], [10, 12], [10, 21], [11, 16], [11, 17], [11, 14], [12, 13], [12, 30], [12, 15], [13, 23], [13, 29], [14, 16], [14, 27], [14, 26], [14, 15], [15, 16], [15, 17], [16, 18], [16, 27], [17, 18], [17, 28], [18, 19], [19, 24], [19, 21], [20, 22], [21, 23], [22, 24], [23, 24], [23, 27], [24, 25], [24, 26], [25, 26], [25, 28], [25, 30], [26, 28]], "target": 1, "labels": {"0": "2", "1": "5", "2": "5", "3": "6", "4": "4", "5": "5", "6": "2", "7": "2", "8": "3", "9": "6", "10": "2", "11": "3", "12": "4", "13": "5", "14": "5", "15": "5", "16": "6", "17": "4", "18": "4", "19": "3", "20": "3", "21": "5", "22": "3", "23": "5", "24": "5", "25": "4", "26": "6", "27": "5", "28": "3", "29": "2", "30": "2"}}
--------------------------------------------------------------------------------
/input/test/8.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 32], [0, 11], [0, 12], [0, 14], [1, 20], [1, 30], [2, 32], [2, 3], [2, 9], [3, 19], [3, 26], [3, 29], [3, 31], [4, 9], [4, 31], [4, 25], [5, 28], [5, 20], [5, 7], [6, 16], [6, 32], [6, 20], [6, 30], [7, 23], [8, 18], [8, 19], [8, 22], [9, 16], [9, 12], [10, 25], [10, 29], [10, 31], [11, 22], [11, 13], [12, 20], [12, 21], [13, 31], [13, 28], [13, 15], [14, 17], [14, 22], [15, 16], [15, 20], [15, 29], [16, 18], [17, 18], [17, 30], [18, 19], [18, 20], [19, 21], [20, 24], [20, 29], [21, 22], [21, 29], [23, 24], [23, 25], [23, 26], [23, 27], [23, 29], [24, 26], [25, 26], [26, 32], [27, 29], [27, 31], [28, 29], [28, 30]], "target": 2, "labels": {"0": "4", "1": "2", "2": "3", "3": "5", "4": "3", "5": "3", "6": "4", "7": "2", "8": "3", "9": "4", "10": "3", "11": "3", "12": "4", "13": "4", "14": "3", "15": "4", "16": "4", "17": "3", "18": "5", "19": "4", "20": "8", "21": "4", "22": "4", "23": "6", "24": "3", "25": "4", "26": "5", "27": "3", "28": "4", "29": "8", "30": "4", "31": "5", "32": "4"}}
--------------------------------------------------------------------------------
/input/test/9.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 2], [0, 18], [0, 22], [0, 23], [1, 2], [1, 3], [1, 23], [2, 3], [2, 4], [3, 4], [3, 5], [3, 7], [3, 8], [4, 5], [4, 6], [5, 6], [6, 8], [6, 7], [7, 8], [7, 9], [7, 15], [8, 10], [9, 10], [9, 11], [10, 11], [10, 12], [11, 12], [11, 13], [12, 13], [12, 14], [13, 14], [13, 15], [14, 16], [14, 20], [14, 15], [15, 16], [16, 18], [16, 23], [17, 18], [17, 19], [18, 19], [19, 20], [19, 21], [20, 22], [21, 22], [21, 23], [22, 23]], "target": 1, "labels": {"0": "5", "1": "4", "2": "4", "3": "6", "4": "4", "5": "3", "6": "4", "7": "5", "8": "4", "9": "3", "10": "4", "11": "4", "12": "4", "13": "4", "14": "5", "15": "4", "16": "4", "17": "2", "18": "4", "19": "4", "20": "3", "21": "3", "22": "4", "23": "5"}}
--------------------------------------------------------------------------------
/input/train/0.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 16], [0, 6], [0, 29], [1, 4], [1, 21], [1, 7], [1, 12], [2, 28], [2, 3], [2, 4], [3, 4], [3, 8], [3, 20], [3, 26], [3, 30], [4, 5], [4, 23], [4, 28], [5, 18], [5, 19], [5, 7], [5, 27], [5, 12], [6, 19], [6, 31], [7, 23], [8, 17], [8, 10], [9, 16], [9, 20], [9, 24], [9, 28], [10, 24], [10, 19], [10, 22], [10, 23], [10, 26], [10, 11], [10, 31], [11, 16], [11, 12], [12, 18], [12, 25], [12, 29], [13, 30], [13, 14], [14, 19], [14, 15], [15, 26], [15, 30], [15, 31], [16, 18], [16, 20], [17, 18], [18, 30], [19, 25], [20, 27], [20, 29], [21, 22], [22, 24], [23, 25], [24, 27], [25, 27], [27, 28]], "target": 1, "labels": {"0": "3", "1": "4", "2": "3", "3": "6", "4": "6", "5": "6", "6": "3", "7": "3", "8": "3", "9": "4", "10": "8", "11": "3", "12": "6", "13": "2", "14": "3", "15": "4", "16": "5", "17": "2", "18": "5", "19": "5", "20": "5", "21": "2", "22": "3", "23": "4", "24": "4", "25": "4", "26": "3", "27": "5", "28": "4", "29": "3", "30": "4", "31": "3"}}
--------------------------------------------------------------------------------
/input/train/1.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 2], [0, 5], [0, 10], [0, 11], [0, 12], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 8], [1, 10], [1, 12], [2, 3], [2, 4], [2, 6], [2, 8], [2, 11], [3, 4], [3, 5], [3, 7], [3, 10], [3, 12], [4, 8], [4, 11], [4, 12], [5, 6], [5, 8], [5, 9], [5, 10], [5, 12], [6, 7], [6, 11], [7, 11], [7, 12], [8, 10], [9, 11], [9, 12], [11, 12]], "target": 1, "labels": {"0": "5", "1": "8", "2": "7", "3": "7", "4": "6", "5": "8", "6": "5", "7": "4", "8": "5", "9": "3", "10": "5", "11": "7", "12": "8"}}
--------------------------------------------------------------------------------
/input/train/10.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [1, 2], [1, 21], [2, 24], [2, 3], [3, 17], [3, 25], [4, 17], [4, 5], [5, 12], [6, 27], [6, 20], [7, 8], [7, 29], [9, 10], [10, 23], [11, 12], [12, 13], [13, 19], [14, 20], [15, 26], [16, 17], [18, 19], [19, 20], [22, 26], [23, 24], [25, 26], [25, 29], [25, 30], [27, 28], [28, 29]], "target": 1, "labels": {"0": "1", "1": "3", "2": "3", "3": "3", "4": "2", "5": "2", "6": "2", "7": "2", "8": "1", "9": "1", "10": "2", "11": "1", "12": "3", "13": "2", "14": "1", "15": "1", "16": "1", "17": "3", "18": "1", "19": "3", "20": "3", "21": "1", "22": "1", "23": "2", "24": "2", "25": "4", "26": "3", "27": "2", "28": "2", "29": "3", "30": "1"}}
--------------------------------------------------------------------------------
/input/train/11.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 2], [0, 3], [0, 22], [0, 23], [0, 24], [0, 18], [1, 3], [1, 4], [1, 23], [1, 7], [2, 3], [2, 4], [2, 5], [2, 24], [3, 5], [3, 6], [3, 14], [4, 5], [4, 6], [4, 22], [5, 17], [5, 6], [5, 23], [6, 16], [6, 7], [6, 8], [7, 8], [7, 9], [7, 10], [8, 9], [8, 10], [8, 11], [9, 10], [9, 11], [9, 12], [10, 24], [10, 12], [10, 13], [10, 15], [11, 12], [11, 13], [11, 14], [12, 16], [12, 13], [12, 14], [12, 15], [13, 16], [13, 14], [13, 15], [14, 16], [14, 17], [14, 15], [15, 16], [15, 17], [15, 18], [16, 17], [16, 19], [17, 18], [17, 19], [17, 20], [18, 19], [18, 21], [19, 20], [19, 21], [19, 22], [20, 21], [20, 22], [20, 23], [21, 22], [21, 23], [21, 24], [22, 23], [22, 24], [23, 24]], "target": 0, "labels": {"0": "7", "1": "5", "2": "5", "3": "6", "4": "5", "5": "6", "6": "6", "7": "5", "8": "5", "9": "5", "10": "7", "11": "5", "12": "7", "13": "6", "14": "7", "15": "7", "16": "7", "17": "7", "18": "5", "19": "6", "20": "5", "21": "6", "22": "7", "23": "7", "24": "6"}}
--------------------------------------------------------------------------------
/input/train/12.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 9], [0, 5], [0, 7], [1, 2], [1, 13], [2, 10], [2, 13], [3, 4], [3, 6], [4, 16], [7, 11], [8, 9], [12, 13], [13, 16], [14, 15], [15, 16]], "target": 1, "labels": {"0": "4", "1": "3", "2": "3", "3": "2", "4": "2", "5": "1", "6": "1", "7": "2", "8": "1", "9": "2", "10": "1", "11": "1", "12": "1", "13": "4", "14": "1", "15": "2", "16": "3"}}
--------------------------------------------------------------------------------
/input/train/13.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 2], [0, 13], [0, 14], [1, 2], [1, 3], [1, 14], [2, 3], [2, 4], [3, 4], [3, 5], [4, 5], [4, 6], [5, 6], [5, 7], [6, 8], [6, 7], [7, 8], [7, 9], [8, 9], [8, 10], [9, 10], [9, 11], [10, 11], [10, 12], [11, 12], [11, 13], [12, 13], [12, 14], [13, 14]], "target": 0, "labels": {"0": "4", "1": "4", "2": "4", "3": "4", "4": "4", "5": "4", "6": "4", "7": "4", "8": "4", "9": "4", "10": "4", "11": "4", "12": "4", "13": "4", "14": "4"}}
--------------------------------------------------------------------------------
/input/train/14.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 20], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], [13, 14], [14, 15], [15, 16], [16, 17], [17, 18], [18, 19], [19, 20]], "target": 0, "labels": {"0": "2", "1": "2", "2": "2", "3": "2", "4": "2", "5": "2", "6": "2", "7": "2", "8": "2", "9": "2", "10": "2", "11": "2", "12": "2", "13": "2", "14": "2", "15": "2", "16": "2", "17": "2", "18": "2", "19": "2", "20": "2"}}
--------------------------------------------------------------------------------
/input/train/15.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 2], [0, 5], [0, 7], [0, 9], [0, 10], [1, 2], [1, 4], [2, 4], [2, 6], [3, 10], [3, 4], [3, 5], [3, 7], [4, 5], [4, 8], [4, 10], [5, 6], [5, 8], [5, 10], [6, 8], [7, 9], [9, 10]], "target": 1, "labels": {"0": "5", "1": "2", "2": "4", "3": "4", "4": "6", "5": "6", "6": "3", "7": "3", "8": "3", "9": "3", "10": "5"}}
--------------------------------------------------------------------------------
/input/train/16.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 2], [0, 3], [0, 22], [0, 27], [0, 28], [1, 2], [1, 28], [1, 5], [2, 4], [2, 10], [2, 12], [3, 5], [3, 16], [3, 9], [3, 12], [4, 17], [4, 20], [4, 6], [4, 9], [4, 24], [4, 25], [4, 26], [5, 18], [5, 6], [5, 9], [5, 12], [6, 22], [6, 7], [7, 8], [7, 9], [8, 16], [8, 10], [8, 11], [9, 26], [10, 17], [11, 13], [12, 25], [12, 13], [12, 14], [13, 15], [14, 17], [14, 18], [15, 16], [15, 17], [15, 18], [15, 19], [16, 24], [17, 24], [19, 20], [19, 23], [20, 21], [20, 22], [21, 23], [23, 24], [24, 26], [25, 27], [26, 27], [27, 28]], "target": 1, "labels": {"0": "5", "1": "3", "2": "5", "3": "5", "4": "8", "5": "6", "6": "4", "7": "3", "8": "4", "9": "5", "10": "3", "11": "2", "12": "6", "13": "3", "14": "3", "15": "5", "16": "4", "17": "5", "18": "3", "19": "3", "20": "4", "21": "2", "22": "3", "23": "3", "24": "5", "25": "3", "26": "4", "27": "4", "28": "3"}}
--------------------------------------------------------------------------------
/input/train/17.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6], [0, 7], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6], [1, 7], [2, 3], [2, 4], [2, 6], [2, 7], [3, 4], [3, 5], [3, 6], [3, 7], [4, 5], [5, 6], [5, 7]], "target": 1, "labels": {"0": "7", "1": "7", "2": "6", "3": "7", "4": "5", "5": "6", "6": "5", "7": "5"}}
--------------------------------------------------------------------------------
/input/train/18.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 2], [0, 4], [0, 23], [0, 10], [0, 13], [1, 2], [1, 3], [1, 21], [2, 22], [2, 10], [2, 13], [2, 30], [3, 18], [3, 27], [3, 5], [3, 6], [3, 11], [3, 14], [3, 21], [4, 21], [4, 22], [4, 8], [4, 18], [5, 23], [5, 8], [5, 26], [5, 29], [6, 32], [6, 11], [6, 13], [6, 30], [6, 31], [7, 16], [7, 17], [7, 20], [7, 10], [7, 27], [7, 14], [8, 10], [8, 26], [8, 31], [9, 22], [9, 23], [9, 31], [9, 27], [9, 14], [9, 15], [10, 16], [10, 19], [10, 11], [10, 30], [11, 20], [11, 27], [11, 13], [12, 22], [12, 26], [12, 28], [12, 14], [13, 22], [13, 17], [13, 15], [14, 23], [15, 32], [15, 17], [15, 19], [16, 33], [16, 21], [16, 25], [16, 23], [17, 23], [17, 24], [17, 27], [18, 23], [18, 24], [19, 21], [19, 25], [20, 23], [20, 29], [21, 22], [21, 23], [21, 24], [22, 33], [22, 23], [24, 25], [24, 27], [24, 29], [24, 30], [25, 32], [25, 27], [25, 31], [26, 32], [26, 31], [27, 33], [27, 28], [27, 31], [28, 29], [28, 30], [28, 31], [29, 33], [29, 31], [30, 32], [32, 33]], "target": 2, "labels": {"0": "5", "1": "3", "2": "6", "3": "8", "4": "5", "5": "5", "6": "6", "7": "6", "8": "5", "9": "6", "10": "8", "11": "6", "12": "4", "13": "7", "14": "5", "15": "5", "16": "6", "17": "6", "18": "4", "19": "4", "20": "4", "21": "8", "22": "8", "23": "10", "24": "7", "25": "6", "26": "5", "27": "10", "28": "5", "29": "6", "30": "6", "31": "8", "32": "6", "33": "5"}}
--------------------------------------------------------------------------------
/input/train/19.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 2], [0, 4], [0, 8], [1, 2], [1, 3], [1, 8], [1, 10], [2, 3], [2, 8], [3, 4], [3, 5], [3, 13], [4, 11], [5, 7], [5, 8], [5, 10], [5, 11], [6, 10], [6, 13], [7, 8], [7, 11], [7, 12], [9, 10], [9, 11], [10, 11], [11, 12], [12, 13]], "target": 1, "labels": {"0": "4", "1": "5", "2": "4", "3": "5", "4": "3", "5": "5", "6": "2", "7": "4", "8": "5", "9": "2", "10": "5", "11": "6", "12": "3", "13": "3"}}
--------------------------------------------------------------------------------
/input/train/2.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 16], [0, 1], [0, 11], [0, 15], [1, 19], [1, 11], [2, 3], [3, 4], [3, 10], [3, 21], [4, 19], [4, 8], [5, 14], [5, 7], [6, 7], [8, 9], [12, 20], [12, 14], [13, 21], [14, 18], [17, 18], [18, 19]], "target": 2, "labels": {"0": "4", "1": "3", "2": "1", "3": "4", "4": "3", "5": "2", "6": "1", "7": "2", "8": "2", "9": "1", "10": "1", "11": "2", "12": "2", "13": "1", "14": "3", "15": "1", "16": "1", "17": "1", "18": "3", "19": "3", "20": "1", "21": "2"}}
--------------------------------------------------------------------------------
/input/train/20.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 3], [0, 1], [0, 9], [0, 13], [1, 16], [1, 4], [1, 8], [1, 12], [2, 18], [2, 10], [2, 11], [2, 12], [3, 9], [3, 18], [3, 10], [4, 17], [4, 7], [4, 12], [5, 16], [5, 10], [6, 16], [6, 10], [7, 17], [7, 13], [7, 14], [8, 15], [8, 9], [9, 17], [9, 10], [9, 14], [9, 15], [11, 13], [11, 14], [12, 17], [12, 13], [12, 15], [15, 16], [15, 18]], "target": 2, "labels": {"0": "4", "1": "5", "2": "4", "3": "4", "4": "4", "5": "2", "6": "2", "7": "4", "8": "3", "9": "7", "10": "5", "11": "3", "12": "6", "13": "4", "14": "3", "15": "5", "16": "4", "17": "4", "18": "3"}}
--------------------------------------------------------------------------------
/input/train/21.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 2], [0, 5], [0, 6], [1, 9], [1, 11], [2, 7], [2, 12], [2, 13], [2, 14], [3, 16], [3, 9], [3, 10], [3, 13], [4, 19], [4, 12], [4, 5], [5, 18], [6, 8], [6, 12], [6, 17], [6, 7], [7, 9], [7, 12], [8, 16], [8, 18], [8, 9], [8, 10], [9, 18], [9, 10], [10, 19], [10, 15], [11, 19], [11, 13], [13, 19], [13, 18], [14, 16], [14, 15], [15, 16], [16, 17]], "target": 2, "labels": {"0": "4", "1": "3", "2": "5", "3": "4", "4": "3", "5": "3", "6": "5", "7": "4", "8": "5", "9": "6", "10": "5", "11": "3", "12": "4", "13": "5", "14": "3", "15": "3", "16": "5", "17": "2", "18": "4", "19": "4"}}
--------------------------------------------------------------------------------
/input/train/22.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 8], [0, 1], [0, 2], [1, 3], [1, 20], [2, 3], [2, 21], [2, 13], [3, 5], [3, 9], [3, 14], [4, 20], [4, 5], [4, 13], [5, 20], [5, 21], [5, 6], [5, 8], [5, 11], [5, 13], [6, 9], [7, 12], [7, 13], [8, 19], [8, 14], [9, 10], [9, 12], [9, 13], [9, 15], [10, 17], [10, 21], [11, 12], [12, 13], [13, 21], [14, 18], [15, 21], [16, 17], [16, 18], [16, 20], [17, 19], [17, 21], [18, 19], [18, 20], [19, 20]], "target": 1, "labels": {"0": "3", "1": "3", "2": "4", "3": "5", "4": "3", "5": "8", "6": "2", "7": "2", "8": "4", "9": "6", "10": "3", "11": "2", "12": "4", "13": "7", "14": "3", "15": "2", "16": "3", "17": "4", "18": "4", "19": "4", "20": "6", "21": "6"}}
--------------------------------------------------------------------------------
/input/train/23.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 24], [0, 1], [0, 2], [0, 23], [1, 24], [1, 2], [1, 3], [2, 3], [2, 4], [3, 4], [3, 5], [4, 5], [4, 6], [5, 13], [5, 6], [6, 8], [6, 7], [7, 8], [7, 9], [8, 9], [8, 10], [9, 15], [9, 11], [10, 11], [10, 12], [11, 12], [11, 13], [12, 13], [12, 14], [13, 23], [13, 14], [13, 15], [14, 16], [14, 15], [15, 16], [15, 17], [16, 17], [16, 18], [17, 18], [17, 19], [18, 19], [18, 20], [19, 20], [19, 21], [20, 21], [20, 22], [21, 22], [21, 23], [22, 24], [22, 23]], "target": 0, "labels": {"0": "4", "1": "4", "2": "4", "3": "4", "4": "4", "5": "4", "6": "4", "7": "3", "8": "4", "9": "4", "10": "3", "11": "4", "12": "4", "13": "6", "14": "4", "15": "5", "16": "4", "17": "4", "18": "4", "19": "4", "20": "4", "21": "4", "22": "4", "23": "4", "24": "3"}}
--------------------------------------------------------------------------------
/input/train/24.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 6], [1, 2], [2, 11], [2, 3], [3, 8], [3, 4], [4, 19], [4, 5], [4, 13], [5, 6], [6, 7], [7, 15], [9, 10], [9, 23], [10, 16], [10, 11], [11, 17], [12, 13], [13, 14], [15, 19], [18, 19], [20, 21], [21, 22], [22, 23], [23, 26], [24, 25], [25, 26]], "target": 1, "labels": {"0": "1", "1": "1", "2": "3", "3": "3", "4": "4", "5": "2", "6": "3", "7": "2", "8": "1", "9": "2", "10": "3", "11": "3", "12": "1", "13": "3", "14": "1", "15": "2", "16": "1", "17": "1", "18": "1", "19": "3", "20": "1", "21": "2", "22": "2", "23": "3", "24": "1", "25": "2", "26": "2"}}
--------------------------------------------------------------------------------
/input/train/25.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 33], [0, 20], [0, 31], [0, 30], [0, 29], [1, 32], [1, 16], [1, 34], [1, 10], [1, 13], [2, 20], [2, 23], [2, 25], [2, 29], [3, 5], [3, 15], [4, 19], [4, 5], [4, 8], [4, 11], [5, 18], [5, 21], [5, 15], [6, 8], [6, 7], [7, 8], [7, 9], [7, 23], [8, 33], [8, 22], [9, 18], [9, 12], [10, 21], [10, 11], [10, 29], [11, 27], [11, 15], [12, 33], [12, 14], [13, 22], [13, 15], [14, 16], [14, 19], [14, 15], [15, 25], [15, 31], [16, 18], [16, 31], [17, 31], [17, 26], [17, 19], [17, 29], [18, 32], [18, 20], [19, 30], [19, 29], [21, 32], [23, 24], [24, 25], [24, 26], [25, 26], [25, 28], [26, 33], [26, 28], [27, 28], [28, 34], [28, 31], [30, 32], [31, 32], [32, 34]], "target": 1, "labels": {"0": "5", "1": "5", "2": "4", "3": "2", "4": "4", "5": "5", "6": "2", "7": "4", "8": "5", "9": "3", "10": "4", "11": "4", "12": "3", "13": "3", "14": "4", "15": "7", "16": "4", "17": "4", "18": "5", "19": "5", "20": "3", "21": "3", "22": "2", "23": "3", "24": "3", "25": "5", "26": "5", "27": "2", "28": "5", "29": "5", "30": "3", "31": "6", "32": "6", "33": "4", "34": "3"}}
--------------------------------------------------------------------------------
/input/train/26.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 20], [0, 9], [0, 10], [1, 16], [1, 21], [1, 24], [1, 11], [1, 13], [2, 19], [2, 5], [2, 6], [2, 10], [2, 15], [3, 8], [3, 27], [4, 16], [4, 27], [4, 13], [5, 14], [6, 24], [6, 19], [6, 22], [6, 26], [7, 17], [7, 9], [7, 25], [8, 18], [8, 22], [8, 28], [8, 29], [9, 25], [9, 10], [9, 29], [9, 15], [10, 24], [10, 25], [10, 28], [11, 19], [11, 29], [12, 17], [12, 26], [12, 19], [12, 21], [13, 27], [13, 15], [14, 17], [14, 20], [15, 16], [15, 25], [16, 22], [17, 22], [17, 29], [18, 26], [19, 26], [20, 23], [20, 29], [22, 23], [24, 29], [25, 27], [28, 29]], "target": 2, "labels": {"0": "3", "1": "5", "2": "5", "3": "2", "4": "3", "5": "2", "6": "5", "7": "3", "8": "5", "9": "6", "10": "6", "11": "3", "12": "4", "13": "4", "14": "3", "15": "5", "16": "4", "17": "5", "18": "2", "19": "5", "20": "4", "21": "2", "22": "5", "23": "2", "24": "4", "25": "5", "26": "4", "27": "4", "28": "3", "29": "7"}}
--------------------------------------------------------------------------------
/input/train/27.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 32], [0, 1], [0, 2], [0, 3], [0, 33], [0, 31], [1, 33], [1, 2], [1, 3], [1, 4], [1, 32], [2, 4], [2, 5], [2, 33], [2, 13], [3, 4], [3, 5], [3, 6], [4, 5], [4, 6], [4, 7], [5, 6], [5, 7], [5, 8], [6, 7], [6, 8], [6, 9], [7, 8], [7, 9], [7, 10], [8, 9], [8, 10], [8, 11], [9, 10], [9, 11], [9, 12], [10, 25], [10, 11], [10, 12], [10, 13], [11, 12], [11, 13], [11, 14], [12, 13], [12, 14], [12, 15], [13, 16], [13, 14], [13, 15], [14, 16], [14, 17], [14, 15], [15, 16], [15, 17], [15, 18], [16, 17], [16, 18], [16, 19], [17, 18], [17, 19], [17, 20], [18, 19], [18, 20], [18, 21], [19, 20], [19, 21], [19, 22], [20, 21], [20, 22], [20, 23], [21, 22], [21, 23], [21, 24], [22, 23], [22, 24], [22, 25], [23, 24], [23, 25], [23, 26], [24, 25], [24, 26], [24, 27], [25, 27], [25, 28], [26, 27], [26, 28], [26, 29], [27, 28], [27, 29], [27, 30], [28, 29], [28, 30], [28, 31], [29, 32], [29, 30], [29, 31], [30, 32], [30, 33], [30, 31], [31, 32], [31, 33], [32, 33]], "target": 0, "labels": {"0": "6", "1": "6", "2": "6", "3": "5", "4": "6", "5": "6", "6": "6", "7": "6", "8": "6", "9": "6", "10": "7", "11": "6", "12": "6", "13": "7", "14": "6", "15": "6", "16": "6", "17": "6", "18": "6", "19": "6", "20": "6", "21": "6", "22": "6", "23": "6", "24": "6", "25": "6", "26": "5", "27": "6", "28": "6", "29": "6", "30": "6", "31": "6", "32": "6", "33": "6"}}
--------------------------------------------------------------------------------
/input/train/28.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [1, 16], [1, 13], [2, 23], [3, 18], [3, 12], [3, 23], [4, 18], [5, 22], [5, 7], [6, 17], [6, 11], [8, 22], [9, 10], [10, 16], [11, 16], [12, 18], [13, 16], [13, 27], [14, 26], [15, 16], [16, 26], [17, 25], [19, 20], [20, 21], [21, 24], [21, 25], [22, 23], [24, 28]], "target": 2, "labels": {"0": "1", "1": "3", "2": "1", "3": "3", "4": "1", "5": "2", "6": "2", "7": "1", "8": "1", "9": "1", "10": "2", "11": "2", "12": "2", "13": "3", "14": "1", "15": "1", "16": "6", "17": "2", "18": "3", "19": "1", "20": "2", "21": "3", "22": "3", "23": "3", "24": "2", "25": "2", "26": "2", "27": "1", "28": "1"}}
--------------------------------------------------------------------------------
/input/train/29.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 8], [0, 1], [0, 2], [0, 14], [0, 15], [1, 2], [1, 3], [1, 8], [1, 11], [1, 15], [2, 4], [2, 14], [3, 4], [3, 5], [3, 12], [4, 12], [4, 5], [5, 6], [5, 7], [6, 8], [6, 7], [7, 8], [7, 9], [7, 11], [9, 10], [9, 11], [10, 11], [10, 12], [12, 14], [13, 14], [13, 15], [14, 15]], "target": 0, "labels": {"0": "5", "1": "6", "2": "4", "3": "4", "4": "4", "5": "4", "6": "3", "7": "5", "8": "4", "9": "3", "10": "3", "11": "4", "12": "4", "13": "2", "14": "5", "15": "4"}}
--------------------------------------------------------------------------------
/input/train/3.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 12], [0, 13], [1, 2], [2, 21], [2, 7], [3, 19], [3, 11], [3, 5], [4, 23], [5, 22], [5, 6], [6, 7], [7, 8], [8, 23], [9, 10], [10, 18], [11, 26], [12, 26], [13, 24], [14, 15], [15, 16], [16, 17], [17, 18], [18, 19], [20, 21], [23, 24], [25, 26]], "target": 1, "labels": {"0": "2", "1": "1", "2": "3", "3": "3", "4": "1", "5": "3", "6": "2", "7": "3", "8": "2", "9": "1", "10": "2", "11": "2", "12": "2", "13": "2", "14": "1", "15": "2", "16": "2", "17": "2", "18": "3", "19": "2", "20": "1", "21": "2", "22": "1", "23": "3", "24": "2", "25": "1", "26": "3"}}
--------------------------------------------------------------------------------
/input/train/4.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 2], [0, 28], [0, 30], [0, 31], [1, 2], [1, 3], [1, 6], [1, 31], [2, 26], [2, 3], [2, 4], [3, 4], [3, 5], [4, 5], [4, 6], [5, 18], [5, 6], [5, 7], [6, 8], [7, 8], [7, 9], [8, 9], [8, 18], [9, 27], [9, 11], [10, 17], [10, 11], [11, 20], [11, 13], [12, 13], [12, 30], [13, 14], [13, 15], [14, 16], [14, 15], [15, 16], [15, 31], [16, 17], [16, 18], [17, 18], [17, 19], [18, 19], [18, 28], [19, 20], [19, 21], [20, 21], [20, 22], [21, 22], [21, 23], [22, 24], [22, 23], [23, 24], [23, 25], [24, 25], [24, 26], [25, 26], [25, 27], [26, 28], [27, 28], [27, 29], [28, 29], [29, 30], [29, 31], [30, 31]], "target": 0, "labels": {"0": "4", "1": "4", "2": "5", "3": "4", "4": "4", "5": "5", "6": "4", "7": "3", "8": "4", "9": "4", "10": "2", "11": "4", "12": "2", "13": "4", "14": "3", "15": "4", "16": "4", "17": "4", "18": "6", "19": "4", "20": "4", "21": "4", "22": "4", "23": "4", "24": "4", "25": "4", "26": "4", "27": "4", "28": "5", "29": "4", "30": "4", "31": "5"}}
--------------------------------------------------------------------------------
/input/train/5.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 1], [0, 18], [0, 20], [0, 21], [0, 14], [1, 9], [1, 2], [2, 3], [2, 4], [3, 4], [3, 5], [3, 15], [4, 11], [4, 5], [4, 6], [5, 6], [6, 8], [6, 7], [7, 8], [7, 9], [8, 9], [8, 10], [9, 10], [9, 12], [9, 15], [10, 12], [10, 13], [11, 13], [11, 15], [12, 18], [12, 14], [13, 16], [13, 19], [13, 14], [13, 15], [14, 19], [14, 20], [14, 17], [16, 18], [16, 21], [17, 18], [17, 19], [18, 19], [19, 21]], "target": 1, "labels": {"0": "5", "1": "3", "2": "3", "3": "4", "4": "5", "5": "3", "6": "4", "7": "3", "8": "4", "9": "6", "10": "4", "11": "3", "12": "4", "13": "6", "14": "6", "15": "4", "16": "3", "17": "3", "18": "5", "19": "5", "20": "2", "21": "3"}}
--------------------------------------------------------------------------------
/input/train/6.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 8], [0, 2], [0, 7], [0, 9], [1, 2], [1, 3], [2, 3], [2, 4], [3, 4], [3, 5], [4, 5], [4, 6], [5, 6], [5, 8], [5, 9], [6, 8], [6, 7], [7, 8], [7, 9], [8, 9]], "target": 0, "labels": {"0": "4", "1": "2", "2": "4", "3": "4", "4": "4", "5": "5", "6": "4", "7": "4", "8": "5", "9": "4"}}
--------------------------------------------------------------------------------
/input/train/7.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 2], [0, 3], [0, 5], [0, 6], [0, 11], [1, 2], [1, 3], [1, 8], [1, 12], [1, 13], [2, 3], [2, 4], [2, 9], [2, 12], [2, 13], [3, 4], [3, 7], [3, 12], [3, 13], [4, 5], [4, 6], [4, 7], [4, 12], [5, 8], [5, 11], [5, 12], [5, 13], [6, 9], [6, 12], [7, 8], [7, 9], [7, 10], [7, 13], [8, 9], [8, 10], [8, 11], [9, 10], [9, 11], [10, 11], [10, 12], [11, 13], [12, 13]], "target": 1, "labels": {"0": "5", "1": "5", "2": "7", "3": "7", "4": "6", "5": "6", "6": "4", "7": "6", "8": "6", "9": "6", "10": "5", "11": "6", "12": "8", "13": "7"}}
--------------------------------------------------------------------------------
/input/train/8.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 24], [0, 1], [0, 2], [1, 25], [1, 3], [1, 5], [2, 3], [2, 4], [3, 4], [3, 5], [4, 24], [4, 5], [4, 6], [5, 6], [5, 7], [5, 11], [6, 8], [6, 14], [6, 7], [7, 8], [7, 9], [8, 9], [8, 10], [9, 11], [9, 12], [10, 11], [10, 12], [11, 17], [11, 13], [12, 13], [12, 14], [13, 14], [13, 15], [14, 16], [14, 21], [15, 17], [15, 19], [16, 17], [16, 18], [16, 25], [17, 18], [17, 20], [18, 19], [18, 20], [19, 20], [19, 21], [20, 21], [21, 23], [22, 24], [22, 23], [23, 24], [23, 25]], "target": 0, "labels": {"0": "3", "1": "4", "2": "3", "3": "4", "4": "5", "5": "6", "6": "5", "7": "4", "8": "4", "9": "4", "10": "3", "11": "5", "12": "4", "13": "4", "14": "5", "15": "3", "16": "4", "17": "5", "18": "4", "19": "4", "20": "4", "21": "4", "22": "2", "23": "4", "24": "4", "25": "3"}}
--------------------------------------------------------------------------------
/input/train/9.json:
--------------------------------------------------------------------------------
1 | {"edges": [[0, 19], [0, 5], [0, 22], [0, 7], [1, 16], [2, 14], [2, 22], [3, 6], [4, 5], [5, 9], [5, 17], [6, 19], [6, 18], [8, 16], [8, 21], [8, 13], [10, 14], [11, 22], [12, 13], [15, 18], [19, 20], [20, 21], [21, 22]], "target": 2, "labels": {"0": "4", "1": "1", "2": "2", "3": "1", "4": "1", "5": "4", "6": "3", "7": "1", "8": "3", "9": "1", "10": "1", "11": "1", "12": "1", "13": "2", "14": "2", "15": "1", "16": "2", "17": "1", "18": "2", "19": "3", "20": "2", "21": "3", "22": "4"}}
--------------------------------------------------------------------------------
/output/watts_predictions.csv:
--------------------------------------------------------------------------------
1 | id,predictions
2 | 561,2
3 | 305,2
4 | 230,0
5 | 445,2
6 | 180,1
7 | 838,0
8 | 552,2
9 | 362,1
10 | 497,2
11 | 596,1
12 | 520,2
13 | 229,2
14 | 377,2
15 | 382,2
16 | 645,2
17 | 46,2
18 | 634,0
19 | 934,2
20 | 371,2
21 | 239,2
22 | 647,1
23 | 344,2
24 | 692,2
25 | 896,2
26 | 536,0
27 | 883,2
28 | 97,2
29 | 474,2
30 | 512,2
31 | 139,0
32 | 938,2
33 | 525,2
34 | 331,2
35 | 946,0
36 | 498,2
37 | 409,0
38 | 604,2
39 | 844,2
40 | 735,2
41 | 756,2
42 | 157,1
43 | 47,2
44 | 885,2
45 | 745,2
46 | 858,2
47 | 219,2
48 | 172,2
49 | 996,1
50 | 27,0
51 | 29,0
52 | 245,2
53 | 357,0
54 | 187,0
55 | 719,1
56 | 630,1
57 | 893,0
58 | 580,2
59 | 783,2
60 | 828,0
61 | 280,1
62 | 810,2
63 | 418,1
64 | 476,2
65 | 617,2
66 | 575,1
67 | 468,1
68 | 168,2
69 | 665,0
70 | 10,1
71 | 734,2
72 | 432,2
73 | 772,2
74 | 273,2
75 | 609,2
76 | 637,2
77 | 626,2
78 | 176,2
79 | 610,2
80 | 151,2
81 | 121,2
82 | 248,0
83 | 995,2
84 | 215,2
85 | 704,2
86 | 412,2
87 | 567,2
88 | 693,2
89 | 8,2
90 | 235,2
91 | 487,2
92 | 383,2
93 | 165,2
94 | 131,2
95 | 800,2
96 | 698,2
97 | 201,2
98 | 297,0
99 | 944,1
100 | 198,2
101 | 345,2
102 | 87,2
103 | 360,1
104 | 18,2
105 | 943,0
106 | 717,0
107 | 451,1
108 | 587,2
109 | 997,1
110 | 86,1
111 | 393,2
112 | 619,2
113 | 811,2
114 | 254,2
115 | 420,2
116 | 454,0
117 | 292,2
118 | 145,2
119 | 922,1
120 | 182,1
121 | 514,2
122 | 736,2
123 | 794,2
124 | 718,2
125 | 834,2
126 | 72,2
127 | 594,1
128 | 25,0
129 | 766,0
130 | 320,2
131 | 843,1
132 | 798,2
133 | 839,2
134 | 128,2
135 | 11,2
136 | 105,2
137 | 537,2
138 | 741,2
139 | 957,2
140 | 5,2
141 | 173,2
142 | 405,2
143 | 916,1
144 | 142,2
145 | 902,2
146 | 385,2
147 | 654,2
148 | 61,2
149 | 276,2
150 | 649,0
151 | 392,1
152 | 366,2
153 | 13,2
154 | 421,2
155 | 119,2
156 | 64,1
157 | 763,0
158 | 960,2
159 | 41,0
160 | 694,2
161 | 605,2
162 | 124,0
163 | 442,2
164 | 155,1
165 | 258,2
166 | 739,2
167 | 57,0
168 | 865,2
169 | 874,2
170 | 90,2
171 | 495,2
172 | 160,2
173 | 965,1
174 | 370,2
175 | 796,2
176 | 722,2
177 | 206,2
178 | 485,1
179 | 847,2
180 | 658,1
181 | 792,2
182 | 193,1
183 | 466,1
184 | 956,2
185 | 639,2
186 | 249,2
187 | 731,0
188 | 716,1
189 | 791,2
190 | 471,0
191 | 394,2
192 | 270,2
193 | 237,2
194 | 431,2
195 | 699,2
196 | 877,2
197 | 71,0
198 | 212,2
199 | 480,2
200 | 104,2
201 | 863,2
202 | 422,0
203 | 453,2
204 | 708,2
205 | 289,2
206 | 224,0
207 | 54,2
208 | 590,2
209 | 161,1
210 | 40,2
211 | 499,0
212 | 184,2
213 | 333,2
214 | 430,1
215 | 0,2
216 | 162,2
217 | 66,2
218 | 652,1
219 | 864,1
220 | 703,2
221 | 469,2
222 | 448,2
223 | 808,2
224 | 832,1
225 | 361,2
226 | 517,2
227 | 776,2
228 | 809,1
229 | 984,2
230 | 459,2
231 | 505,1
232 | 356,0
233 | 92,2
234 | 372,2
235 | 849,1
236 | 376,1
237 | 758,2
238 | 767,2
239 | 542,2
240 | 701,2
241 | 282,0
242 | 822,2
243 | 363,1
244 | 479,1
245 | 295,2
246 | 601,2
247 | 702,2
248 | 891,2
249 | 188,2
250 | 310,2
251 | 888,2
252 | 755,2
253 | 116,1
254 | 343,0
255 | 968,2
256 | 400,2
257 | 959,2
258 | 696,2
259 | 465,2
260 | 636,2
261 | 528,2
262 | 216,1
263 | 403,1
264 | 58,1
265 | 879,2
266 | 241,2
267 | 608,2
268 | 573,2
269 | 274,2
270 | 631,2
271 | 777,1
272 | 287,1
273 | 183,2
274 | 293,2
275 | 247,2
276 | 688,2
277 | 554,2
278 | 975,2
279 | 208,2
280 | 407,2
281 | 416,2
282 | 534,2
283 | 660,2
284 | 927,2
285 | 669,1
286 | 655,2
287 | 842,2
288 | 196,1
289 | 126,2
290 | 217,2
291 | 368,2
292 | 887,2
293 | 560,2
294 | 33,2
295 | 819,2
296 | 62,2
297 | 369,2
298 | 535,2
299 | 531,1
300 | 621,0
301 | 199,0
302 | 73,2
303 | 821,1
304 | 478,0
305 | 256,1
306 | 695,1
307 | 962,2
308 | 662,2
309 | 4,1
310 | 890,0
311 | 740,2
312 | 757,0
313 | 44,2
314 | 336,2
315 | 986,2
316 | 837,2
317 | 70,2
318 | 365,2
319 | 483,1
320 | 355,2
321 | 213,2
322 | 95,2
323 | 820,2
324 | 275,2
325 | 570,1
326 | 294,2
327 | 81,2
328 | 63,2
329 | 79,2
330 | 84,2
331 | 747,0
332 | 150,2
333 | 589,2
334 | 538,1
335 | 496,2
336 | 103,2
337 | 374,1
338 | 475,2
339 | 80,1
340 | 503,2
341 | 857,2
342 | 339,0
343 | 137,2
344 | 114,2
345 | 948,2
346 | 576,2
347 | 970,2
348 | 964,2
349 | 78,2
350 | 319,2
351 | 950,1
352 | 795,2
353 | 737,0
354 | 461,2
355 | 159,0
356 | 759,0
357 | 39,2
358 | 59,2
359 | 928,2
360 | 481,2
361 | 624,1
362 | 195,2
363 | 906,0
364 | 806,2
365 | 582,1
366 | 337,2
367 | 812,1
368 | 550,2
369 | 782,0
370 | 49,2
371 | 438,2
372 | 102,2
373 | 252,2
374 | 464,1
375 | 322,2
376 | 993,1
377 | 12,2
378 | 786,1
379 | 680,1
380 | 824,2
381 | 653,1
382 | 925,1
383 | 522,2
384 | 686,2
385 | 171,2
386 | 88,2
387 | 335,1
388 | 181,2
389 | 488,0
390 | 132,2
391 | 980,1
392 | 860,2
393 | 153,2
394 | 829,2
395 | 588,2
396 | 175,2
397 | 272,0
398 | 255,2
399 | 449,0
400 | 557,2
401 | 6,2
402 | 527,2
403 | 555,2
404 | 170,2
405 | 558,2
406 | 426,1
407 | 947,2
408 | 761,2
409 | 24,1
410 | 544,2
411 | 661,2
412 | 494,0
413 | 914,0
414 | 68,1
415 | 764,2
416 | 622,2
417 | 330,2
418 | 664,0
419 | 942,2
420 | 404,2
421 | 439,2
422 | 714,2
423 | 629,2
424 | 742,2
425 | 850,0
426 | 991,1
427 | 9,2
428 | 846,2
429 | 123,2
430 | 790,2
431 | 42,2
432 | 918,0
433 | 296,1
434 | 130,0
435 | 166,2
436 | 804,2
437 | 983,2
438 | 112,1
439 | 926,2
440 | 789,2
441 | 388,2
442 | 912,2
443 | 156,2
444 | 644,2
445 | 700,2
446 | 406,2
447 | 687,2
448 | 387,2
449 | 328,2
450 | 866,2
451 | 853,2
452 | 424,2
453 | 220,2
454 | 775,2
455 | 593,2
456 | 178,1
457 | 221,2
458 | 955,2
459 | 977,2
460 | 602,2
461 | 941,2
462 | 760,2
463 | 562,2
464 | 852,2
465 | 973,2
466 | 233,2
467 | 723,2
468 | 467,2
469 | 30,2
470 | 681,1
471 | 301,1
472 | 441,2
473 | 236,1
474 | 584,2
475 | 298,2
476 | 227,2
477 | 384,2
478 | 45,2
479 | 186,2
480 | 286,0
481 | 633,1
482 | 936,2
483 | 504,0
484 | 827,1
485 | 329,2
486 | 547,2
487 | 859,2
488 | 749,1
489 | 913,2
490 | 643,2
491 | 710,2
492 | 578,2
493 | 500,2
494 | 803,2
495 | 232,1
496 | 565,1
497 | 223,2
498 | 923,2
499 | 300,2
500 | 397,2
501 | 267,2
502 | 773,2
503 | 876,0
504 | 987,1
505 | 108,2
506 | 729,2
507 | 513,1
508 | 50,1
509 | 94,2
510 | 895,2
511 | 491,2
512 | 303,0
513 | 91,1
514 | 907,2
515 | 307,2
516 | 569,1
517 | 910,2
518 | 386,2
519 | 650,0
520 | 74,2
521 | 452,2
522 | 961,2
523 | 728,2
524 | 516,2
525 | 836,1
526 | 265,2
527 | 462,2
528 | 571,2
529 | 65,2
530 | 586,2
531 | 861,2
532 | 278,2
533 | 988,2
534 | 733,0
535 | 311,2
536 | 599,2
537 | 482,2
538 | 526,2
539 | 638,2
540 | 541,1
541 | 347,1
542 | 750,2
543 | 262,1
544 | 419,2
545 | 595,2
546 | 200,2
547 | 533,1
548 | 613,2
549 | 117,1
550 | 146,2
551 | 190,2
552 | 253,2
553 | 990,2
554 | 646,2
555 | 807,2
556 | 210,1
557 | 67,2
558 | 539,2
559 | 651,0
560 | 204,2
561 | 141,2
562 | 1,2
563 | 620,2
564 | 974,1
565 | 110,0
566 | 69,2
567 | 373,2
568 | 612,2
569 | 317,1
570 | 396,2
571 | 781,0
572 | 628,2
573 | 978,2
574 | 935,0
575 | 841,2
576 | 878,2
577 | 2,2
578 | 689,2
579 | 127,2
580 | 243,2
581 | 240,2
582 | 359,2
583 | 851,0
584 | 556,2
585 | 76,0
586 | 831,2
587 | 712,2
588 | 998,2
589 | 706,2
590 | 7,2
591 | 730,2
592 | 830,2
593 | 656,1
594 | 769,1
595 | 290,2
596 | 31,2
597 | 346,0
598 | 158,2
599 | 55,2
600 | 225,2
601 | 263,2
602 | 989,2
603 | 250,2
604 | 341,2
605 | 506,2
606 | 194,2
607 | 780,2
608 | 444,2
609 | 378,2
610 | 51,2
611 | 917,0
612 | 540,1
613 | 214,2
614 | 940,2
615 | 350,2
616 | 949,2
617 | 869,2
618 | 401,2
619 | 530,2
620 | 572,2
621 | 211,2
622 | 489,2
623 | 937,1
624 | 16,2
625 | 477,1
626 | 868,2
627 | 678,2
628 | 673,2
629 | 304,2
630 | 774,1
631 | 976,2
632 | 338,2
633 | 611,2
634 | 107,2
635 | 264,2
636 | 285,2
637 | 271,2
638 | 269,2
639 | 924,2
640 | 835,2
641 | 640,2
642 | 642,2
643 | 334,2
644 | 35,0
645 | 353,1
646 | 967,2
647 | 915,2
648 | 518,2
649 | 411,2
650 | 870,2
651 | 147,0
652 | 875,2
653 | 501,2
654 | 606,2
655 | 470,2
656 | 545,1
657 | 493,2
658 | 316,2
659 | 234,0
660 | 592,0
661 | 981,0
662 | 415,0
663 | 691,2
664 | 823,2
665 | 882,0
666 | 674,1
667 | 395,0
668 | 724,0
669 | 625,2
670 | 707,2
671 | 705,2
672 | 436,0
673 | 14,2
674 | 484,2
675 | 818,1
676 | 677,1
677 | 564,2
678 | 854,2
679 | 222,2
680 | 427,2
681 | 492,0
682 | 559,2
683 | 799,2
684 | 423,0
685 | 748,2
686 | 867,1
687 | 744,2
688 | 502,2
689 | 787,2
690 | 472,2
691 | 825,1
692 | 473,0
693 | 325,2
694 | 440,1
695 | 894,2
696 | 202,1
697 | 709,1
698 | 242,2
699 | 53,2
700 | 597,2
701 | 207,2
702 | 455,2
703 | 523,2
704 | 726,2
705 | 880,2
706 | 291,2
707 | 367,0
708 | 667,2
709 | 668,0
710 | 379,2
711 | 164,2
712 | 963,2
713 | 309,2
714 | 314,2
715 | 551,2
716 | 135,2
717 | 332,0
718 | 721,2
719 | 460,1
720 | 340,2
721 | 15,1
722 | 768,1
723 | 521,2
724 | 566,1
725 | 261,2
726 | 632,2
727 | 585,2
728 | 641,2
729 | 994,2
730 | 129,2
731 | 244,0
732 | 903,1
733 | 872,1
734 | 197,2
735 | 192,1
736 | 813,0
737 | 318,2
738 | 149,2
739 | 148,2
740 | 753,2
741 | 873,0
742 | 96,2
743 | 817,2
744 | 746,2
745 | 971,2
746 | 549,2
747 | 784,2
748 | 364,2
749 | 713,2
750 | 985,0
751 | 862,2
752 | 34,2
753 | 447,2
754 | 32,2
755 | 312,2
756 | 111,1
757 | 683,2
758 | 548,2
759 | 939,2
760 | 623,2
761 | 177,1
762 | 515,2
763 | 603,2
764 | 433,2
765 | 788,2
766 | 36,0
767 | 315,2
768 | 856,2
769 | 120,0
770 | 48,0
771 | 577,2
772 | 797,2
773 | 302,2
774 | 228,1
775 | 375,0
776 | 886,2
777 | 390,2
778 | 98,2
779 | 954,2
780 | 553,0
781 | 616,1
782 | 635,0
783 | 408,2
784 | 732,1
785 | 174,2
786 | 770,2
787 | 897,0
788 | 529,2
789 | 399,2
790 | 892,2
791 | 288,2
792 | 352,1
793 | 511,2
794 | 28,2
795 | 143,2
796 | 670,2
797 | 697,1
798 | 260,2
799 | 679,2
800 | 598,0
801 | 900,2
802 | 77,2
803 | 720,2
804 | 106,2
805 | 89,1
806 | 281,0
807 | 163,0
808 | 815,0
809 | 672,2
810 | 754,1
811 | 889,2
812 | 762,2
813 | 60,2
814 | 901,2
815 | 953,2
816 | 113,0
817 | 905,1
818 | 966,2
819 | 508,2
820 | 299,2
821 | 101,2
822 | 663,2
823 | 154,2
824 | 437,0
825 | 952,0
826 | 930,2
827 | 591,2
828 | 82,1
829 | 972,1
830 | 793,2
831 | 524,2
832 | 19,2
833 | 615,2
834 | 324,0
835 | 945,2
836 | 675,2
837 | 909,2
838 | 323,1
839 | 185,2
840 | 685,2
841 | 711,0
842 | 627,2
843 | 929,2
844 | 919,2
845 | 327,2
846 | 321,2
847 | 169,1
848 | 122,0
849 | 581,2
850 | 568,2
851 | 532,1
852 | 779,2
853 | 189,2
854 | 690,2
855 | 358,2
856 | 814,0
857 | 277,0
858 | 308,1
859 | 725,1
860 | 83,2
861 | 752,2
862 | 205,2
863 | 463,0
864 | 134,2
865 | 676,2
866 | 563,1
867 | 969,2
868 | 618,0
869 | 600,2
870 | 203,1
871 | 671,2
872 | 99,2
873 | 457,0
874 | 251,1
875 | 446,1
876 | 434,2
877 | 840,2
878 | 999,2
879 | 684,2
880 | 179,0
881 | 398,2
882 | 443,2
883 | 52,2
884 | 816,2
885 | 425,2
886 | 326,2
887 | 802,2
888 | 23,2
889 | 743,1
890 | 226,1
891 | 93,2
892 | 191,2
893 | 109,2
894 | 657,2
895 | 682,2
896 | 519,2
897 | 992,1
898 | 381,1
899 | 429,2
900 | 543,2
901 | 979,2
902 | 833,1
903 | 133,2
904 | 659,2
905 | 805,0
906 | 871,2
907 | 510,1
908 | 20,2
909 | 26,2
910 | 354,2
911 | 951,2
912 | 3,2
913 | 218,2
914 | 490,2
915 | 933,2
916 | 257,1
917 | 306,2
918 | 391,1
919 | 313,1
920 | 486,2
921 | 118,2
922 | 778,2
923 | 43,2
924 | 648,2
925 | 115,2
926 | 607,2
927 | 845,2
928 | 140,1
929 | 380,2
930 | 351,2
931 | 414,2
932 | 904,2
933 | 38,2
934 | 715,0
935 | 417,2
936 | 855,1
937 | 268,2
938 | 911,2
939 | 546,1
940 | 348,2
941 | 738,2
942 | 56,0
943 | 389,1
944 | 428,1
945 | 583,0
946 | 284,1
947 | 931,2
948 | 238,2
949 | 921,2
950 | 771,2
951 | 402,0
952 | 884,2
953 | 100,1
954 | 456,2
955 | 167,2
956 | 349,2
957 | 266,2
958 | 342,0
959 | 413,2
960 | 435,2
961 | 22,1
962 | 246,2
963 | 209,2
964 | 259,2
965 | 881,1
966 | 231,1
967 | 152,2
968 | 579,2
969 | 17,2
970 | 279,2
971 | 574,2
972 | 765,2
973 | 898,2
974 | 144,2
975 | 136,2
976 | 982,2
977 | 727,2
978 | 826,2
979 | 932,2
980 | 507,2
981 | 21,1
982 | 410,2
983 | 908,2
984 | 666,2
985 | 283,2
986 | 138,0
987 | 751,2
988 | 37,1
989 | 801,2
990 | 920,2
991 | 125,2
992 | 614,2
993 | 75,2
994 | 450,0
995 | 509,2
996 | 785,2
997 | 899,1
998 | 85,2
999 | 958,2
1000 | 848,2
1001 | 458,0
1002 |
--------------------------------------------------------------------------------
/paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/benedekrozemberczki/CapsGNN/e665c3c78bcee01f9814c885fea27b5c32c0f467/paper.pdf
--------------------------------------------------------------------------------
/src/capsgnn.py:
--------------------------------------------------------------------------------
1 | """CapsGNN Trainer."""
2 |
3 | import glob
4 | import json
5 | import random
6 | import torch
7 | import numpy as np
8 | import pandas as pd
9 | from tqdm import tqdm, trange
10 | from torch_geometric.nn import GCNConv
11 | from utils import create_numeric_mapping
12 | from layers import ListModule, PrimaryCapsuleLayer, Attention, SecondaryCapsuleLayer
13 | from layers import margin_loss
14 |
15 | class CapsGNN(torch.nn.Module):
16 | """
17 | An implementation of themodel described in the following paper:
18 | https://openreview.net/forum?id=Byl8BnRcYm
19 | """
20 | def __init__(self, args, number_of_features, number_of_targets):
21 | super(CapsGNN, self).__init__()
22 | """
23 | :param args: Arguments object.
24 | :param number_of_features: Number of vertex features.
25 | :param number_of_targets: Number of classes.
26 | """
27 | self.args = args
28 | self.number_of_features = number_of_features
29 | self.number_of_targets = number_of_targets
30 | self._setup_layers()
31 |
32 | def _setup_base_layers(self):
33 | """
34 | Creating GCN layers.
35 | """
36 | self.base_layers = [GCNConv(self.number_of_features, self.args.gcn_filters)]
37 | for _ in range(self.args.gcn_layers-1):
38 | self.base_layers.append(GCNConv(self.args.gcn_filters, self.args.gcn_filters))
39 | self.base_layers = ListModule(*self.base_layers)
40 |
41 | def _setup_primary_capsules(self):
42 | """
43 | Creating primary capsules.
44 | """
45 | self.first_capsule = PrimaryCapsuleLayer(in_units=self.args.gcn_filters,
46 | in_channels=self.args.gcn_layers,
47 | num_units=self.args.gcn_layers,
48 | capsule_dimensions=self.args.capsule_dimensions)
49 |
50 | def _setup_attention(self):
51 | """
52 | Creating attention layer.
53 | """
54 | self.attention = Attention(self.args.gcn_layers*self.args.capsule_dimensions,
55 | self.args.inner_attention_dimension)
56 |
57 | def _setup_graph_capsules(self):
58 | """
59 | Creating graph capsules.
60 | """
61 | self.graph_capsule = SecondaryCapsuleLayer(self.args.gcn_layers,
62 | self.args.capsule_dimensions,
63 | self.args.number_of_capsules,
64 | self.args.capsule_dimensions)
65 |
66 | def _setup_class_capsule(self):
67 | """
68 | Creating class capsules.
69 | """
70 | self.class_capsule = SecondaryCapsuleLayer(self.args.capsule_dimensions,
71 | self.args.number_of_capsules,
72 | self.number_of_targets,
73 | self.args.capsule_dimensions)
74 |
75 | def _setup_reconstruction_layers(self):
76 | """
77 | Creating histogram reconstruction layers.
78 | """
79 | self.reconstruction_layer_1 = torch.nn.Linear(self.number_of_targets*self.args.capsule_dimensions,
80 | int((self.number_of_features*2)/3))
81 |
82 | self.reconstruction_layer_2 = torch.nn.Linear(int((self.number_of_features*2)/3),
83 | int((self.number_of_features*3)/2))
84 |
85 | self.reconstruction_layer_3 = torch.nn.Linear(int((self.number_of_features*3)/2),
86 | self.number_of_features)
87 |
88 | def _setup_layers(self):
89 | """
90 | Creating layers of model.
91 | 1. GCN layers.
92 | 2. Primary capsules.
93 | 3. Attention
94 | 4. Graph capsules.
95 | 5. Class capsules.
96 | 6. Reconstruction layers.
97 | """
98 | self._setup_base_layers()
99 | self._setup_primary_capsules()
100 | self._setup_attention()
101 | self._setup_graph_capsules()
102 | self._setup_class_capsule()
103 | self._setup_reconstruction_layers()
104 |
105 | def calculate_reconstruction_loss(self, capsule_input, features):
106 | """
107 | Calculating the reconstruction loss of the model.
108 | :param capsule_input: Output of class capsule.
109 | :param features: Feature matrix.
110 | :return reconstrcution_loss: Loss of reconstruction.
111 | """
112 |
113 | v_mag = torch.sqrt((capsule_input**2).sum(dim=1))
114 | _, v_max_index = v_mag.max(dim=0)
115 | v_max_index = v_max_index.data
116 |
117 | capsule_masked = torch.autograd.Variable(torch.zeros(capsule_input.size()))
118 | capsule_masked[v_max_index, :] = capsule_input[v_max_index, :]
119 | capsule_masked = capsule_masked.view(1, -1)
120 |
121 | feature_counts = features.sum(dim=0)
122 | feature_counts = feature_counts/feature_counts.sum()
123 |
124 | reconstruction_output = torch.nn.functional.relu(self.reconstruction_layer_1(capsule_masked))
125 | reconstruction_output = torch.nn.functional.relu(self.reconstruction_layer_2(reconstruction_output))
126 | reconstruction_output = torch.softmax(self.reconstruction_layer_3(reconstruction_output), dim=1)
127 | reconstruction_output = reconstruction_output.view(1, self.number_of_features)
128 | reconstruction_loss = torch.sum((features-reconstruction_output)**2)
129 | return reconstruction_loss
130 |
131 | def forward(self, data):
132 | """
133 | Forward propagation pass.
134 | :param data: Dictionary of tensors with features and edges.
135 | :return class_capsule_output: Class capsule outputs.
136 | """
137 | features = data["features"]
138 | edges = data["edges"]
139 | hidden_representations = []
140 |
141 | for layer in self.base_layers:
142 | features = torch.nn.functional.relu(layer(features, edges))
143 | hidden_representations.append(features)
144 |
145 | hidden_representations = torch.cat(tuple(hidden_representations))
146 | hidden_representations = hidden_representations.view(1, self.args.gcn_layers, self.args.gcn_filters, -1)
147 | first_capsule_output = self.first_capsule(hidden_representations)
148 | first_capsule_output = first_capsule_output.view(-1, self.args.gcn_layers*self.args.capsule_dimensions)
149 | rescaled_capsule_output = self.attention(first_capsule_output)
150 | rescaled_first_capsule_output = rescaled_capsule_output.view(-1, self.args.gcn_layers,
151 | self.args.capsule_dimensions)
152 | graph_capsule_output = self.graph_capsule(rescaled_first_capsule_output)
153 | reshaped_graph_capsule_output = graph_capsule_output.view(-1, self.args.capsule_dimensions,
154 | self.args.number_of_capsules)
155 | class_capsule_output = self.class_capsule(reshaped_graph_capsule_output)
156 | class_capsule_output = class_capsule_output.view(-1, self.number_of_targets*self.args.capsule_dimensions)
157 | class_capsule_output = torch.mean(class_capsule_output, dim=0).view(1,
158 | self.number_of_targets,
159 | self.args.capsule_dimensions)
160 | recon = class_capsule_output.view(self.number_of_targets, self.args.capsule_dimensions)
161 | reconstruction_loss = self.calculate_reconstruction_loss(recon, data["features"])
162 | return class_capsule_output, reconstruction_loss
163 |
164 |
165 | class CapsGNNTrainer(object):
166 | """
167 | CapsGNN training and scoring.
168 | """
169 | def __init__(self, args):
170 | """
171 | :param args: Arguments object.
172 | """
173 | self.args = args
174 | self.setup_model()
175 |
176 | def enumerate_unique_labels_and_targets(self):
177 | """
178 | Enumerating the features and targets in order to setup weights later.
179 | """
180 | print("\nEnumerating feature and target values.\n")
181 | ending = "*.json"
182 |
183 | self.train_graph_paths = glob.glob(self.args.train_graph_folder+ending)
184 | self.test_graph_paths = glob.glob(self.args.test_graph_folder+ending)
185 | graph_paths = self.train_graph_paths + self.test_graph_paths
186 |
187 | targets = set()
188 | features = set()
189 | for path in tqdm(graph_paths):
190 | data = json.load(open(path))
191 | targets = targets.union(set([data["target"]]))
192 | features = features.union(set(data["labels"]))
193 |
194 | self.target_map = create_numeric_mapping(targets)
195 | self.feature_map = create_numeric_mapping(features)
196 |
197 | self.number_of_features = len(self.feature_map)
198 | self.number_of_targets = len(self.target_map)
199 |
200 | def setup_model(self):
201 | """
202 | Enumerating labels and initializing a CapsGNN.
203 | """
204 | self.enumerate_unique_labels_and_targets()
205 | self.model = CapsGNN(self.args, self.number_of_features, self.number_of_targets)
206 |
207 | def create_batches(self):
208 | """
209 | Batching the graphs for training.
210 | """
211 | self.batches = []
212 | for i in range(0, len(self.train_graph_paths), self.args.batch_size):
213 | self.batches.append(self.train_graph_paths[i:i+self.args.batch_size])
214 |
215 | def create_data_dictionary(self, target, edges, features):
216 | """
217 | Creating a data dictionary.
218 | :param target: Target vector.
219 | :param edges: Edge list tensor.
220 | :param features: Feature tensor.
221 | """
222 | to_pass_forward = dict()
223 | to_pass_forward["target"] = target
224 | to_pass_forward["edges"] = edges
225 | to_pass_forward["features"] = features
226 | return to_pass_forward
227 |
228 | def create_target(self, data):
229 | """
230 | Target createn based on data dicionary.
231 | :param data: Data dictionary.
232 | :return : Target vector.
233 | """
234 | return torch.FloatTensor([0.0 if i != data["target"] else 1.0 for i in range(self.number_of_targets)])
235 |
236 | def create_edges(self, data):
237 | """
238 | Create an edge matrix.
239 | :param data: Data dictionary.
240 | :return : Edge matrix.
241 | """
242 | edges = [[edge[0], edge[1]] for edge in data["edges"]]
243 | edges = edges + [[edge[1], edge[0]] for edge in data["edges"]]
244 | return torch.t(torch.LongTensor(edges))
245 |
246 | def create_features(self, data):
247 | """
248 | Create feature matrix.
249 | :param data: Data dictionary.
250 | :return features: Matrix of features.
251 | """
252 | features = np.zeros((len(data["labels"]), self.number_of_features))
253 | node_indices = [node for node in range(len(data["labels"]))]
254 | feature_indices = [self.feature_map[label] for label in data["labels"].values()]
255 | features[node_indices, feature_indices] = 1.0
256 | features = torch.FloatTensor(features)
257 | return features
258 |
259 | def create_input_data(self, path):
260 | """
261 | Creating tensors and a data dictionary with Torch tensors.
262 | :param path: path to the data JSON.
263 | :return to_pass_forward: Data dictionary.
264 | """
265 | data = json.load(open(path))
266 | target = self.create_target(data)
267 | edges = self.create_edges(data)
268 | features = self.create_features(data)
269 | to_pass_forward = self.create_data_dictionary(target, edges, features)
270 | return to_pass_forward
271 |
272 | def fit(self):
273 | """
274 | Training a model on the training set.
275 | """
276 | print("\nTraining started.\n")
277 | self.model.train()
278 | optimizer = torch.optim.Adam(self.model.parameters(),
279 | lr=self.args.learning_rate,
280 | weight_decay=self.args.weight_decay)
281 |
282 | for _ in tqdm(range(self.args.epochs), desc="Epochs: ", leave=True):
283 | random.shuffle(self.train_graph_paths)
284 | self.create_batches()
285 | losses = 0
286 | self.steps = trange(len(self.batches), desc="Loss")
287 | for step in self.steps:
288 | accumulated_losses = 0
289 | optimizer.zero_grad()
290 | batch = self.batches[step]
291 | for path in batch:
292 | data = self.create_input_data(path)
293 | prediction, reconstruction_loss = self.model(data)
294 | loss = margin_loss(prediction,
295 | data["target"],
296 | self.args.lambd)
297 | loss = loss+self.args.theta*reconstruction_loss
298 | accumulated_losses = accumulated_losses + loss
299 | accumulated_losses = accumulated_losses/len(batch)
300 | accumulated_losses.backward()
301 | optimizer.step()
302 | losses = losses + accumulated_losses.item()
303 | average_loss = losses/(step + 1)
304 | self.steps.set_description("CapsGNN (Loss=%g)" % round(average_loss, 4))
305 |
306 | def score(self):
307 | """
308 | Scoring on the test set.
309 | """
310 | print("\n\nScoring.\n")
311 | self.model.eval()
312 | self.predictions = []
313 | self.hits = []
314 | for path in tqdm(self.test_graph_paths):
315 | data = self.create_input_data(path)
316 | prediction, _ = self.model(data)
317 | prediction_mag = torch.sqrt((prediction**2).sum(dim=2))
318 | _, prediction_max_index = prediction_mag.max(dim=1)
319 | prediction = prediction_max_index.data.view(-1).item()
320 | self.predictions.append(prediction)
321 | self.hits.append(data["target"][prediction] == 1.0)
322 |
323 | print("\nAccuracy: " + str(round(np.mean(self.hits), 4)))
324 |
325 | def save_predictions(self):
326 | """
327 | Saving the test set predictions.
328 | """
329 | identifiers = [path.split("/")[-1].strip(".json") for path in self.test_graph_paths]
330 | out = pd.DataFrame()
331 | out["id"] = identifiers
332 | out["predictions"] = self.predictions
333 | out.to_csv(self.args.prediction_path, index=None)
334 |
--------------------------------------------------------------------------------
/src/layers.py:
--------------------------------------------------------------------------------
1 | """CapsGNN layers."""
2 |
3 | import torch
4 | from torch.autograd import Variable
5 |
6 | class ListModule(torch.nn.Module):
7 | """
8 | Abstract list layer class.
9 | """
10 | def __init__(self, *args):
11 | """
12 | Model initializing.
13 | """
14 | super(ListModule, self).__init__()
15 | idx = 0
16 | for module in args:
17 | self.add_module(str(idx), module)
18 | idx += 1
19 |
20 | def __getitem__(self, idx):
21 | """
22 | Getting the indexed layer.
23 | """
24 | if idx < 0 or idx >= len(self._modules):
25 | raise IndexError('index {} is out of range'.format(idx))
26 | it = iter(self._modules.values())
27 | for _ in range(idx):
28 | next(it)
29 | return next(it)
30 |
31 | def __iter__(self):
32 | """
33 | Iterating on the layers.
34 | """
35 | return iter(self._modules.values())
36 |
37 | def __len__(self):
38 | """
39 | Number of layers.
40 | """
41 | return len(self._modules)
42 |
43 | class PrimaryCapsuleLayer(torch.nn.Module):
44 | """
45 | Primary Convolutional Capsule Layer class based on:
46 | https://github.com/timomernick/pytorch-capsule.
47 | """
48 | def __init__(self, in_units, in_channels, num_units, capsule_dimensions):
49 | super(PrimaryCapsuleLayer, self).__init__()
50 | """
51 | :param in_units: Number of input units (GCN layers).
52 | :param in_channels: Number of channels.
53 | :param num_units: Number of capsules.
54 | :param capsule_dimensions: Number of neurons in capsule.
55 | """
56 | self.num_units = num_units
57 | self.units = []
58 | for i in range(self.num_units):
59 | unit = torch.nn.Conv1d(in_channels=in_channels,
60 | out_channels=capsule_dimensions,
61 | kernel_size=(in_units, 1),
62 | stride=1,
63 | bias=True)
64 |
65 | self.add_module("unit_" + str(i), unit)
66 | self.units.append(unit)
67 |
68 | @staticmethod
69 | def squash(s):
70 | """
71 | Squash activations.
72 | :param s: Signal.
73 | :return s: Activated signal.
74 | """
75 | mag_sq = torch.sum(s**2, dim=2, keepdim=True)
76 | mag = torch.sqrt(mag_sq)
77 | s = (mag_sq / (1.0 + mag_sq)) * (s / mag)
78 | return s
79 |
80 | def forward(self, x):
81 | """
82 | Forward propagation pass.
83 | :param x: Input features.
84 | :return : Primary capsule features.
85 | """
86 | u = [self.units[i](x) for i in range(self.num_units)]
87 | u = torch.stack(u, dim=1)
88 | u = u.view(x.size(0), self.num_units, -1)
89 | return PrimaryCapsuleLayer.squash(u)
90 |
91 | class SecondaryCapsuleLayer(torch.nn.Module):
92 | """
93 | Secondary Convolutional Capsule Layer class based on this repostory:
94 | https://github.com/timomernick/pytorch-capsule
95 | """
96 | def __init__(self, in_units, in_channels, num_units, unit_size):
97 | super(SecondaryCapsuleLayer, self).__init__()
98 | """
99 | :param in_units: Number of input units (GCN layers).
100 | :param in_channels: Number of channels.
101 | :param num_units: Number of capsules.
102 | :param capsule_dimensions: Number of neurons in capsule.
103 | """
104 | self.in_units = in_units
105 | self.in_channels = in_channels
106 | self.num_units = num_units
107 | self.W = torch.nn.Parameter(torch.randn(1, in_channels, num_units, unit_size, in_units))
108 |
109 | @staticmethod
110 | def squash(s):
111 | """
112 | Squash activations.
113 | :param s: Signal.
114 | :return s: Activated signal.
115 | """
116 | mag_sq = torch.sum(s**2, dim=2, keepdim=True)
117 | mag = torch.sqrt(mag_sq)
118 | s = (mag_sq / (1.0 + mag_sq)) * (s / mag)
119 | return s
120 |
121 | def forward(self, x):
122 | """
123 | Forward propagation pass.
124 | :param x: Input features.
125 | :return : Capsule output.
126 | """
127 | batch_size = x.size(0)
128 | x = x.transpose(1, 2)
129 | x = torch.stack([x] * self.num_units, dim=2).unsqueeze(4)
130 | W = torch.cat([self.W] * batch_size, dim=0)
131 | u_hat = torch.matmul(W, x)
132 | b_ij = Variable(torch.zeros(1, self.in_channels, self.num_units, 1))
133 |
134 | num_iterations = 3
135 |
136 | for _ in range(num_iterations):
137 | c_ij = torch.nn.functional.softmax(b_ij, dim=2)
138 | c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)
139 | s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
140 | v_j = SecondaryCapsuleLayer.squash(s_j)
141 | v_j1 = torch.cat([v_j] * self.in_channels, dim=1)
142 | u_vj1 = torch.matmul(u_hat.transpose(3, 4), v_j1).squeeze(4).mean(dim=0, keepdim=True)
143 | b_ij = b_ij + u_vj1
144 | # b_max = torch.max(b_ij, dim = 2, keepdim = True)
145 | # b_ij = b_ij / b_max.values ## values can be zero so loss would be nan
146 | return v_j.squeeze(1)
147 |
148 | class Attention(torch.nn.Module):
149 | """
150 | 2 Layer Attention Module.
151 | See the CapsGNN paper for details.
152 | """
153 | def __init__(self, attention_size_1, attention_size_2):
154 | super(Attention, self).__init__()
155 | """
156 | :param attention_size_1: Number of neurons in 1st attention layer.
157 | :param attention_size_2: Number of neurons in 2nd attention layer.
158 | """
159 | self.attention_1 = torch.nn.Linear(attention_size_1, attention_size_2)
160 | self.attention_2 = torch.nn.Linear(attention_size_2, attention_size_1)
161 |
162 | def forward(self, x_in):
163 | """
164 | Forward propagation pass.
165 | :param x_in: Primary capsule output.
166 | :param condensed_x: Attention normalized capsule output.
167 | """
168 | attention_score_base = self.attention_1(x_in)
169 | attention_score_base = torch.nn.functional.relu(attention_score_base)
170 | attention_score = self.attention_2(attention_score_base)
171 | attention_score = torch.nn.functional.softmax(attention_score, dim=0)
172 | condensed_x = x_in *attention_score
173 | return condensed_x
174 |
175 | def margin_loss(scores, target, loss_lambda):
176 | """
177 | The margin loss from the original paper. Based on:
178 | https://github.com/timomernick/pytorch-capsule
179 | :param scores: Capsule scores.
180 | :param target: Target groundtruth.
181 | :param loss_lambda: Regularization parameter.
182 | :return L_c: Classification loss.
183 | """
184 | scores = scores.squeeze()
185 | v_mag = torch.sqrt((scores**2).sum(dim=1, keepdim=True))
186 | zero = Variable(torch.zeros(1))
187 | m_plus = 0.9
188 | m_minus = 0.1
189 | max_l = torch.max(m_plus - v_mag, zero).view(1, -1)**2
190 | max_r = torch.max(v_mag - m_minus, zero).view(1, -1)**2
191 | T_c = Variable(torch.zeros(v_mag.shape))
192 | T_c = target
193 | L_c = T_c * max_l + loss_lambda * (1.0 - T_c) * max_r
194 | L_c = L_c.sum(dim=1)
195 | L_c = L_c.mean()
196 | return L_c
197 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | """Running CapsGNN."""
2 |
3 | from utils import tab_printer
4 | from capsgnn import CapsGNNTrainer
5 | from param_parser import parameter_parser
6 |
7 | def main():
8 | """
9 | Parsing command line parameters, processing graphs, fitting a CapsGNN.
10 | """
11 | args = parameter_parser()
12 | tab_printer(args)
13 | model = CapsGNNTrainer(args)
14 | model.fit()
15 | model.score()
16 | model.save_predictions()
17 |
18 | if __name__ == "__main__":
19 | main()
20 |
--------------------------------------------------------------------------------
/src/param_parser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def parameter_parser():
4 | """
5 | A method to parse up command line parameters. By default it learns on the Watts-Strogatz dataset.
6 | The default hyperparameters give good results without cross-validation.
7 | """
8 | parser = argparse.ArgumentParser(description="Run CapsGNN.")
9 |
10 | parser.add_argument("--train-graph-folder",
11 | nargs="?",
12 | default="./input/train/",
13 | help="Training graphs folder.")
14 |
15 | parser.add_argument("--test-graph-folder",
16 | nargs="?",
17 | default="./input/test/",
18 | help="Testing graphs folder.")
19 |
20 | parser.add_argument("--prediction-path",
21 | nargs="?",
22 | default="./output/watts_predictions.csv",
23 | help="Path to store the predicted graph labels.")
24 |
25 | parser.add_argument("--epochs",
26 | type=int,
27 | default=100,
28 | help="Number of training epochs. Default is 100.")
29 |
30 | parser.add_argument("--batch-size",
31 | type=int,
32 | default=32,
33 | help="Number of graphs processed per batch. Default is 32.")
34 |
35 | parser.add_argument("--gcn-filters",
36 | type=int,
37 | default=20,
38 | help="Number of Graph Convolutional filters. Default is 20.")
39 |
40 | parser.add_argument("--gcn-layers",
41 | type=int,
42 | default=2,
43 | help="Number of Graph Convolutional Layers. Default is 2.")
44 |
45 | parser.add_argument("--inner-attention-dimension",
46 | type=int,
47 | default=20,
48 | help="Number of Attention Neurons. Default is 20.")
49 |
50 | parser.add_argument("--capsule-dimensions",
51 | type=int,
52 | default=8,
53 | help="Capsule dimensions. Default is 8.")
54 |
55 | parser.add_argument("--number-of-capsules",
56 | type=int,
57 | default=8,
58 | help="Number of capsules per layer. Default is 8.")
59 |
60 | parser.add_argument("--weight-decay",
61 | type=float,
62 | default=10**-6,
63 | help="Weight decay. Default is 10^-6.")
64 |
65 | parser.add_argument("--learning-rate",
66 | type=float,
67 | default=0.01,
68 | help="Learning rate. Default is 0.01.")
69 |
70 | parser.add_argument("--lambd",
71 | type=float,
72 | default=0.5,
73 | help="Loss combination weight. Default is 0.5.")
74 |
75 | parser.add_argument("--theta",
76 | type=float,
77 | default=0.1,
78 | help="Reconstruction loss weight. Default is 0.1.")
79 |
80 | return parser.parse_args()
81 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | """Data reading and printing utils."""
2 |
3 | from texttable import Texttable
4 |
5 | def tab_printer(args):
6 | """
7 | Function to print the logs in a nice tabular format.
8 | :param args: Parameters used for the model.
9 | """
10 | args = vars(args)
11 | keys = sorted(args.keys())
12 | t = Texttable()
13 | t.add_rows([["Parameter", "Value"]])
14 | t.add_rows([[k.replace("_", " ").capitalize(), args[k]] for k in keys])
15 | print(t.draw())
16 |
17 | def create_numeric_mapping(node_properties):
18 | """
19 | Create node feature map.
20 | :param node_properties: List of features sorted.
21 | :return : Feature numeric map.
22 | """
23 | return {value:i for i, value in enumerate(node_properties)}
24 |
--------------------------------------------------------------------------------