├── .gitignore ├── LICENSE ├── README.md ├── configs └── LayerDAG │ └── tpu_tile.yaml ├── data_files └── tpu_tile_processed │ ├── LICENSE │ ├── README.md │ ├── test.pth │ ├── train.pth │ └── val.pth ├── sample.py ├── setup_utils.py ├── src ├── dataset │ ├── __init__.py │ ├── general.py │ ├── layer_dag.py │ └── tpu_tile.py ├── eval │ ├── __init__.py │ ├── discriminator │ │ ├── __init__.py │ │ ├── base.py │ │ ├── data_utils.py │ │ └── mpnn.py │ └── tpu_tile.py └── model │ ├── __init__.py │ ├── diffusion.py │ └── layer_dag.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | wandb 3 | *.pth 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LayerDAG 2 | 3 | [[Paper]](https://arxiv.org/abs/2411.02322) 4 | 5 | ## Table of Contents 6 | 7 | - [Installation](#installation) 8 | - [Train](#train) 9 | - [Sample](#sample) 10 | - [Eval](#eval) 11 | - [Frequently Asked Questions](#frequently-asked-questions) 12 | * [Q1: libcusparse.so](#q1-libcusparseso) 13 | - [Citation](#citation) 14 | 15 | ## Installation 16 | 17 | ```bash 18 | conda create -n LayerDAG python=3.10 -y 19 | conda activate LayerDAG 20 | pip install torch==1.12.0+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 21 | conda install -c conda-forge cudatoolkit=11.6 22 | conda clean --all -y 23 | pip install dgl==1.1.0+cu116 -f https://data.dgl.ai/wheels/cu116/repo.html 24 | pip install tqdm einops wandb pydantic pandas 25 | pip install numpy==1.26.3 26 | ``` 27 | 28 | ## Train 29 | 30 | To train a LayerDAG model, 31 | 32 | ```bash 33 | python train.py --config_file configs/LayerDAG/tpu_tile.yaml 34 | ``` 35 | 36 | The trained model checkpoint will be saved to a file `model_tpu_tile_{time_stamp}.pth`. 37 | 38 | ## Sample 39 | 40 | For sampling and evaluation, 41 | 42 | ```bash 43 | python sample.py --model_path X 44 | ``` 45 | 46 | where `X` is the file `model_tpu_tile_{time_stamp}.pth` saved above. 47 | 48 | ## Frequently Asked Questions 49 | 50 | ### Q1: libcusparse.so 51 | 52 | **An error occurs that the program cannot find `libcusparse.so`**, e.g., OSError: libcusparse.so.11: cannot open shared object file: No such file or directory. 53 | 54 | To search for the location of it on linux, 55 | 56 | ```bash 57 | find /path/to/directory -name libcusparse.so.11 -exec realpath {} \; 58 | ``` 59 | 60 | where `/path/to/directory` is the directory you want to search. Assume that the search returns `/home/miniconda3/envs/LayerDAG/lib/libcusparse.so.11`. Then you need to manually specify the environment variable as follows. 61 | 62 | ```bash 63 | export LD_LIBRARY_PATH=/home/miniconda3/envs/LayerDAG/lib:$LD_LIBRARY_PATH 64 | ``` 65 | 66 | ## Citation 67 | 68 | ``` 69 | @inproceedings{li2024layerdag, 70 | title={Layer{DAG}: A Layerwise Autoregressive Diffusion Model for Directed Acyclic Graph Generation}, 71 | author={Mufei Li and Viraj Shitole and Eli Chien and Changhai Man and Zhaodong Wang and Srinivas Sridharan and Ying Zhang and Tushar Krishna and Pan Li}, 72 | booktitle={International Conference on Learning Representations}, 73 | year={2025} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /configs/LayerDAG/tpu_tile.yaml: -------------------------------------------------------------------------------- 1 | general : 2 | dataset: tpu_tile 3 | conditional: true 4 | patience: 10 5 | 6 | node_count : 7 | loader : 8 | batch_size: 128 9 | num_workers: 4 10 | model : 11 | x_n_emb_size: 64 12 | y_emb_size: 256 13 | num_mpnn_layers: 3 14 | pool: 'sum' 15 | num_epochs : 500 16 | optimizer : 17 | lr: 0.0003 18 | amsgrad: true 19 | 20 | node_pred : 21 | T : 64 22 | loader : 23 | batch_size: 256 24 | num_workers: 4 25 | num_epochs : 700 26 | graph_encoder : 27 | x_n_emb_size: 512 28 | y_emb_size: 512 29 | num_mpnn_layers: 2 30 | pool: 'sum' 31 | predictor : 32 | t_emb_size: 256 33 | out_hidden_size: 512 34 | num_transformer_layers: 1 35 | num_heads: 4 36 | dropout: 0 37 | optimizer : 38 | lr: 0.0003 39 | amsgrad: true 40 | 41 | edge_pred : 42 | T : 16 43 | loader : 44 | batch_size: 256 45 | num_workers: 4 46 | num_epochs : 1000 47 | graph_encoder : 48 | x_n_emb_size: 256 49 | y_emb_size: 64 50 | num_mpnn_layers: 4 51 | predictor : 52 | t_emb_size: 256 53 | out_hidden_size: 320 54 | optimizer : 55 | lr: 0.0003 56 | amsgrad: true 57 | -------------------------------------------------------------------------------- /data_files/tpu_tile_processed/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /data_files/tpu_tile_processed/README.md: -------------------------------------------------------------------------------- 1 | # TPU Tile 2 | 3 | The TpuGraphs dataset is released at https://github.com/google-research-datasets/tpu_graphs/tree/main under the license of Apache-2.0. We adapt the tile collection of the dataset by averaging the normalized runtimes acorss compiler configurations. 4 | -------------------------------------------------------------------------------- /data_files/tpu_tile_processed/test.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-COM/LayerDAG/bc22fed6f112c3eef2e991551bb0677038df3467/data_files/tpu_tile_processed/test.pth -------------------------------------------------------------------------------- /data_files/tpu_tile_processed/train.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-COM/LayerDAG/bc22fed6f112c3eef2e991551bb0677038df3467/data_files/tpu_tile_processed/train.pth -------------------------------------------------------------------------------- /data_files/tpu_tile_processed/val.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Graph-COM/LayerDAG/bc22fed6f112c3eef2e991551bb0677038df3467/data_files/tpu_tile_processed/val.pth -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from pprint import pprint 5 | from tqdm import tqdm 6 | 7 | from setup_utils import set_seed 8 | from src.dataset import load_dataset, DAGDataset 9 | from src.eval import TPUTileEvaluator 10 | from src.model import DiscreteDiffusion, EdgeDiscreteDiffusion, LayerDAG 11 | 12 | def sample_tpu_subset(args, device, dummy_category, model, subset): 13 | syn_set = DAGDataset(dummy_category, label=True) 14 | 15 | raw_y_batch = [] 16 | for i, y in enumerate(tqdm(subset.y)): 17 | raw_y_batch.append(y) 18 | if (len(raw_y_batch) == args.batch_size) or (i == len(subset.y) - 1): 19 | batch_edge_index, batch_x_n, batch_y = model.sample( 20 | device, len(raw_y_batch), raw_y_batch, 21 | min_num_steps_n=args.min_num_steps_n, 22 | max_num_steps_n=args.max_num_steps_n, 23 | min_num_steps_e=args.min_num_steps_e, 24 | max_num_steps_e=args.max_num_steps_e) 25 | 26 | for j in range(len(batch_edge_index)): 27 | edge_index_j = batch_edge_index[j] 28 | dst_j, src_j = edge_index_j.cpu() 29 | syn_set.add_data(src_j, dst_j, batch_x_n[j].cpu(), 30 | batch_y[j]) 31 | 32 | raw_y_batch = [] 33 | 34 | return syn_set 35 | 36 | def dump_to_file(syn_set, file_name, sample_dir): 37 | file_path = os.path.join(sample_dir, file_name) 38 | data_dict = { 39 | 'src_list': [], 40 | 'dst_list': [], 41 | 'x_n_list': [], 42 | 'y_list': [] 43 | } 44 | for i in range(len(syn_set)): 45 | src_i, dst_i, x_n_i, y_i = syn_set[i] 46 | 47 | data_dict['src_list'].append(src_i) 48 | data_dict['dst_list'].append(dst_i) 49 | data_dict['x_n_list'].append(x_n_i) 50 | data_dict['y_list'].append(y_i) 51 | 52 | torch.save(data_dict, file_path) 53 | 54 | def eval_tpu_tile(args, device, model): 55 | sample_dir = 'tpu_tile_samples' 56 | os.makedirs(sample_dir, exist_ok=True) 57 | 58 | evaluator = TPUTileEvaluator() 59 | train_set, val_set, _ = load_dataset('tpu_tile') 60 | 61 | train_syn_set = sample_tpu_subset(args, device, train_set.dummy_category, model, train_set) 62 | val_syn_set = sample_tpu_subset(args, device, train_set.dummy_category, model, val_set) 63 | 64 | evaluator.eval(train_syn_set, val_syn_set) 65 | 66 | dump_to_file(train_syn_set, 'train.pth', sample_dir) 67 | dump_to_file(val_syn_set, 'val.pth', sample_dir) 68 | 69 | def main(args): 70 | torch.set_num_threads(args.num_threads) 71 | 72 | device_str = 'cuda' if torch.cuda.is_available() else 'cpu' 73 | device = torch.device(device_str) 74 | 75 | ckpt = torch.load(args.model_path) 76 | 77 | dataset = ckpt['dataset'] 78 | assert dataset == 'tpu_tile' 79 | 80 | node_diffusion = DiscreteDiffusion(**ckpt['node_diffusion_config']) 81 | edge_diffusion = EdgeDiscreteDiffusion(**ckpt['edge_diffusion_config']) 82 | 83 | model = LayerDAG(device=device, 84 | node_diffusion=node_diffusion, 85 | edge_diffusion=edge_diffusion, 86 | **ckpt['model_config']) 87 | pprint(ckpt['model_config']) 88 | model.load_state_dict(ckpt['model_state_dict']) 89 | model.eval() 90 | set_seed(args.seed) 91 | 92 | eval_tpu_tile(args, device, model) 93 | 94 | if __name__ == '__main__': 95 | from argparse import ArgumentParser 96 | 97 | parser = ArgumentParser() 98 | parser.add_argument("--model_path", type=str, help="Path to the model.") 99 | parser.add_argument("--batch_size", type=int, default=256) 100 | parser.add_argument("--num_threads", type=int, default=24) 101 | parser.add_argument("--min_num_steps_n", type=int, default=None) 102 | parser.add_argument("--min_num_steps_e", type=int, default=None) 103 | parser.add_argument("--max_num_steps_n", type=int, default=None) 104 | parser.add_argument("--max_num_steps_e", type=int, default=None) 105 | parser.add_argument("--seed", type=int, default=0) 106 | args = parser.parse_args() 107 | 108 | main(args) 109 | -------------------------------------------------------------------------------- /setup_utils.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import numpy as np 3 | import pydantic 4 | import random 5 | import torch 6 | import yaml 7 | 8 | from typing import Optional 9 | 10 | def set_seed(seed=0): 11 | if seed is None: 12 | return 13 | 14 | dgl.seed(seed) 15 | np.random.seed(seed) 16 | random.seed(seed) 17 | torch.manual_seed(seed) 18 | if torch.cuda.is_available(): 19 | torch.cuda.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | 24 | class DataLoaderYaml(pydantic.BaseModel): 25 | batch_size: int 26 | num_workers: int 27 | 28 | class BiMPNNYaml(pydantic.BaseModel): 29 | x_n_emb_size: int 30 | pe_emb_size: Optional[int] = 0 31 | y_emb_size: Optional[int] = 0 32 | num_mpnn_layers: int 33 | pool: Optional[str] = None 34 | pe: Optional[str] = None 35 | 36 | class OptimizerYaml(pydantic.BaseModel): 37 | lr: float 38 | amsgrad: bool 39 | 40 | class NodeCountYaml(pydantic.BaseModel): 41 | loader: DataLoaderYaml 42 | model: BiMPNNYaml 43 | num_epochs: int 44 | optimizer: OptimizerYaml 45 | 46 | class NodePredictorYaml(pydantic.BaseModel): 47 | t_emb_size: int 48 | out_hidden_size: int 49 | num_transformer_layers: int 50 | num_heads: int 51 | dropout: float 52 | 53 | class NodePredYaml(pydantic.BaseModel): 54 | T: int 55 | loader: DataLoaderYaml 56 | num_epochs: int 57 | graph_encoder: BiMPNNYaml 58 | predictor: NodePredictorYaml 59 | optimizer: OptimizerYaml 60 | 61 | class EdgePredictorYaml(pydantic.BaseModel): 62 | t_emb_size: int 63 | out_hidden_size: int 64 | 65 | class EdgePredYaml(pydantic.BaseModel): 66 | T: int 67 | loader: DataLoaderYaml 68 | num_epochs: int 69 | graph_encoder: BiMPNNYaml 70 | predictor: EdgePredictorYaml 71 | optimizer: OptimizerYaml 72 | 73 | class GeneralYaml(pydantic.BaseModel): 74 | dataset: str 75 | conditional: bool 76 | patience: Optional[int] = None 77 | 78 | class LayerDAGYaml(pydantic.BaseModel): 79 | general: GeneralYaml 80 | node_count: NodeCountYaml 81 | node_pred: NodePredYaml 82 | edge_pred: EdgePredYaml 83 | 84 | def load_yaml(config_file): 85 | with open(config_file) as f: 86 | yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader) 87 | 88 | return LayerDAGYaml(**yaml_data).model_dump() 89 | -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer_dag import * 2 | from .general import DAGDataset 3 | from .tpu_tile import get_tpu_tile 4 | 5 | def load_dataset(dataset_name): 6 | if dataset_name == 'tpu_tile': 7 | return get_tpu_tile() 8 | else: 9 | return NotImplementedError 10 | -------------------------------------------------------------------------------- /src/dataset/general.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.utils.data import Dataset 4 | 5 | class DAGDataset(Dataset): 6 | """ 7 | Parameters 8 | ---------- 9 | label : bool 10 | Whether each DAG has a label like runtime/latency. 11 | """ 12 | def __init__(self, num_categories, label=False): 13 | self.src = [] 14 | self.dst = [] 15 | self.x_n = [] 16 | 17 | self.label = label 18 | if self.label: 19 | self.y = [] 20 | 21 | self.dummy_category = num_categories 22 | if isinstance(self.dummy_category, torch.Tensor): 23 | self.dummy_category = self.dummy_category.tolist() 24 | 25 | self.num_categories = num_categories + 1 26 | 27 | def __len__(self): 28 | return len(self.src) 29 | 30 | def __getitem__(self, index): 31 | if self.label: 32 | return self.src[index], self.dst[index], self.x_n[index], self.y[index] 33 | else: 34 | return self.src[index], self.dst[index], self.x_n[index] 35 | 36 | def add_data(self, src, dst, x_n, y=None): 37 | self.src.append(src) 38 | self.dst.append(dst) 39 | self.x_n.append(x_n) 40 | if (y is not None) and (self.label): 41 | self.y.append(y) 42 | -------------------------------------------------------------------------------- /src/dataset/layer_dag.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from collections import defaultdict 4 | from torch.utils.data import Dataset 5 | 6 | __all__ = ['LayerDAGNodeCountDataset', 7 | 'LayerDAGNodePredDataset', 8 | 'LayerDAGEdgePredDataset', 9 | 'collate_node_count', 10 | 'collate_node_pred', 11 | 'collate_edge_pred'] 12 | 13 | class LayerDAGBaseDataset(Dataset): 14 | def __init__(self, conditional=False): 15 | self.input_src = [] 16 | self.input_dst = [] 17 | self.input_x_n = [] 18 | self.input_level = [] 19 | 20 | self.input_e_start = [] 21 | self.input_e_end = [] 22 | self.input_n_start = [] 23 | self.input_n_end = [] 24 | 25 | self.conditional = conditional 26 | if conditional: 27 | self.input_y = [] 28 | self.input_g = [] 29 | 30 | def get_in_deg(self, dst, num_nodes): 31 | return torch.bincount(dst, minlength=num_nodes).tolist() 32 | 33 | def get_out_adj_list(self, src, dst): 34 | out_adj_list = defaultdict(list) 35 | num_edges = len(src) 36 | for i in range(num_edges): 37 | out_adj_list[src[i]].append(dst[i]) 38 | return out_adj_list 39 | 40 | def get_in_adj_list(self, src, dst): 41 | in_adj_list = defaultdict(list) 42 | num_edges = len(src) 43 | for i in range(num_edges): 44 | in_adj_list[dst[i]].append(src[i]) 45 | return in_adj_list 46 | 47 | def base_postprocess(self): 48 | self.input_src = torch.LongTensor(self.input_src) 49 | self.input_dst = torch.LongTensor(self.input_dst) 50 | 51 | # Case1: self.input_x_n[0] is an int. 52 | # Case2: self.input_x_n[0] is a tensor of shape (F). 53 | self.input_x_n = torch.LongTensor(self.input_x_n) 54 | self.input_level = torch.LongTensor(self.input_level) 55 | 56 | self.input_e_start = torch.LongTensor(self.input_e_start) 57 | self.input_e_end = torch.LongTensor(self.input_e_end) 58 | self.input_n_start = torch.LongTensor(self.input_n_start) 59 | self.input_n_end = torch.LongTensor(self.input_n_end) 60 | 61 | if self.conditional: 62 | self.input_y = torch.tensor(self.input_y) 63 | self.input_g = torch.LongTensor(self.input_g) 64 | 65 | class LayerDAGNodeCountDataset(LayerDAGBaseDataset): 66 | def __init__(self, dag_dataset, conditional=False): 67 | super().__init__(conditional) 68 | 69 | # Size of the next layer to predict. 70 | self.label = [] 71 | 72 | for i in range(len(dag_dataset)): 73 | data_i = dag_dataset[i] 74 | 75 | if conditional: 76 | src, dst, x_n, y = data_i 77 | # Index of y in self.input_y 78 | input_g = len(self.input_y) 79 | self.input_y.append(y) 80 | else: 81 | src, dst, x_n = data_i 82 | 83 | # For recording indices of the node attributes in self.input_x_n 84 | input_n_start = len(self.input_x_n) 85 | input_n_end = len(self.input_x_n) 86 | 87 | # For recording indices of the edges in self.input_src/self.input_dst 88 | input_e_start = len(self.input_src) 89 | input_e_end = len(self.input_src) 90 | 91 | # Use a dummy node for representing the initial empty DAG. 92 | self.input_x_n.append(dag_dataset.dummy_category) 93 | input_n_end += 1 94 | src = src + 1 95 | dst = dst + 1 96 | 97 | # Layer ID 98 | level = 0 99 | self.input_level.append(level) 100 | 101 | num_nodes = len(x_n) + 1 102 | in_deg = self.get_in_deg(dst, num_nodes) 103 | 104 | src = src.tolist() 105 | dst = dst.tolist() 106 | x_n = x_n.tolist() 107 | out_adj_list = self.get_out_adj_list(src, dst) 108 | in_adj_list = self.get_in_adj_list(src, dst) 109 | 110 | frontiers = [ 111 | u for u in range(1, num_nodes) if in_deg[u] == 0 112 | ] 113 | frontier_size = len(frontiers) 114 | while frontier_size > 0: 115 | # There is another layer. 116 | level += 1 117 | 118 | # Record indices for retrieving edges in the previous layers 119 | # for model input. 120 | self.input_e_start.append(input_e_start) 121 | self.input_e_end.append(input_e_end) 122 | 123 | # Record indices for retrieving node attributes in the previous 124 | # layers for model input. 125 | self.input_n_start.append(input_n_start) 126 | self.input_n_end.append(input_n_end) 127 | 128 | if conditional: 129 | # Record the index for retrieving graph-level conditional 130 | # information for model input. 131 | self.input_g.append(input_g) 132 | self.label.append(frontier_size) 133 | 134 | # (1) Add the node attributes/edges for the current layer. 135 | # (2) Get the next layer. 136 | next_frontiers = [] 137 | for u in frontiers: 138 | # -1 for the initial dummy node 139 | self.input_x_n.append(x_n[u - 1]) 140 | self.input_level.append(level) 141 | 142 | for t in in_adj_list[u]: 143 | self.input_src.append(t) 144 | self.input_dst.append(u) 145 | input_e_end += 1 146 | 147 | for v in out_adj_list[u]: 148 | in_deg[v] -= 1 149 | if in_deg[v] == 0: 150 | next_frontiers.append(v) 151 | input_n_end += frontier_size 152 | 153 | frontiers = next_frontiers 154 | frontier_size = len(frontiers) 155 | 156 | # Handle termination, namely predicting the layer size to be 0. 157 | self.input_e_start.append(input_e_start) 158 | self.input_e_end.append(input_e_end) 159 | self.input_n_start.append(input_n_start) 160 | self.input_n_end.append(input_n_end) 161 | if conditional: 162 | self.input_g.append(input_g) 163 | self.label.append(frontier_size) 164 | 165 | self.base_postprocess() 166 | self.label = torch.LongTensor(self.label) 167 | # Maximum number of nodes in a layer. 168 | self.max_layer_size = self.label.max().item() 169 | 170 | def __len__(self): 171 | return len(self.label) 172 | 173 | def __getitem__(self, index): 174 | input_e_start = self.input_e_start[index] 175 | input_e_end = self.input_e_end[index] 176 | input_n_start = self.input_n_start[index] 177 | input_n_end = self.input_n_end[index] 178 | 179 | # Absolute and relative (with respect to the new layer) layer idx 180 | # for potential extra encodings. 181 | input_abs_level = self.input_level[input_n_start:input_n_end] 182 | input_rel_level = input_abs_level.max() - input_abs_level 183 | 184 | if self.conditional: 185 | input_g = self.input_g[index] 186 | input_y = self.input_y[input_g].item() 187 | 188 | return self.input_src[input_e_start:input_e_end],\ 189 | self.input_dst[input_e_start:input_e_end],\ 190 | self.input_x_n[input_n_start:input_n_end],\ 191 | input_abs_level, input_rel_level, input_y, self.label[index] 192 | else: 193 | return self.input_src[input_e_start:input_e_end],\ 194 | self.input_dst[input_e_start:input_e_end],\ 195 | self.input_x_n[input_n_start:input_n_end],\ 196 | input_abs_level, input_rel_level, self.label[index] 197 | 198 | class LayerDAGNodePredDataset(LayerDAGBaseDataset): 199 | def __init__(self, dag_dataset, conditional=False, get_marginal=True): 200 | super().__init__(conditional) 201 | 202 | # Indices for retrieving the labels 203 | # (node attributes for the next layer) 204 | self.label_start = [] 205 | self.label_end = [] 206 | 207 | for i in range(len(dag_dataset)): 208 | data_i = dag_dataset[i] 209 | 210 | if conditional: 211 | src, dst, x_n, y = data_i 212 | # Index of y in self.input_y 213 | input_g = len(self.input_y) 214 | self.input_y.append(y) 215 | else: 216 | src, dst, x_n = data_i 217 | 218 | # For recording indices of the node attributes in self.input_x_n, 219 | # which will be model input. 220 | input_n_start = len(self.input_x_n) 221 | input_n_end = len(self.input_x_n) 222 | 223 | # For recording indices of the edges in self.input_src/self.input_dst, 224 | # which will be model input. 225 | input_e_start = len(self.input_src) 226 | input_e_end = len(self.input_src) 227 | 228 | # Use a dummy node for representing the initial empty DAG, which 229 | # will be model input. 230 | self.input_x_n.append(dag_dataset.dummy_category) 231 | input_n_end += 1 232 | src = src + 1 233 | dst = dst + 1 234 | # For recording indices of the node attributes in self.input_x_n, 235 | # which will be ground truth labels for model predictions. 236 | label_start = len(self.input_x_n) 237 | 238 | # Layer ID 239 | level = 0 240 | self.input_level.append(level) 241 | 242 | num_nodes = len(x_n) + 1 243 | in_deg = self.get_in_deg(dst, num_nodes) 244 | 245 | src = src.tolist() 246 | dst = dst.tolist() 247 | x_n = x_n.tolist() 248 | out_adj_list = self.get_out_adj_list(src, dst) 249 | in_adj_list = self.get_in_adj_list(src, dst) 250 | 251 | frontiers = [ 252 | u for u in range(1, num_nodes) if in_deg[u] == 0 253 | ] 254 | frontier_size = len(frontiers) 255 | while frontier_size > 0: 256 | # There is another layer. 257 | level += 1 258 | 259 | # Record indices for retrieving edges in the previous layers 260 | # for model input. 261 | self.input_e_start.append(input_e_start) 262 | self.input_e_end.append(input_e_end) 263 | 264 | # Record indices for retrieving node attributes in the previous 265 | # layers for model input. 266 | self.input_n_start.append(input_n_start) 267 | self.input_n_end.append(input_n_end) 268 | 269 | if conditional: 270 | # Record the index for retrieving graph-level conditional 271 | # information for model input. 272 | self.input_g.append(input_g) 273 | 274 | # Record indices for retrieving node attributes of the new 275 | # layer for model predictions. 276 | self.label_start.append(label_start) 277 | label_end = label_start + frontier_size 278 | self.label_end.append(label_end) 279 | label_start = label_end 280 | 281 | # (1) Add the node attributes/edges for the current layer. 282 | # (2) Get the next layer. 283 | next_frontiers = [] 284 | for u in frontiers: 285 | # -1 for the initial dummy node 286 | self.input_x_n.append(x_n[u - 1]) 287 | self.input_level.append(level) 288 | 289 | for t in in_adj_list[u]: 290 | self.input_src.append(t) 291 | self.input_dst.append(u) 292 | input_e_end += 1 293 | 294 | for v in out_adj_list[u]: 295 | in_deg[v] -= 1 296 | if in_deg[v] == 0: 297 | next_frontiers.append(v) 298 | input_n_end += frontier_size 299 | 300 | frontiers = next_frontiers 301 | frontier_size = len(frontiers) 302 | 303 | self.base_postprocess() 304 | self.label_start = torch.LongTensor(self.label_start) 305 | self.label_end = torch.LongTensor(self.label_end) 306 | 307 | if get_marginal: 308 | # Case 1 (a single node attribute): self.input_x_n is of shape (N). 309 | # Case 2 (multiple node attributes): self.input_x_n is of shape (N, F). 310 | input_x_n = self.input_x_n 311 | if input_x_n.ndim == 1: 312 | input_x_n = input_x_n.unsqueeze(-1) 313 | 314 | num_feats = input_x_n.shape[-1] 315 | x_n_marginal = [] 316 | for f in range(num_feats): 317 | input_x_n_f = input_x_n[:, f] 318 | unique_x_n_f, x_n_count_f = input_x_n_f.unique(return_counts=True) 319 | assert unique_x_n_f.max().item() == len(x_n_count_f) - 1,\ 320 | 'Need to re-label node types to be consecutive integers starting from 0' 321 | 322 | # The last category is the dummy category. 323 | num_x_n_types_f = len(x_n_count_f) - 1 324 | x_n_marginal_f = torch.zeros(num_x_n_types_f) 325 | 326 | for c in range(len(x_n_count_f)): 327 | x_n_type_f_c = unique_x_n_f[c].item() 328 | # No need to include the dummy category for marginal computation. 329 | if x_n_type_f_c != num_x_n_types_f: 330 | x_n_marginal_f[x_n_type_f_c] = x_n_count_f[c].item() 331 | 332 | x_n_marginal_f /= (x_n_marginal_f.sum() + 1e-8) 333 | x_n_marginal.append(x_n_marginal_f) 334 | 335 | self.x_n_marginal = x_n_marginal 336 | 337 | def __len__(self): 338 | return len(self.label_start) 339 | 340 | def __getitem__(self, index): 341 | input_e_start = self.input_e_start[index] 342 | input_e_end = self.input_e_end[index] 343 | input_n_start = self.input_n_start[index] 344 | input_n_end = self.input_n_end[index] 345 | label_start = self.label_start[index] 346 | label_end = self.label_end[index] 347 | 348 | # Absolute and relative (with respect to the new layer) layer idx 349 | # for potential extra encodings. 350 | input_abs_level = self.input_level[input_n_start:input_n_end] 351 | input_rel_level = input_abs_level.max() - input_abs_level 352 | 353 | z = self.input_x_n[label_start:label_end] 354 | t, z_t = self.node_diffusion.apply_noise(z) 355 | 356 | if self.conditional: 357 | input_g = self.input_g[index] 358 | input_y = self.input_y[input_g].item() 359 | 360 | return self.input_src[input_e_start:input_e_end],\ 361 | self.input_dst[input_e_start:input_e_end],\ 362 | self.input_x_n[input_n_start:input_n_end],\ 363 | input_abs_level, input_rel_level, z_t, t, input_y, z 364 | else: 365 | return self.input_src[input_e_start:input_e_end],\ 366 | self.input_dst[input_e_start:input_e_end],\ 367 | self.input_x_n[input_n_start:input_n_end],\ 368 | input_abs_level, input_rel_level, z_t, t, z 369 | 370 | class LayerDAGEdgePredDataset(LayerDAGBaseDataset): 371 | def __init__(self, dag_dataset, conditional=False): 372 | super().__init__(conditional) 373 | 374 | self.query_src = [] 375 | self.query_dst = [] 376 | # Indices for retrieving the query node pairs 377 | self.query_start = [] 378 | self.query_end = [] 379 | self.label = [] 380 | 381 | num_edges = 0 382 | num_nonsrc_nodes = 0 383 | for i in range(len(dag_dataset)): 384 | data_i = dag_dataset[i] 385 | 386 | if conditional: 387 | src, dst, x_n, y = data_i 388 | # Index of y in self.input_y 389 | input_g = len(self.input_y) 390 | self.input_y.append(y) 391 | else: 392 | src, dst, x_n = data_i 393 | 394 | # For recording indices of the node attributes in self.input_x_n, 395 | # which will be model input. 396 | input_n_start = len(self.input_x_n) 397 | input_n_end = len(self.input_x_n) 398 | 399 | # For recording indices of the edges in self.input_src/self.input_dst, 400 | # which will be model input. 401 | input_e_start = len(self.input_src) 402 | input_e_end = len(self.input_src) 403 | 404 | # For recording indices of the query node pairs in 405 | # self.query_src/self.query_dst for model predictions. 406 | query_start = len(self.query_src) 407 | query_end = len(self.query_src) 408 | 409 | # Use a dummy node for representing the initial empty DAG, which 410 | # will be model input. 411 | self.input_x_n.append(dag_dataset.dummy_category) 412 | input_n_end += 1 413 | src = src + 1 414 | dst = dst + 1 415 | 416 | # Layer ID 417 | level = 0 418 | self.input_level.append(level) 419 | 420 | num_nodes = len(x_n) + 1 421 | in_deg = self.get_in_deg(dst, num_nodes) 422 | 423 | src = src.tolist() 424 | dst = dst.tolist() 425 | x_n = x_n.tolist() 426 | out_adj_list = self.get_out_adj_list(src, dst) 427 | in_adj_list = self.get_in_adj_list(src, dst) 428 | 429 | prev_frontiers = [ 430 | u for u in range(1, num_nodes) if in_deg[u] == 0 431 | ] 432 | current_frontiers = [] 433 | level += 1 434 | 435 | num_edges += len(src) 436 | num_nonsrc_nodes += len(x_n) - len(prev_frontiers) 437 | 438 | for u in prev_frontiers: 439 | self.input_x_n.append(x_n[u - 1]) 440 | self.input_level.append(level) 441 | 442 | for v in out_adj_list[u]: 443 | in_deg[v] -= 1 444 | if in_deg[v] == 0: 445 | current_frontiers.append(v) 446 | input_n_end += len(prev_frontiers) 447 | 448 | src_candidates = prev_frontiers 449 | 450 | while len(current_frontiers) > 0: 451 | level += 1 452 | 453 | next_frontiers = [] 454 | temp_edge_count = 0 455 | for u in current_frontiers: 456 | self.input_x_n.append(x_n[u - 1]) 457 | self.input_level.append(level) 458 | 459 | self.query_src.extend(src_candidates) 460 | self.query_dst.extend([u] * len(src_candidates)) 461 | query_end += len(src_candidates) 462 | for t in src_candidates: 463 | if t in in_adj_list[u]: 464 | self.input_src.append(t) 465 | self.input_dst.append(u) 466 | temp_edge_count += 1 467 | self.label.append(1) 468 | else: 469 | self.label.append(0) 470 | 471 | for v in out_adj_list[u]: 472 | in_deg[v] -= 1 473 | if in_deg[v] == 0: 474 | next_frontiers.append(v) 475 | 476 | input_n_end += len(current_frontiers) 477 | 478 | # Record indices for retrieving edges in the previous layers 479 | # for model input. 480 | self.input_e_start.append(input_e_start) 481 | self.input_e_end.append(input_e_end) 482 | 483 | # Record indices for retrieving node attributes in the previous 484 | # layers for model input. 485 | self.input_n_start.append(input_n_start) 486 | self.input_n_end.append(input_n_end) 487 | 488 | if conditional: 489 | # Record the index for retrieving graph-level conditional 490 | # information for model input. 491 | self.input_g.append(input_g) 492 | 493 | # Record indices for retrieving query node pairs 494 | # for model predictions. 495 | self.query_start.append(query_start) 496 | self.query_end.append(query_end) 497 | 498 | src_candidates.extend(current_frontiers) 499 | prev_frontiers = current_frontiers 500 | current_frontiers = next_frontiers 501 | input_e_end += temp_edge_count 502 | query_start = query_end 503 | 504 | self.base_postprocess() 505 | self.query_src = torch.LongTensor(self.query_src) 506 | self.query_dst = torch.LongTensor(self.query_dst) 507 | self.query_start = torch.LongTensor(self.query_start) 508 | self.query_end = torch.LongTensor(self.query_end) 509 | self.label = torch.LongTensor(self.label) 510 | 511 | self.avg_in_deg = num_edges / num_nonsrc_nodes 512 | 513 | def __len__(self): 514 | return len(self.query_start) 515 | 516 | def __getitem__(self, index): 517 | input_e_start = self.input_e_start[index] 518 | input_e_end = self.input_e_end[index] 519 | input_src = self.input_src[input_e_start:input_e_end] 520 | input_dst = self.input_dst[input_e_start:input_e_end] 521 | 522 | input_n_start = self.input_n_start[index] 523 | input_n_end = self.input_n_end[index] 524 | input_x_n = self.input_x_n[input_n_start:input_n_end] 525 | 526 | # Absolute and relative (with respect to the new layer) layer idx 527 | # for potential extra encodings. 528 | input_abs_level = self.input_level[input_n_start:input_n_end] 529 | input_rel_level = input_abs_level.max() - input_abs_level 530 | 531 | query_start = self.query_start[index] 532 | query_end = self.query_end[index] 533 | query_src = self.query_src[query_start:query_end] 534 | query_dst = self.query_dst[query_start:query_end] 535 | label = self.label[query_start:query_end] 536 | 537 | unique_src = torch.unique(query_src, sorted=False) 538 | unique_dst = torch.unique(query_dst, sorted=False) 539 | label_adj = label.reshape(len(unique_dst), len(unique_src)) 540 | 541 | t, label_t = self.edge_diffusion.apply_noise(label_adj) 542 | 543 | mask = (label_t == 1) 544 | noisy_src = query_src[mask] 545 | noisy_dst = query_dst[mask] 546 | 547 | if self.conditional: 548 | input_g = self.input_g[index] 549 | input_y = self.input_y[input_g].item() 550 | 551 | return input_src, input_dst, noisy_src, noisy_dst, input_x_n,\ 552 | input_abs_level, input_rel_level, t, input_y, query_src, query_dst, label 553 | else: 554 | return input_src, input_dst, noisy_src, noisy_dst, input_x_n,\ 555 | input_abs_level, input_rel_level, t, query_src, query_dst, label 556 | 557 | def collate_common(src, dst, x_n, abs_level, rel_level): 558 | num_nodes_cumsum = torch.cumsum(torch.tensor( 559 | [0] + [len(x_n_i) for x_n_i in x_n]), dim=0) 560 | 561 | batch_size = len(x_n) 562 | src_ = [] 563 | dst_ = [] 564 | for i in range(batch_size): 565 | src_.append(src[i] + num_nodes_cumsum[i]) 566 | dst_.append(dst[i] + num_nodes_cumsum[i]) 567 | src = torch.cat(src_, dim=0) 568 | dst = torch.cat(dst_, dim=0) 569 | edge_index = torch.stack([dst, src]) 570 | 571 | x_n = torch.cat(x_n, dim=0).long() 572 | abs_level = torch.cat(abs_level, dim=0).float().unsqueeze(-1) 573 | rel_level = torch.cat(rel_level, dim=0).float().unsqueeze(-1) 574 | 575 | # Prepare edge index for node to graph mapping 576 | nids = [] 577 | gids = [] 578 | for i in range(batch_size): 579 | nids.append(torch.arange(num_nodes_cumsum[i], num_nodes_cumsum[i+1]).long()) 580 | gids.append(torch.ones(num_nodes_cumsum[i+1] - num_nodes_cumsum[i]).fill_(i).long()) 581 | nids = torch.cat(nids, dim=0) 582 | gids = torch.cat(gids, dim=0) 583 | n2g_index = torch.stack([gids, nids]) 584 | 585 | return batch_size, edge_index, x_n, abs_level, rel_level, n2g_index 586 | 587 | def collate_node_count(data): 588 | if len(data[0]) == 7: 589 | batch_src, batch_dst, batch_x_n, batch_abs_level, batch_rel_level, batch_y, batch_label = map(list, zip(*data)) 590 | 591 | y_ = [] 592 | for i in range(len(batch_x_n)): 593 | y_.extend([batch_y[i]] * len(batch_x_n[i])) 594 | batch_y = torch.tensor(y_).unsqueeze(-1) 595 | else: 596 | batch_src, batch_dst, batch_x_n, batch_abs_level, batch_rel_level, batch_label = map( 597 | list, zip(*data)) 598 | 599 | batch_size, batch_edge_index, batch_x_n, batch_abs_level, batch_rel_level,\ 600 | batch_n2g_index = collate_common( 601 | batch_src, batch_dst, batch_x_n, batch_abs_level, batch_rel_level) 602 | 603 | batch_label = torch.stack(batch_label) 604 | 605 | if len(data[0]) == 7: 606 | return batch_size, batch_edge_index, batch_x_n, batch_abs_level,\ 607 | batch_rel_level, batch_y, batch_n2g_index, batch_label 608 | else: 609 | return batch_size, batch_edge_index, batch_x_n, batch_abs_level,\ 610 | batch_rel_level, batch_n2g_index, batch_label 611 | 612 | def collate_node_pred(data): 613 | if len(data[0]) == 8: 614 | batch_src, batch_dst, batch_x_n, batch_abs_level, batch_rel_level,\ 615 | batch_z_t, batch_t, batch_z = map(list, zip(*data)) 616 | else: 617 | batch_src, batch_dst, batch_x_n, batch_abs_level, batch_rel_level,\ 618 | batch_z_t, batch_t, batch_y, batch_z = map(list, zip(*data)) 619 | # Broadcast graph-level conditional information to nodes. 620 | y_ = [] 621 | for i in range(len(batch_x_n)): 622 | y_.extend([batch_y[i]] * len(batch_x_n[i])) 623 | batch_y = torch.tensor(y_).unsqueeze(-1) 624 | 625 | batch_size, batch_edge_index, batch_x_n, batch_abs_level, batch_rel_level,\ 626 | batch_n2g_index = collate_common( 627 | batch_src, batch_dst, batch_x_n, batch_abs_level, batch_rel_level) 628 | 629 | num_query_cumsum = torch.cumsum(torch.tensor( 630 | [0] + [len(z_t_i) for z_t_i in batch_z_t]), dim=0) 631 | query2g = [] 632 | for i in range(batch_size): 633 | query2g.append(torch.ones(num_query_cumsum[i+1] - num_query_cumsum[i]).fill_(i).long()) 634 | query2g = torch.cat(query2g) 635 | 636 | batch_z_t = torch.cat(batch_z_t) 637 | batch_t = torch.cat(batch_t).unsqueeze(-1) 638 | batch_z = torch.cat(batch_z) 639 | 640 | if batch_z.ndim == 1: 641 | batch_z = batch_z.unsqueeze(-1) 642 | 643 | if len(data[0]) == 8: 644 | return batch_size, batch_edge_index, batch_x_n, batch_abs_level,\ 645 | batch_rel_level, batch_n2g_index, batch_z_t, batch_t, query2g,\ 646 | num_query_cumsum, batch_z 647 | else: 648 | return batch_size, batch_edge_index, batch_x_n, batch_abs_level,\ 649 | batch_rel_level, batch_n2g_index, batch_z_t, batch_t, batch_y,\ 650 | query2g, num_query_cumsum, batch_z 651 | 652 | def collate_edge_pred(data): 653 | if len(data[0]) == 11: 654 | batch_src, batch_dst, batch_noisy_src, batch_noisy_dst, batch_x_n,\ 655 | batch_abs_level, batch_rel_level, batch_t, batch_query_src,\ 656 | batch_query_dst, batch_label = map(list, zip(*data)) 657 | else: 658 | batch_src, batch_dst, batch_noisy_src, batch_noisy_dst, batch_x_n,\ 659 | batch_abs_level, batch_rel_level, batch_t, batch_y,\ 660 | batch_query_src, batch_query_dst, batch_label = map(list, zip(*data)) 661 | # Broadcast graph-level conditional information to nodes. 662 | y_ = [] 663 | for i in range(len(batch_x_n)): 664 | y_.extend([batch_y[i]] * len(batch_x_n[i])) 665 | batch_y = torch.tensor(y_).unsqueeze(-1) 666 | 667 | num_nodes_cumsum = torch.cumsum(torch.tensor( 668 | [0] + [len(x_n_i) for x_n_i in batch_x_n]), dim=0) 669 | 670 | batch_size = len(batch_x_n) 671 | src_ = [] 672 | dst_ = [] 673 | noisy_src_ = [] 674 | noisy_dst_ = [] 675 | query_src_ = [] 676 | query_dst_ = [] 677 | t_ = [] 678 | for i in range(batch_size): 679 | src_.append(batch_src[i] + num_nodes_cumsum[i]) 680 | dst_.append(batch_dst[i] + num_nodes_cumsum[i]) 681 | noisy_src_.append(batch_noisy_src[i] + num_nodes_cumsum[i]) 682 | noisy_dst_.append(batch_noisy_dst[i] + num_nodes_cumsum[i]) 683 | query_src_.append(batch_query_src[i] + num_nodes_cumsum[i]) 684 | query_dst_.append(batch_query_dst[i] + num_nodes_cumsum[i]) 685 | t_.append(batch_t[i].expand(len(batch_query_src[i]), -1)) 686 | 687 | src = torch.cat(src_, dim=0) 688 | dst = torch.cat(dst_, dim=0) 689 | edge_index = torch.stack([dst, src]) 690 | noisy_src = torch.cat(noisy_src_, dim=0) 691 | noisy_dst = torch.cat(noisy_dst_, dim=0) 692 | noisy_edge_index = torch.stack([noisy_dst, noisy_src]) 693 | query_src = torch.cat(query_src_) 694 | query_dst = torch.cat(query_dst_) 695 | t = torch.cat(t_) 696 | 697 | batch_x_n = torch.cat(batch_x_n, dim=0).long() 698 | batch_abs_level = torch.cat(batch_abs_level, dim=0).float().unsqueeze(-1) 699 | batch_rel_level = torch.cat(batch_rel_level, dim=0).float().unsqueeze(-1) 700 | 701 | batch_label = torch.cat(batch_label) 702 | 703 | if len(data[0]) == 11: 704 | return edge_index, noisy_edge_index, batch_x_n, batch_abs_level,\ 705 | batch_rel_level, t, query_src, query_dst, batch_label 706 | else: 707 | return edge_index, noisy_edge_index, batch_x_n, batch_abs_level,\ 708 | batch_rel_level, t, batch_y, query_src, query_dst, batch_label 709 | -------------------------------------------------------------------------------- /src/dataset/tpu_tile.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from .general import DAGDataset 5 | 6 | def to_dag_dataset(data_dict, num_categories): 7 | dataset = DAGDataset(num_categories=num_categories, label=True) 8 | 9 | src_list = data_dict['src_list'] 10 | dst_list = data_dict['dst_list'] 11 | x_n_list = data_dict['x_n_list'] 12 | y_list = data_dict['y_list'] 13 | 14 | num_g = len(src_list) 15 | for i in range(num_g): 16 | dataset.add_data(src_list[i], 17 | dst_list[i], 18 | x_n_list[i], 19 | y_list[i]) 20 | 21 | return dataset 22 | 23 | def get_tpu_tile(): 24 | root_path = os.path.dirname(os.path.abspath(__file__)) 25 | root_path = os.path.join(root_path, '../../data_files/tpu_tile_processed') 26 | 27 | train_path = os.path.join(root_path, 'train.pth') 28 | val_path = os.path.join(root_path, 'val.pth') 29 | test_path = os.path.join(root_path, 'test.pth') 30 | 31 | print('Loading TPU Tile dataset...') 32 | # Load the pre-processed TPU Tile dataset, where for each kernel graph, we 33 | # average the normalized runtime over multiple compiler configurations. 34 | train_set = torch.load(train_path) 35 | val_set = torch.load(val_path) 36 | test_set = torch.load(test_path) 37 | 38 | num_categories = torch.cat(train_set['x_n_list']).max().item() + 1 39 | train_set = to_dag_dataset(train_set, num_categories) 40 | val_set = to_dag_dataset(val_set, num_categories) 41 | test_set = to_dag_dataset(test_set, num_categories) 42 | 43 | return train_set, val_set, test_set 44 | -------------------------------------------------------------------------------- /src/eval/__init__.py: -------------------------------------------------------------------------------- 1 | from .discriminator import * 2 | from .tpu_tile import * -------------------------------------------------------------------------------- /src/eval/discriminator/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .mpnn import * -------------------------------------------------------------------------------- /src/eval/discriminator/base.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import os 3 | import torch 4 | 5 | __all__ = ['BaseEvaluator'] 6 | 7 | class BaseTrainer: 8 | def __init__(self, 9 | hyper_space, 10 | search_priority_increasing): 11 | """Base class for training a discriminative model. 12 | 13 | Parameters 14 | ---------- 15 | search_priority_increasing : list of str 16 | The priority of hyperparameters to search, from lowest to highest. 17 | """ 18 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 19 | self.device = torch.device(device) 20 | 21 | self.hyper_space = hyper_space 22 | self.search_priority_increasing = search_priority_increasing 23 | 24 | def get_config_list(self): 25 | vals = [self.hyper_space[k] for k in self.search_priority_increasing] 26 | 27 | config_list = [] 28 | for items in itertools.product(*vals): 29 | items_dict = dict(zip(self.search_priority_increasing, items)) 30 | config_list.append(items_dict) 31 | 32 | return config_list 33 | 34 | def save_model(self, model_path): 35 | torch.save({ 36 | "model_state_dict": self.model.state_dict(), 37 | "model_config": self.best_model_config 38 | }, model_path) 39 | 40 | class BaseEvaluator: 41 | def __init__(self, 42 | Trainer, 43 | model_path, 44 | real_train_set, 45 | real_val_set, 46 | real_test_set): 47 | self.Trainer = Trainer 48 | self.real_test_set = real_test_set 49 | 50 | self.model_real = Trainer() 51 | if (model_path is not None) and (os.path.exists(model_path)): 52 | self.model_real.load_model(model_path) 53 | else: 54 | self.model_real.fit(real_train_set, 55 | real_val_set) 56 | if model_path is not None: 57 | self.model_real.save_model(model_path) 58 | 59 | self.real_pearson_coeff, self.real_spearman_coeff, self.real_mae = self.model_real.predict(real_test_set) 60 | 61 | def eval(self, train_syn_set, val_syn_set): 62 | model_syn = self.Trainer() 63 | model_syn.fit(train_syn_set, val_syn_set) 64 | self.syn_pearson_coeff, self.syn_spearman_coeff, self.syn_mae = model_syn.predict(self.real_test_set) 65 | -------------------------------------------------------------------------------- /src/eval/discriminator/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def collate_fn(data): 4 | batch_src, batch_dst, batch_x_n, batch_label = map(list, zip(*data)) 5 | 6 | # (B + 1), B for batch size 7 | num_nodes_cumsum = torch.cumsum(torch.tensor( 8 | [0] + [len(x_n) for x_n in batch_x_n] 9 | ), dim=0) 10 | 11 | batch_size = len(batch_src) 12 | batch_src_ = [] 13 | batch_dst_ = [] 14 | for i in range(batch_size): 15 | batch_src_.append(batch_src[i] + num_nodes_cumsum[i]) 16 | batch_dst_.append(batch_dst[i] + num_nodes_cumsum[i]) 17 | 18 | batch_src = torch.cat(batch_src_, dim=0) # (E) 19 | batch_dst = torch.cat(batch_dst_, dim=0) # (E) 20 | batch_edge_index = torch.stack([batch_dst, batch_src]) # (2, E) 21 | 22 | # (V) 23 | batch_x_n = torch.cat(batch_x_n, dim=0).long() 24 | # (B, F_G) 25 | batch_label = torch.tensor(batch_label).unsqueeze(-1) 26 | 27 | return batch_edge_index, batch_x_n, num_nodes_cumsum, batch_label 28 | -------------------------------------------------------------------------------- /src/eval/discriminator/mpnn.py: -------------------------------------------------------------------------------- 1 | import dgl.sparse as dglsp 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from copy import deepcopy 7 | from scipy import stats 8 | from torch.optim import Adam 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | from .base import BaseTrainer 13 | from .data_utils import collate_fn 14 | 15 | __all__ = ['MPNNTrainer'] 16 | 17 | class MultiEmbedding(nn.Module): 18 | def __init__(self, num_x_n_cat, hidden_size): 19 | super().__init__() 20 | 21 | self.emb_list = nn.ModuleList([ 22 | nn.Embedding(num_x_n_cat_i, hidden_size) 23 | for num_x_n_cat_i in num_x_n_cat 24 | ]) 25 | 26 | def forward(self, x_n_cat): 27 | if len(x_n_cat.shape) == 1: 28 | x_n_emb = self.emb_list[0](x_n_cat) 29 | else: 30 | x_n_emb = torch.cat([ 31 | self.emb_list[i](x_n_cat[:, i]) for i in range(len(self.emb_list)) 32 | ], dim=1) 33 | 34 | return x_n_emb 35 | 36 | class MPNNLayer(nn.Module): 37 | def __init__(self, hidden_size): 38 | super().__init__() 39 | 40 | self.W = nn.Linear(hidden_size, hidden_size) 41 | self.W_self = nn.Linear(hidden_size, hidden_size) 42 | self.W_trans = nn.Linear(hidden_size, hidden_size) 43 | 44 | def forward(self, A, h_n): 45 | if A.nnz == 0: 46 | h_n_out = self.W_self(h_n) 47 | else: 48 | h_n_out = A @ self.W(h_n) + self.W_self(h_n) + A.T @ self.W_trans(h_n) 49 | # h_n_out = A @ self.W(h_n) + self.W_self(h_n) 50 | return F.relu(h_n_out) 51 | 52 | class MPNN(nn.Module): 53 | def __init__(self, 54 | hidden_size, 55 | num_mpnn_layers, 56 | num_x_n_cat=None): 57 | super().__init__() 58 | 59 | # Backward compatbility with earlier TPU Tile checkpoints 60 | if num_x_n_cat is None: 61 | num_x_n_cat = [47] 62 | 63 | self.x_n_emb = MultiEmbedding(num_x_n_cat, hidden_size) 64 | hidden_size_total = len(num_x_n_cat) * hidden_size 65 | self.mpnn_layers = nn.ModuleList() 66 | for _ in range(num_mpnn_layers): 67 | self.mpnn_layers.append(MPNNLayer(hidden_size_total)) 68 | 69 | self.out_proj = nn.Sequential( 70 | nn.Linear((num_mpnn_layers + 1) * hidden_size_total, hidden_size_total), 71 | nn.ReLU(), 72 | ) 73 | 74 | self.pred = nn.Sequential( 75 | nn.Linear(hidden_size_total, hidden_size_total), 76 | nn.ReLU(), 77 | nn.Linear(hidden_size_total, 1) 78 | ) 79 | 80 | def forward(self, A, x_n, A_n_g): 81 | # A = A + A.T 82 | h_n = self.x_n_emb(x_n) 83 | 84 | h_n_cat = [h_n] 85 | for layer in self.mpnn_layers: 86 | h_n = layer(A, h_n) 87 | h_n_cat.append(h_n) 88 | h_n = torch.cat(h_n_cat, dim=-1) 89 | h_n = self.out_proj(h_n) 90 | h_g = A_n_g @ h_n 91 | 92 | return self.pred(h_g) 93 | 94 | class MPNNTrainer(BaseTrainer): 95 | def __init__(self, 96 | hyper_space='tpu_tile', 97 | search_priority_increasing=None): 98 | if hyper_space == 'tpu_tile': 99 | hyper_space = { 100 | "lr": [1e-3], 101 | "num_mpnn_layers": [4], 102 | "hidden_size": [128], 103 | "num_x_n_cat": [[47]], 104 | "num_epochs": [500] 105 | } 106 | elif hyper_space == 'hls_dsp': 107 | hyper_space = { 108 | "lr": [1e-3], 109 | "num_mpnn_layers": [1], 110 | "hidden_size": [32], 111 | "num_x_n_cat": [[3, 107, 7, 45, 2, 2, 21]], 112 | "num_epochs": [500] 113 | } 114 | elif hyper_space == 'hls_lut': 115 | hyper_space = { 116 | "lr": [1e-3], 117 | "num_mpnn_layers": [3], 118 | "hidden_size": [64], 119 | "num_x_n_cat": [[3, 107, 7, 45, 2, 2, 21]], 120 | "num_epochs": [500] 121 | } 122 | elif hyper_space == 'nas_cpu': 123 | hyper_space = { 124 | "lr": [1e-3], 125 | "num_mpnn_layers": [1], 126 | "hidden_size": [1], 127 | "num_x_n_cat": [[9, 2, 5, 4, 5, 4, 4, 5, 4, 3, 3, 5, 4, 5]], 128 | "num_epochs": [500] 129 | } 130 | 131 | if search_priority_increasing is None: 132 | search_priority_increasing = ["lr", "num_mpnn_layers", "hidden_size", "num_x_n_cat", "num_epochs"] 133 | 134 | super().__init__(hyper_space=hyper_space, 135 | search_priority_increasing=search_priority_increasing) 136 | 137 | def preprocess(self, edge_index, x_n, num_nodes_cumsum, label): 138 | N = int(num_nodes_cumsum[-1]) 139 | A = dglsp.spmatrix(edge_index, shape=(N, N)).to(self.device) 140 | x_n = x_n.to(self.device) 141 | label = label.to(self.device) 142 | 143 | batch_size = len(num_nodes_cumsum) - 1 144 | src = [] 145 | dst = [] 146 | for i in range(batch_size): 147 | num_nodes_i = num_nodes_cumsum[i + 1] - num_nodes_cumsum[i] 148 | dst.extend([i] * num_nodes_i) 149 | src.extend(list(range(num_nodes_cumsum[i], num_nodes_cumsum[i + 1]))) 150 | 151 | src = torch.LongTensor(src) 152 | dst = torch.LongTensor(dst) 153 | n_g_edge_index = torch.stack([dst, src]) 154 | A_n_g = dglsp.spmatrix(n_g_edge_index, shape=(len(label), len(x_n))).to(self.device) 155 | 156 | return A, x_n, A_n_g, label 157 | 158 | def train_epoch(self, train_loader, model, optimizer): 159 | model.train() 160 | for batch_data in train_loader: 161 | batch_edge_index, batch_x_n, num_nodes_cumsum, batch_label = batch_data 162 | A, x_n, A_n_g, label = self.preprocess( 163 | batch_edge_index, batch_x_n, num_nodes_cumsum, batch_label) 164 | pred = model(A, x_n, A_n_g) 165 | loss = F.smooth_l1_loss(pred, label) 166 | 167 | optimizer.zero_grad() 168 | loss.backward() 169 | optimizer.step() 170 | 171 | @torch.no_grad() 172 | def eval_epoch(self, data_loader, model, spearman=False): 173 | model.eval() 174 | 175 | full_label = [] 176 | full_pred = [] 177 | for batch_data in data_loader: 178 | batch_edge_index, batch_x_n, num_nodes_cumsum, batch_label = batch_data 179 | A, x_n, A_n_g, label = self.preprocess( 180 | batch_edge_index, batch_x_n, num_nodes_cumsum, batch_label) 181 | pred = model(A, x_n, A_n_g) 182 | full_label.append(label.cpu()) 183 | full_pred.append(pred.cpu()) 184 | 185 | full_label = torch.cat(full_label).squeeze(-1) 186 | full_pred = torch.cat(full_pred).squeeze(-1) 187 | coef = stats.pearsonr(full_label.numpy(), full_pred.numpy())[0] 188 | 189 | mae = F.l1_loss(full_label, full_pred) 190 | 191 | if spearman: 192 | spearman_coef = stats.spearmanr(full_label.numpy(), full_pred.numpy())[0] 193 | 194 | return coef, spearman_coef, mae 195 | else: 196 | return coef, mae 197 | 198 | def fit_trial(self, 199 | train_set, 200 | val_set, 201 | hidden_size, 202 | num_mpnn_layers, 203 | num_x_n_cat, 204 | lr, 205 | num_epochs, 206 | batch_size=256, 207 | num_workers=0, 208 | patience_limit=100): 209 | torch.set_num_threads(20) 210 | train_loader = DataLoader(train_set, 211 | batch_size=batch_size, 212 | num_workers=num_workers, 213 | collate_fn=collate_fn, 214 | shuffle=True) 215 | 216 | val_loader = DataLoader(val_set, 217 | batch_size=batch_size, 218 | num_workers=num_workers, 219 | collate_fn=collate_fn) 220 | 221 | model = MPNN(num_x_n_cat=num_x_n_cat, 222 | hidden_size=hidden_size, 223 | num_mpnn_layers=num_mpnn_layers).to(self.device) 224 | optimizer = Adam(model.parameters(), lr=lr) 225 | 226 | best_val_coef = float('-inf') 227 | patience = 0 228 | best_model_state_dict = deepcopy(model.state_dict()) 229 | for epoch in tqdm(range(num_epochs)): 230 | self.train_epoch(train_loader, model, optimizer) 231 | val_coef, val_mae = self.eval_epoch(val_loader, model) 232 | if val_coef > best_val_coef: 233 | patience = 0 234 | best_val_coef = val_coef 235 | best_model_state_dict = deepcopy(model.state_dict()) 236 | else: 237 | patience += 1 238 | 239 | print(f'Epoch {epoch} | Best Val coef: {best_val_coef:.4f} | Val coef: {val_coef:.4f} | Val mae: {val_mae:.4f}') 240 | 241 | if patience == patience_limit: 242 | break 243 | 244 | model.load_state_dict(best_model_state_dict) 245 | return best_val_coef, model 246 | 247 | def fit(self, 248 | train_set, 249 | val_set): 250 | config_list = self.get_config_list() 251 | 252 | best_coef = float('-inf') 253 | with tqdm(config_list) as tconfig: 254 | tconfig.set_description("Training MPNN discriminator") 255 | 256 | for config in tconfig: 257 | trial_coef, trial_model = self.fit_trial( 258 | train_set, val_set, **config) 259 | if trial_coef > best_coef: 260 | best_coef = trial_coef 261 | best_model = trial_model 262 | best_model_config = { 263 | "hidden_size": config["hidden_size"], 264 | "num_mpnn_layers": config["num_mpnn_layers"], 265 | "num_x_n_cat": config["num_x_n_cat"] 266 | } 267 | tconfig.set_postfix(pearson=best_coef) 268 | 269 | if trial_coef == 1.0: 270 | break 271 | 272 | self.model = best_model 273 | self.best_model_config = best_model_config 274 | 275 | def predict(self, 276 | test_set, 277 | batch_size=256, 278 | num_workers=0): 279 | test_loader = DataLoader(test_set, 280 | batch_size=batch_size, 281 | num_workers=num_workers, 282 | collate_fn=collate_fn) 283 | return self.eval_epoch(test_loader, self.model, spearman=True) 284 | 285 | def load_model(self, model_path): 286 | cpt = torch.load(model_path) 287 | model = MPNN(**cpt["model_config"]).to(self.device) 288 | model.load_state_dict(cpt["model_state_dict"]) 289 | self.model = model 290 | -------------------------------------------------------------------------------- /src/eval/tpu_tile.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .discriminator import BaseEvaluator, MPNNTrainer 4 | from ..dataset import load_dataset 5 | 6 | __all__ = ['TPUTileEvaluator'] 7 | 8 | class TPUTileEvaluator: 9 | def __init__(self): 10 | train_set, val_set, test_set = load_dataset('tpu_tile') 11 | 12 | cpt_path = "tpu_tile_cpts" 13 | os.makedirs(cpt_path, exist_ok=True) 14 | self.mpnn_evaluator = BaseEvaluator(MPNNTrainer, 15 | os.path.join("tpu_tile_cpts", "mpnn.pth"), 16 | train_set, 17 | val_set, 18 | test_set) 19 | 20 | def eval(self, train_syn_set, val_syn_set): 21 | self.mpnn_evaluator.eval(train_syn_set, val_syn_set) 22 | self.summary() 23 | 24 | def summary(self): 25 | print('\n') 26 | print('MPNN Discriminator') 27 | print('------------------') 28 | print('\n') 29 | 30 | print('Real') 31 | print('------------------') 32 | print(f'Pearson Coeff: {self.mpnn_evaluator.real_pearson_coeff}') 33 | print('\n') 34 | 35 | print('Synthetic') 36 | print('------------------') 37 | print(f'Pearson Coeff: {self.mpnn_evaluator.syn_pearson_coeff}') 38 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion import * 2 | from .layer_dag import * 3 | -------------------------------------------------------------------------------- /src/model/diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | __all__ = [ 7 | 'DiscreteDiffusion', 8 | 'EdgeDiscreteDiffusion' 9 | ] 10 | 11 | class DiscreteDiffusion(nn.Module): 12 | def __init__(self, 13 | marginal_list, 14 | T, 15 | s=0.008): 16 | """ 17 | Parameters 18 | ---------- 19 | marginal_list : list of torch.Tensor 20 | marginal_list[d] is the marginal distribution of the d-th attribute 21 | s : float 22 | Constant in noise schedule 23 | """ 24 | super().__init__() 25 | 26 | if not isinstance(marginal_list, list): 27 | marginal_list = [marginal_list] 28 | 29 | self.num_classes_list = [] 30 | self.I_list = nn.ParameterList([]) 31 | self.m_list = nn.ParameterList([]) 32 | 33 | for marginal_d in marginal_list: 34 | num_classes_d = len(marginal_d) 35 | self.num_classes_list.append(num_classes_d) 36 | self.I_list.append(nn.Parameter( 37 | torch.eye(num_classes_d), requires_grad=False)) 38 | marginal_d = marginal_d.unsqueeze(0).expand( 39 | num_classes_d, -1).clone() 40 | self.m_list.append(nn.Parameter(marginal_d, requires_grad=False)) 41 | 42 | self.T = T 43 | # Cosine schedule as proposed in 44 | # https://arxiv.org/abs/2102.09672 45 | num_steps = T + 2 46 | t = np.linspace(0, num_steps, num_steps) 47 | # Schedule for \bar{alpha}_t = alpha_1 * ... * alpha_t 48 | alpha_bars = np.cos(0.5 * np.pi * ((t / num_steps) + s) / (1 + s)) ** 2 49 | # Make the largest value 1. 50 | alpha_bars = alpha_bars / alpha_bars[0] 51 | alphas = alpha_bars[1:] / alpha_bars[:-1] 52 | 53 | self.betas = torch.from_numpy(1 - alphas).float() 54 | self.alphas = 1 - torch.clamp(self.betas, min=0, max=0.9999) 55 | 56 | log_alphas = torch.log(self.alphas) 57 | log_alpha_bars = torch.cumsum(log_alphas, dim=0) 58 | self.alpha_bars = torch.exp(log_alpha_bars) 59 | 60 | self.betas = nn.Parameter(self.betas, requires_grad=False) 61 | self.alphas = nn.Parameter(self.alphas, requires_grad=False) 62 | self.alpha_bars = nn.Parameter(self.alpha_bars, requires_grad=False) 63 | 64 | def get_Q(self, alpha, d): 65 | """ 66 | Parameters 67 | ---------- 68 | d : int 69 | Index for the attribute 70 | """ 71 | return alpha * self.I_list[d] + (1 - alpha) * self.m_list[d] 72 | 73 | def apply_noise(self, z, t=None): 74 | if t is None: 75 | # Sample a timestep t uniformly from 0 to self.T. 76 | # Note that the notation is slightly inconsistent with the paper. 77 | # t=0 corresponds to t=1 in the paper, where corruption has already taken place. 78 | t = torch.randint(low=0, high=self.T + 1, size=(1,)) 79 | 80 | alpha_bar_t = self.alpha_bars[t.item()] 81 | 82 | if z.ndim == 1: 83 | z = z.unsqueeze(-1) 84 | 85 | _, D = z.shape 86 | z_t_list = [] 87 | for d in range(D): 88 | Q_bar_t_d = self.get_Q(alpha_bar_t, d) 89 | z_one_hot_d = F.one_hot(z[:, d], num_classes=self.num_classes_list[d]).float() 90 | prob_z_t_d = z_one_hot_d @ Q_bar_t_d 91 | z_t_d = prob_z_t_d.multinomial(1).squeeze(-1) 92 | z_t_list.append(z_t_d) 93 | 94 | if D == 1: 95 | z_t = z_t_list[0] 96 | else: 97 | z_t = torch.stack(z_t_list, dim=1) 98 | 99 | return t, z_t 100 | 101 | class EdgeDiscreteDiffusion(nn.Module): 102 | def __init__(self, 103 | avg_in_deg, 104 | T, 105 | s=0.008): 106 | super().__init__() 107 | 108 | self.avg_in_deg = avg_in_deg 109 | 110 | self.T = T 111 | # Cosine schedule as proposed in 112 | # https://arxiv.org/abs/2102.09672 113 | num_steps = T + 2 114 | t = np.linspace(0, num_steps, num_steps) 115 | # Schedule for \bar{alpha}_t = alpha_1 * ... * alpha_t 116 | alpha_bars = np.cos(0.5 * np.pi * ((t / num_steps) + s) / (1 + s)) ** 2 117 | # Make the largest value 1. 118 | alpha_bars = alpha_bars / alpha_bars[0] 119 | alphas = alpha_bars[1:] / alpha_bars[:-1] 120 | 121 | self.betas = torch.from_numpy(1 - alphas).float() 122 | self.alphas = 1 - torch.clamp(self.betas, min=0, max=0.9999) 123 | 124 | log_alphas = torch.log(self.alphas) 125 | log_alpha_bars = torch.cumsum(log_alphas, dim=0) 126 | self.alpha_bars = torch.exp(log_alpha_bars) 127 | 128 | self.betas = nn.Parameter(self.betas, requires_grad=False) 129 | self.alphas = nn.Parameter(self.alphas, requires_grad=False) 130 | self.alpha_bars = nn.Parameter(self.alpha_bars, requires_grad=False) 131 | 132 | def apply_noise(self, z, t=None): 133 | """ 134 | Parameters 135 | ---------- 136 | z : torch.Tensor of shape (A, B) 137 | Adjacency matrix. 138 | A is the number of candidate destination nodes. 139 | B is the number of candidate source nodes. 140 | 141 | Returns 142 | ------- 143 | z_t : torch.Tensor of shape (A * B) 144 | """ 145 | if t is None: 146 | # Sample a timestep t uniformly from 0 to self.T. 147 | # Note that the notation is slightly inconsistent with the paper. 148 | # t=0 corresponds to t=1 in the paper, where corruption has already taken place. 149 | t = torch.randint(low=0, high=self.T + 1, size=(1,)) 150 | 151 | # TODO: Better doc 152 | alpha_bar_t = self.alpha_bars[t.item()] 153 | # Marginal probability for an edge to exist. 154 | mean_in_deg = min(self.avg_in_deg, z.shape[1]) 155 | m_z_t = torch.ones(z.shape) * (mean_in_deg / z.shape[1]) 156 | prob_z_t = alpha_bar_t * z + (1 - alpha_bar_t) * m_z_t 157 | z_t = torch.bernoulli(prob_z_t) 158 | 159 | # Make sure each node has at least one edge. 160 | isolated_mask = (z_t.sum(dim=1) == 0).bool() 161 | if isolated_mask.any(): 162 | z_t[isolated_mask, prob_z_t[isolated_mask].argmax(dim=1)] = 1 163 | 164 | z_t = z_t.reshape(-1) 165 | 166 | return t, z_t 167 | 168 | def get_Qs(self, 169 | alpha_t, 170 | alpha_bar_s, 171 | alpha_bar_t, 172 | marginal): 173 | M = torch.zeros(2) 174 | M = torch.tensor([ 175 | 1 - marginal, marginal 176 | ]) 177 | M = M.unsqueeze(0).expand(2, -1) 178 | I = torch.eye(2) 179 | 180 | Q_t = alpha_t * I + (1 - alpha_t) * M 181 | Q_bar_s = alpha_bar_s * I + (1 - alpha_bar_s) * M 182 | Q_bar_t = alpha_bar_t * I + (1 - alpha_bar_t) * M 183 | 184 | return Q_t, Q_bar_s, Q_bar_t 185 | -------------------------------------------------------------------------------- /src/model/layer_dag.py: -------------------------------------------------------------------------------- 1 | import dgl.sparse as dglsp 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | __all__ = [ 10 | 'LayerDAG' 11 | ] 12 | 13 | class SinusoidalPE(nn.Module): 14 | def __init__(self, pe_size): 15 | super().__init__() 16 | 17 | self.pe_size = pe_size 18 | if pe_size > 0: 19 | self.div_term = torch.exp(torch.arange(0, pe_size, 2) * 20 | (-math.log(10000.0) / pe_size)) 21 | self.div_term = nn.Parameter(self.div_term, requires_grad=False) 22 | 23 | def forward(self, position): 24 | if self.pe_size == 0: 25 | return torch.zeros(len(position), 0).to(position.device) 26 | 27 | return torch.cat([ 28 | torch.sin(position * self.div_term), 29 | torch.cos(position * self.div_term) 30 | ], dim=-1) 31 | 32 | class BiMPNNLayer(nn.Module): 33 | def __init__(self, in_size, out_size): 34 | super().__init__() 35 | 36 | self.W = nn.Linear(in_size, out_size) 37 | self.W_trans = nn.Linear(in_size, out_size) 38 | self.W_self = nn.Linear(in_size, out_size) 39 | 40 | def forward(self, A, A_T, h_n): 41 | if A.nnz == 0: 42 | h_n_out = self.W_self(h_n) 43 | else: 44 | h_n_out = A @ self.W(h_n) + A_T @ self.W_trans(h_n) +\ 45 | self.W_self(h_n) 46 | return F.gelu(h_n_out) 47 | 48 | class OneHotPE(nn.Module): 49 | def __init__(self, pe_size): 50 | super().__init__() 51 | 52 | self.pe_size = pe_size 53 | 54 | def forward(self, position): 55 | if self.pe_size == 0: 56 | return torch.zeros(len(position), 0).to(position.device) 57 | 58 | return F.one_hot(position.clamp(max=self.pe_size - 1).long().squeeze(-1), 59 | num_classes=self.pe_size) 60 | 61 | class MultiEmbedding(nn.Module): 62 | def __init__(self, num_x_n_cat, hidden_size): 63 | super().__init__() 64 | 65 | self.emb_list = nn.ModuleList([ 66 | nn.Embedding(num_x_n_cat_i, hidden_size) 67 | for num_x_n_cat_i in num_x_n_cat.tolist() 68 | ]) 69 | 70 | def forward(self, x_n_cat): 71 | if len(x_n_cat.shape) == 1: 72 | x_n_emb = self.emb_list[0](x_n_cat) 73 | else: 74 | x_n_emb = torch.cat([ 75 | self.emb_list[i](x_n_cat[:, i]) for i in range(len(self.emb_list)) 76 | ], dim=1) 77 | 78 | return x_n_emb 79 | 80 | class BiMPNNEncoder(nn.Module): 81 | def __init__(self, 82 | num_x_n_cat, 83 | x_n_emb_size, 84 | pe_emb_size, 85 | hidden_size, 86 | num_mpnn_layers, 87 | pe=None, 88 | y_emb_size=0, 89 | pool=None): 90 | super().__init__() 91 | 92 | self.pe = pe 93 | if self.pe in ['relative_level', 'abs_level']: 94 | self.level_emb = SinusoidalPE(pe_emb_size) 95 | elif self.pe == 'relative_level_one_hot': 96 | self.level_emb = OneHotPE(pe_emb_size) 97 | 98 | self.x_n_emb = MultiEmbedding(num_x_n_cat, x_n_emb_size) 99 | self.y_emb = SinusoidalPE(y_emb_size) 100 | 101 | self.proj_input = nn.Sequential( 102 | nn.Linear(hidden_size, hidden_size), 103 | nn.GELU(), 104 | nn.Linear(hidden_size, hidden_size) 105 | ) 106 | 107 | self.mpnn_layers = nn.ModuleList() 108 | for _ in range(num_mpnn_layers): 109 | self.mpnn_layers.append(BiMPNNLayer(hidden_size, hidden_size)) 110 | 111 | self.project_output_n = nn.Sequential( 112 | nn.Linear((num_mpnn_layers + 1) * hidden_size, hidden_size), 113 | nn.GELU(), 114 | nn.Linear(hidden_size, hidden_size) 115 | ) 116 | 117 | self.pool = pool 118 | if pool is not None: 119 | self.bn_g = nn.BatchNorm1d(hidden_size) 120 | 121 | def forward(self, A, x_n, abs_level, rel_level, y=None, A_n2g=None): 122 | A_T = A.T 123 | h_n = self.x_n_emb(x_n) 124 | 125 | if self.pe == 'abs_level': 126 | node_pe = self.level_emb(abs_level) 127 | 128 | if self.pe in ['relative_level', 'relative_level_one_hot']: 129 | node_pe = self.level_emb(rel_level) 130 | 131 | if self.pe is not None: 132 | h_n = torch.cat([h_n, node_pe], dim=-1) 133 | 134 | if y is not None: 135 | h_y = self.y_emb(y) 136 | h_n = torch.cat([h_n, h_y], dim=-1) 137 | 138 | h_n = self.proj_input(h_n) 139 | h_n_cat = [h_n] 140 | for layer in self.mpnn_layers: 141 | h_n = layer(A, A_T, h_n) 142 | h_n_cat.append(h_n) 143 | h_n = torch.cat(h_n_cat, dim=-1) 144 | h_n = self.project_output_n(h_n) 145 | 146 | if self.pool is None: 147 | return h_n 148 | elif self.pool == 'sum': 149 | h_g = A_n2g @ h_n 150 | return self.bn_g(h_g) 151 | elif self.pool == 'mean': 152 | h_g = A_n2g @ h_n 153 | h_g = h_g / A_n2g.sum(dim=1).unsqueeze(-1) 154 | return self.bn_g(h_g) 155 | 156 | class GraphClassifier(nn.Module): 157 | def __init__(self, 158 | graph_encoder, 159 | emb_size, 160 | num_classes): 161 | super().__init__() 162 | 163 | self.graph_encoder = graph_encoder 164 | self.predictor = nn.Sequential( 165 | nn.Linear(emb_size, emb_size), 166 | nn.GELU(), 167 | nn.Linear(emb_size, num_classes) 168 | ) 169 | 170 | def forward(self, A, x_n, abs_level, rel_level, A_n2g, y=None): 171 | h_g = self.graph_encoder(A, x_n, abs_level, rel_level, y, A_n2g) 172 | pred_g = self.predictor(h_g) 173 | 174 | return pred_g 175 | 176 | class TransformerLayer(nn.Module): 177 | def __init__(self, 178 | hidden_size, 179 | num_heads, 180 | dropout): 181 | super().__init__() 182 | 183 | self.to_v = nn.Linear(hidden_size, hidden_size) 184 | self.to_qk = nn.Linear(hidden_size, hidden_size * 2) 185 | 186 | self._reset_parameters() 187 | 188 | self.num_heads = num_heads 189 | head_dim = hidden_size // num_heads 190 | assert head_dim * num_heads == hidden_size, "hidden_size must be divisible by num_heads" 191 | self.scale = head_dim ** -0.5 192 | 193 | self.proj_new = nn.Sequential( 194 | nn.Linear(hidden_size, hidden_size), 195 | nn.Dropout(dropout) 196 | ) 197 | self.norm1 = nn.LayerNorm(hidden_size) 198 | 199 | self.out_proj = nn.Sequential( 200 | nn.Linear(hidden_size, 4 * hidden_size), 201 | nn.GELU(), 202 | nn.Linear(4 * hidden_size, hidden_size), 203 | nn.Dropout(dropout) 204 | ) 205 | self.norm2 = nn.LayerNorm(hidden_size) 206 | 207 | def _reset_parameters(self): 208 | nn.init.xavier_uniform_(self.to_v.weight) 209 | nn.init.xavier_uniform_(self.to_qk.weight) 210 | 211 | def attn(self, q, k, v, num_query_cumsum): 212 | """ 213 | Parameters 214 | ---------- 215 | q : torch.Tensor of shape (N, F) 216 | Query matrix for node representations. 217 | k : torch.Tensor of shape (N, F) 218 | Key matrix for node representations. 219 | v : torch.Tensor of shape (N, F) 220 | Value matrix for node representations. 221 | num_query_cumsum : torch.Tensor of shape (B + 1) 222 | num_query_cumsum[0] is 0, num_query_cumsum[i] is the number of queries 223 | for the first i graphs in the batch for i > 0. 224 | 225 | Returns 226 | ------- 227 | torch.Tensor of shape (N, F) 228 | Updated hidden representations of query nodes for the batch of graphs. 229 | """ 230 | # Handle different numbers of query nodes in the batch with padding. 231 | batch_size = len(num_query_cumsum) - 1 232 | num_query_nodes = torch.diff(num_query_cumsum) 233 | max_num_nodes = num_query_nodes.max().item() 234 | 235 | q_padded = q.new_zeros(batch_size, max_num_nodes, q.shape[-1]) 236 | k_padded = k.new_zeros(batch_size, max_num_nodes, k.shape[-1]) 237 | v_padded = v.new_zeros(batch_size, max_num_nodes, v.shape[-1]) 238 | pad_mask = q.new_zeros(batch_size, max_num_nodes).bool() 239 | 240 | for i in range(batch_size): 241 | q_padded[i, :num_query_nodes[i]] = q[num_query_cumsum[i]:num_query_cumsum[i + 1]] 242 | k_padded[i, :num_query_nodes[i]] = k[num_query_cumsum[i]:num_query_cumsum[i + 1]] 243 | v_padded[i, :num_query_nodes[i]] = v[num_query_cumsum[i]:num_query_cumsum[i + 1]] 244 | pad_mask[i, num_query_nodes[i]:] = True 245 | 246 | # Split F into H * D, where H is the number of heads 247 | # D is the dimension per head. 248 | 249 | # (B, H, max_num_nodes, D) 250 | q_padded = rearrange(q_padded, 'b n (h d) -> b h n d', h=self.num_heads) 251 | # (B, H, max_num_nodes, D) 252 | k_padded = rearrange(k_padded, 'b n (h d) -> b h n d', h=self.num_heads) 253 | # (B, H, max_num_nodes, D) 254 | v_padded = rearrange(v_padded, 'b n (h d) -> b h n d', h=self.num_heads) 255 | 256 | # Q * K^T / sqrt(D) 257 | # (B, H, max_num_nodes, max_num_nodes) 258 | dot = torch.matmul(q_padded, k_padded.transpose(-1, -2)) * self.scale 259 | # Mask unnormalized attention logits for non-existent nodes. 260 | dot = dot.masked_fill( 261 | pad_mask.unsqueeze(1).unsqueeze(2), 262 | float('-inf'), 263 | ) 264 | 265 | attn_scores = F.softmax(dot, dim=-1) 266 | # (B, H, max_num_nodes, D) 267 | h_n_padded = torch.matmul(attn_scores, v_padded) 268 | # (B * max_num_nodes, H * D) = (B * max_num_nodes, F) 269 | h_n_padded = rearrange(h_n_padded, 'b h n d -> (b n) (h d)') 270 | 271 | # Unpad the aggregation results. 272 | # (N, F) 273 | pad_mask = (~pad_mask).reshape(-1) 274 | return h_n_padded[pad_mask] 275 | 276 | def forward(self, h_n, num_query_cumsum): 277 | # Compute value matrix 278 | v_n = self.to_v(h_n) 279 | 280 | # Compute query and key matrices 281 | q_n, k_n = self.to_qk(h_n).chunk(2, dim=-1) 282 | 283 | h_n_new = self.attn(q_n, k_n, v_n, num_query_cumsum) 284 | h_n_new = self.proj_new(h_n_new) 285 | 286 | # Add & Norm 287 | h_n = self.norm1(h_n + h_n_new) 288 | h_n = self.norm2(h_n + self.out_proj(h_n)) 289 | 290 | return h_n 291 | 292 | class NodePredModel(nn.Module): 293 | def __init__(self, 294 | graph_encoder, 295 | num_x_n_cat, 296 | x_n_emb_size, 297 | t_emb_size, 298 | in_hidden_size, 299 | out_hidden_size, 300 | num_transformer_layers, 301 | num_heads, 302 | dropout): 303 | super().__init__() 304 | 305 | self.graph_encoder = graph_encoder 306 | num_real_classes = num_x_n_cat - 1 307 | self.x_n_emb = MultiEmbedding(num_real_classes, x_n_emb_size) 308 | self.t_emb = SinusoidalPE(t_emb_size) 309 | in_hidden_size = in_hidden_size + t_emb_size + len(num_real_classes) * x_n_emb_size 310 | self.project_h_n = nn.Sequential( 311 | nn.Linear(in_hidden_size, out_hidden_size), 312 | nn.GELU() 313 | ) 314 | 315 | self.trans_layers = nn.ModuleList() 316 | for _ in range(num_transformer_layers): 317 | self.trans_layers.append(TransformerLayer( 318 | out_hidden_size, num_heads, dropout 319 | )) 320 | 321 | self.pred_list = nn.ModuleList([]) 322 | num_real_classes = num_real_classes.tolist() 323 | for num_classes_f in num_real_classes: 324 | self.pred_list.append(nn.Sequential( 325 | nn.Linear(out_hidden_size, out_hidden_size), 326 | nn.GELU(), 327 | nn.Linear(out_hidden_size, num_classes_f) 328 | )) 329 | 330 | def forward_with_h_g(self, h_g, x_n_t, 331 | t, query2g, num_query_cumsum): 332 | h_t = self.t_emb(t) 333 | h_g = torch.cat([h_g, h_t], dim=1) 334 | 335 | h_n_t = self.x_n_emb(x_n_t) 336 | h_n_t = torch.cat([h_n_t, h_g[query2g]], dim=1) 337 | h_n_t = self.project_h_n(h_n_t) 338 | 339 | for trans_layer in self.trans_layers: 340 | h_n_t = trans_layer(h_n_t, num_query_cumsum) 341 | 342 | pred = [] 343 | for d in range(len(self.pred_list)): 344 | pred.append(self.pred_list[d](h_n_t)) 345 | 346 | return pred 347 | 348 | def forward(self, A, x_n, abs_level, rel_level, A_n2g, x_n_t, 349 | t, query2g, num_query_cumsum, y=None): 350 | """ 351 | Parameters 352 | ---------- 353 | x_n_t : torch.LongTensor of shape (Q) 354 | t : torch.LongTensor of shape (B, 1) 355 | query2g : torch.LongTensor of shape (Q) 356 | num_query_cumsum : torch.LongTensor of shape (B + 1) 357 | """ 358 | h_g = self.graph_encoder(A, x_n, abs_level, 359 | rel_level, y=y, A_n2g=A_n2g) 360 | return self.forward_with_h_g(h_g, x_n_t, t, query2g, 361 | num_query_cumsum) 362 | 363 | class EdgePredModel(nn.Module): 364 | def __init__(self, 365 | graph_encoder, 366 | t_emb_size, 367 | in_hidden_size, 368 | out_hidden_size): 369 | super().__init__() 370 | 371 | self.graph_encoder = graph_encoder 372 | self.t_emb = SinusoidalPE(t_emb_size) 373 | self.pred = nn.Sequential( 374 | nn.Linear(2 * in_hidden_size + t_emb_size, out_hidden_size), 375 | nn.GELU(), 376 | nn.Linear(out_hidden_size, 2) 377 | ) 378 | 379 | def forward(self, A, x_n, abs_level, rel_level, t, 380 | query_src, query_dst, y=None): 381 | """ 382 | t : torch.tensor of shape (num_queries, 1) 383 | """ 384 | h_n = self.graph_encoder(A, x_n, abs_level, rel_level, y=y) 385 | 386 | h_e = torch.cat([ 387 | self.t_emb(t), 388 | h_n[query_src], 389 | h_n[query_dst] 390 | ], dim=-1) 391 | 392 | return self.pred(h_e) 393 | 394 | class LayerDAG(nn.Module): 395 | def __init__(self, 396 | device, 397 | num_x_n_cat, 398 | node_count_encoder_config, 399 | max_layer_size, 400 | node_diffusion, 401 | node_pred_graph_encoder_config, 402 | node_predictor_config, 403 | edge_diffusion, 404 | edge_pred_graph_encoder_config, 405 | edge_predictor_config, 406 | max_level=None): 407 | """ 408 | Parameters 409 | ---------- 410 | num_x_n_cat : 411 | Case1: int 412 | Case2: torch.LongTensor of shape (num_feats) 413 | """ 414 | super().__init__() 415 | 416 | if isinstance(num_x_n_cat, int): 417 | num_x_n_cat = torch.LongTensor([num_x_n_cat]) 418 | 419 | self.dummy_x_n = num_x_n_cat - 1 420 | hidden_size = len(num_x_n_cat) * node_count_encoder_config['x_n_emb_size'] +\ 421 | node_count_encoder_config['pe_emb_size'] +\ 422 | node_count_encoder_config['y_emb_size'] 423 | node_count_encoder = BiMPNNEncoder(num_x_n_cat, 424 | hidden_size=hidden_size, 425 | **node_count_encoder_config).to(device) 426 | self.node_count_model = GraphClassifier( 427 | node_count_encoder, 428 | emb_size=hidden_size, 429 | num_classes=max_layer_size+1).to(device) 430 | 431 | self.node_diffusion = node_diffusion 432 | hidden_size = len(num_x_n_cat) * node_pred_graph_encoder_config['x_n_emb_size'] +\ 433 | node_pred_graph_encoder_config['pe_emb_size'] +\ 434 | node_pred_graph_encoder_config['y_emb_size'] 435 | node_pred_graph_encoder = BiMPNNEncoder(num_x_n_cat, hidden_size=hidden_size, 436 | **node_pred_graph_encoder_config).to(device) 437 | self.node_pred_model = NodePredModel(node_pred_graph_encoder, 438 | num_x_n_cat, 439 | node_pred_graph_encoder_config['x_n_emb_size'], 440 | in_hidden_size=hidden_size, 441 | **node_predictor_config).to(device) 442 | 443 | self.edge_diffusion = edge_diffusion 444 | hidden_size = len(num_x_n_cat) * edge_pred_graph_encoder_config['x_n_emb_size'] +\ 445 | edge_pred_graph_encoder_config['pe_emb_size'] +\ 446 | edge_pred_graph_encoder_config['y_emb_size'] 447 | edge_pred_graph_encoder = BiMPNNEncoder(num_x_n_cat, hidden_size=hidden_size, 448 | **edge_pred_graph_encoder_config).to(device) 449 | self.edge_pred_model = EdgePredModel(edge_pred_graph_encoder, 450 | in_hidden_size=hidden_size, 451 | **edge_predictor_config).to(device) 452 | 453 | self.max_level = max_level 454 | 455 | def posterior(self, Z_t, Q_t, Q_bar_s, Q_bar_t, Z_0): 456 | # (num_rows, num_classes) 457 | left_term = Z_t @ torch.transpose(Q_t, -1, -2) 458 | # (num_rows, 1, num_classes) 459 | left_term = left_term.unsqueeze(dim=-2) 460 | # (1, num_classes, num_classes) 461 | right_term = Q_bar_s.unsqueeze(dim=-3) 462 | # (num_rows, num_classes, num_classes) 463 | numerator = left_term * right_term 464 | 465 | # (num_classes, num_rows) 466 | prod = Q_bar_t @ torch.transpose(Z_t, -1, -2) 467 | # (num_rows, num_classes) 468 | prod = torch.transpose(prod, -1, -2) 469 | # (num_rows, num_classes, 1) 470 | denominator = prod.unsqueeze(-1) 471 | denominator[denominator == 0.] = 1. 472 | # (num_rows, num_classes, num_classes) 473 | out = numerator / denominator 474 | 475 | # (num_rows, num_classes, num_classes) 476 | prob = Z_0.unsqueeze(-1) * out 477 | # (num_rows, num_classes) 478 | prob = prob.sum(dim=-2) 479 | 480 | return prob 481 | 482 | def posterior_edge(self, 483 | Z_t, 484 | alpha_t, 485 | alpha_bar_s, 486 | alpha_bar_t, 487 | Z_0, 488 | marginal_list, 489 | num_new_nodes_list, 490 | num_query_list): 491 | batch_size = len(num_new_nodes_list) 492 | Z_t_list = torch.split(Z_t, num_query_list, dim=0) 493 | Z_0_list = torch.split(Z_0, num_query_list, dim=0) 494 | device = Z_t.device 495 | e_mask_list = [] 496 | 497 | for i in range(batch_size): 498 | Z_t_i = Z_t_list[i] 499 | Z_0_i = Z_0_list[i] 500 | 501 | Q_t_i, Q_bar_s_i, Q_bar_t_i = self.edge_diffusion.get_Qs( 502 | alpha_t, alpha_bar_s, alpha_bar_t, marginal_list[i]) 503 | Q_t_i = Q_t_i.to(device) 504 | Q_bar_s_i = Q_bar_s_i.to(device) 505 | Q_bar_t_i = Q_bar_t_i.to(device) 506 | 507 | # (num_rows, num_classes) 508 | left_term_i = Z_t_i @ torch.transpose(Q_t_i, -1, -2) 509 | # (num_rows, 1, num_classes) 510 | left_term_i = left_term_i.unsqueeze(dim=-2) 511 | # (1, num_classes, num_classes) 512 | right_term_i = Q_bar_s_i.unsqueeze(dim=-3) 513 | # (num_rows, num_classes, num_classes) 514 | numerator_i = left_term_i * right_term_i 515 | 516 | # (num_classes, num_rows) 517 | prod_i = Q_bar_t_i @ torch.transpose(Z_t_i, -1, -2) 518 | # (num_rows, num_classes) 519 | prod_i = torch.transpose(prod_i, -1, -2) 520 | # (num_rows, num_classes, 1) 521 | denominator_i = prod_i.unsqueeze(-1) 522 | denominator_i[denominator_i == 0.] = 1. 523 | # (num_rows, num_classes, num_classes) 524 | out_i = numerator_i / denominator_i 525 | 526 | # (num_rows, num_classes, num_classes) 527 | prob_i = Z_0_i.unsqueeze(-1) * out_i 528 | # (num_rows, num_classes) 529 | prob_i = prob_i.sum(dim=-2) 530 | prob_i = prob_i / (prob_i.sum(dim=-1, keepdim=True) + 1e-6) 531 | 532 | # Get the probabilities for edge existence. 533 | prob_i = prob_i[:, 1] 534 | prob_i = prob_i.reshape(num_new_nodes_list[i], -1) 535 | e_mask_i = torch.bernoulli(prob_i) 536 | 537 | isolated_mask_i = (e_mask_i.sum(dim=1) == 0).bool() 538 | if isolated_mask_i.any(): 539 | e_mask_i[isolated_mask_i, prob_i[isolated_mask_i].argmax(dim=1)] = 1 540 | e_mask_list.append(e_mask_i.reshape(-1)) 541 | 542 | return torch.cat(e_mask_list).bool() 543 | 544 | @torch.no_grad() 545 | def sample_node_layer(self, 546 | A, 547 | x_n, 548 | abs_level, 549 | rel_level, 550 | A_n2g, 551 | curr_level=None, 552 | y=None, 553 | min_num_steps_n=None, 554 | max_num_steps_n=None): 555 | device = A.device 556 | 557 | node_count_logits = self.node_count_model(A, x_n, abs_level, 558 | rel_level, A_n2g=A_n2g, y=y) 559 | 560 | # For the first layer, the layer size must be nonzero. 561 | if curr_level == 0: 562 | node_count_logits[:, 0] = float('-inf') 563 | 564 | node_count_probs = node_count_logits.softmax(dim=-1) 565 | num_new_nodes = node_count_probs.multinomial(1) 566 | 567 | num_new_nodes_total = num_new_nodes.sum().item() 568 | batch_size = num_new_nodes.shape[0] 569 | if num_new_nodes_total == 0: 570 | return [torch.LongTensor([]).to(device) 571 | for _ in range(batch_size)] 572 | 573 | num_classes_list = self.node_diffusion.num_classes_list 574 | marginal_list = self.node_diffusion.m_list 575 | D = len(num_classes_list) 576 | 577 | x_n_t = [] 578 | for d in range(D): 579 | marginal_d = marginal_list[d] 580 | prior_d = marginal_d[0][None, :].expand(num_new_nodes_total, -1) 581 | # (num_new_nodes_total) 582 | x_n_t_d = prior_d.multinomial(1).squeeze(-1) 583 | x_n_t.append(x_n_t_d) 584 | x_n_t = torch.stack(x_n_t, dim=1).to(device) 585 | 586 | # Iteratively sample p(D^s | D^t) for t = 1, ..., T, with s = t - 1. 587 | h_g = self.node_pred_model.graph_encoder(A, x_n, abs_level, rel_level, 588 | y=y, A_n2g=A_n2g) 589 | 590 | num_query_cumsum = torch.cumsum(torch.tensor( 591 | [0] + num_new_nodes.squeeze(-1).tolist()), dim=0).long().to(device) 592 | query2g = [] 593 | for i in range(batch_size): 594 | query2g.append(torch.ones(num_query_cumsum[i+1] - num_query_cumsum[i]).fill_(i).long()) 595 | query2g = torch.cat(query2g).to(device) 596 | 597 | T_x_n = self.node_diffusion.T 598 | if max_num_steps_n is not None: 599 | T_x_n = min(T_x_n, max_num_steps_n) 600 | 601 | time_x_n_list = list(reversed(range(0, T_x_n))) 602 | if min_num_steps_n is not None: 603 | num_steps_n = min_num_steps_n + int( 604 | (T_x_n - min_num_steps_n) * (curr_level / self.max_level) 605 | ) 606 | time_x_n_list = time_x_n_list[-num_steps_n:] 607 | 608 | for s_x_n in time_x_n_list: 609 | t_x_n = s_x_n + 1 610 | 611 | # Note that computing Q_bar_t from alpha_bar_t is the same 612 | # as computing Q_t from alpha_t. 613 | alpha_t = self.node_diffusion.alphas[t_x_n] 614 | alpha_bar_s = self.node_diffusion.alpha_bars[s_x_n] 615 | alpha_bar_t = self.node_diffusion.alpha_bars[t_x_n] 616 | 617 | t_x_n_tensor = torch.LongTensor([[t_x_n]]).expand(batch_size, -1).to(device) 618 | x_n_0_logits = self.node_pred_model.forward_with_h_g( 619 | h_g, x_n_t, t_x_n_tensor, query2g, 620 | num_query_cumsum) 621 | 622 | x_n_s = [] 623 | for d in range(D): 624 | Q_t_d = self.node_diffusion.get_Q(alpha_t, d).to(device) 625 | Q_bar_s_d = self.node_diffusion.get_Q(alpha_bar_s, d).to(device) 626 | Q_bar_t_d = self.node_diffusion.get_Q(alpha_bar_t, d).to(device) 627 | 628 | x_n_0_probs_d = x_n_0_logits[d].softmax(dim=-1) 629 | # (num_new_nodes, num_classes) 630 | x_n_t_one_hot_d = F.one_hot(x_n_t[:, d], num_classes=num_classes_list[d]).float() 631 | 632 | x_n_s_probs_d = self.posterior(x_n_t_one_hot_d, Q_t_d, Q_bar_s_d, 633 | Q_bar_t_d, x_n_0_probs_d) 634 | x_n_s_d = x_n_s_probs_d.multinomial(1).squeeze(-1) 635 | x_n_s.append(x_n_s_d) 636 | 637 | x_n_t = torch.stack(x_n_s, dim=1) 638 | 639 | return torch.split(x_n_t, num_new_nodes.squeeze(-1).tolist()) 640 | 641 | @torch.no_grad() 642 | def sample_edge_layer(self, num_nodes_cumsum, edge_index_list, 643 | batch_x_n, batch_abs_level, batch_rel_level, 644 | num_new_nodes_list, batch_query_src, batch_query_dst, 645 | query_src_list, query_dst_list, 646 | batch_y=None, 647 | curr_level=None, 648 | min_num_steps_e=None, 649 | max_num_steps_e=None): 650 | device = batch_x_n.device 651 | 652 | e_t_mask_list = [] 653 | batch_size = len(num_new_nodes_list) 654 | marginal_list = [] 655 | num_query_list = [] 656 | for i in range(batch_size): 657 | num_query_i = len(query_src_list[i]) 658 | num_query_list.append(num_query_i) 659 | 660 | num_new_nodes_i = num_new_nodes_list[i] 661 | prior_i = torch.ones(num_query_i).reshape(num_new_nodes_i, -1) 662 | mean_in_deg_i = min(self.edge_diffusion.avg_in_deg, prior_i.shape[1]) 663 | marginal_i = mean_in_deg_i / prior_i.shape[1] 664 | marginal_list.append(marginal_i) 665 | prior_i = prior_i * marginal_i 666 | e_t_mask_i = torch.bernoulli(prior_i) 667 | isolated_mask = (e_t_mask_i.sum(dim=1) == 0).bool() 668 | if isolated_mask.any(): 669 | e_t_mask_i[isolated_mask, torch.zeros(int(isolated_mask.sum().item())).long()] = 1 670 | e_t_mask_list.append(e_t_mask_i.reshape(-1)) 671 | 672 | e_t_mask = torch.cat(e_t_mask_list).bool().to(device) 673 | 674 | num_nodes = len(batch_x_n) 675 | num_queries = len(batch_query_src) 676 | 677 | batch_edge_index = self.get_batch_A( 678 | num_nodes_cumsum, edge_index_list, device, 679 | return_edge_index=True) 680 | 681 | # Iteratively sample p(D^s | D^t) for t = 1, ..., T, with s = t - 1. 682 | T_x_e = self.edge_diffusion.T 683 | if max_num_steps_e is not None: 684 | T_x_e = min(T_x_e, max_num_steps_e) 685 | 686 | time_x_e_list = list(reversed(range(0, T_x_e))) 687 | if min_num_steps_e is not None: 688 | num_steps_e = min_num_steps_e + int( 689 | (T_x_e - min_num_steps_e) * (curr_level / self.max_level) 690 | ) 691 | time_x_e_list = time_x_e_list[-num_steps_e:] 692 | 693 | for s_x_e in time_x_e_list: 694 | t_x_e = s_x_e + 1 695 | 696 | # Note that computing Q_bar_t from alpha_bar_t is the same 697 | # as computing Q_t from alpha_t. 698 | alpha_t = self.edge_diffusion.alphas[t_x_e] 699 | alpha_bar_s = self.edge_diffusion.alpha_bars[s_x_e] 700 | alpha_bar_t = self.edge_diffusion.alpha_bars[t_x_e] 701 | 702 | edge_index_t = torch.stack([ 703 | batch_query_dst[e_t_mask], 704 | batch_query_src[e_t_mask] 705 | ]).to(device) 706 | 707 | A = dglsp.spmatrix( 708 | torch.cat([batch_edge_index, edge_index_t], dim=1), 709 | shape=(num_nodes, num_nodes)).to(device) 710 | t_x_e_tensor = torch.LongTensor([t_x_e])[None, :].expand( 711 | num_queries, -1).to(device) 712 | e_0_logits = self.edge_pred_model( 713 | A, batch_x_n, batch_abs_level, batch_rel_level, t_x_e_tensor, 714 | batch_query_src, batch_query_dst, batch_y) 715 | e_0_probs = e_0_logits.softmax(dim=-1) 716 | # (num_queries, num_classes) 717 | e_t_one_hot = F.one_hot(e_t_mask.long(), num_classes=2).float() 718 | 719 | e_t_mask = self.posterior_edge(e_t_one_hot, 720 | alpha_t, 721 | alpha_bar_s, 722 | alpha_bar_t, 723 | e_0_probs, 724 | marginal_list, 725 | num_new_nodes_list, 726 | num_query_list) 727 | 728 | num_query_split = [len(query_src_i) for query_src_i in query_src_list] 729 | e_t_mask_split = torch.split(e_t_mask, num_query_split) 730 | 731 | edge_index_list_ = [] 732 | for i in range(len(edge_index_list)): 733 | edge_index_i = edge_index_list[i] 734 | e_t_mask_i = e_t_mask_split[i] 735 | edge_index_l_i = torch.stack([ 736 | query_dst_list[i][e_t_mask_i], 737 | query_src_list[i][e_t_mask_i] 738 | ]) 739 | edge_index_i = torch.cat([edge_index_i, edge_index_l_i], dim=1) 740 | edge_index_list_.append(edge_index_i) 741 | edge_index_list = edge_index_list_ 742 | 743 | return edge_index_list 744 | 745 | def get_batch_A(self, num_nodes_cumsum, edge_index_list, device, return_edge_index=False): 746 | batch_size = len(edge_index_list) 747 | edge_index_list_ = [] 748 | for i in range(batch_size): 749 | edge_index_list_.append(edge_index_list[i] + num_nodes_cumsum[i]) 750 | 751 | batch_edge_index = torch.cat(edge_index_list_, dim=1) 752 | 753 | if return_edge_index: 754 | return batch_edge_index 755 | 756 | N = num_nodes_cumsum[-1].item() 757 | batch_A = dglsp.spmatrix(batch_edge_index, shape=(N, N)).to(device) 758 | 759 | return batch_A 760 | 761 | def get_batch_A_n2g(self, num_nodes_cumsum, device): 762 | batch_size = len(num_nodes_cumsum) - 1 763 | nids = [] 764 | gids = [] 765 | for i in range(batch_size): 766 | nids.append(torch.arange(num_nodes_cumsum[i], num_nodes_cumsum[i+1]).long()) 767 | gids.append(torch.ones(num_nodes_cumsum[i+1] - num_nodes_cumsum[i]).fill_(i).long()) 768 | 769 | nids = torch.cat(nids, dim=0) 770 | gids = torch.cat(gids, dim=0) 771 | n2g_index = torch.stack([gids, nids]) 772 | 773 | N = num_nodes_cumsum[-1].item() 774 | batch_A_n2g = dglsp.spmatrix(n2g_index, shape=(batch_size, N)).to(device) 775 | 776 | return batch_A_n2g 777 | 778 | def get_batch_y(self, y_list, x_n_list, device): 779 | if y_list is None: 780 | return None 781 | 782 | y_list_ = [] 783 | for i in range(len(x_n_list)): 784 | y_list_.append(torch.zeros(len(x_n_list[i]), 1).fill_(y_list[i])) 785 | batch_y = torch.cat(y_list_).to(device) 786 | 787 | return batch_y 788 | 789 | @torch.no_grad() 790 | def sample(self, 791 | device, 792 | batch_size=1, 793 | y=None, 794 | min_num_steps_n=None, 795 | max_num_steps_n=None, 796 | min_num_steps_e=None, 797 | max_num_steps_e=None): 798 | if y is not None: 799 | assert batch_size == len(y) 800 | y_list = y 801 | 802 | edge_index_list = [ 803 | torch.LongTensor([[], []]).to(device) 804 | for _ in range(batch_size) 805 | ] 806 | 807 | if isinstance(self.dummy_x_n, int): 808 | init_x_n = torch.LongTensor([[self.dummy_x_n]]).to(device) 809 | elif isinstance(self.dummy_x_n, torch.Tensor): 810 | init_x_n = self.dummy_x_n.to(device).unsqueeze(0) 811 | else: 812 | raise NotImplementedError 813 | x_n_list = [init_x_n for _ in range(batch_size)] 814 | batch_x_n = torch.cat(x_n_list) 815 | batch_y = self.get_batch_y(y_list, x_n_list, device) 816 | 817 | level = 0. 818 | abs_level_list = [ 819 | torch.tensor([[level]]).to(device) 820 | for _ in range(batch_size) 821 | ] 822 | batch_abs_level = torch.cat(abs_level_list) 823 | batch_rel_level = batch_abs_level.max() - batch_abs_level 824 | 825 | edge_index_finished = [] 826 | x_n_finished = [] 827 | if y is not None: 828 | y_finished = [] 829 | 830 | num_nodes_cumsum = torch.cumsum(torch.tensor( 831 | [0] + [len(x_n_i) for x_n_i in x_n_list]), dim=0) 832 | while True: 833 | batch_A = self.get_batch_A(num_nodes_cumsum, edge_index_list, device) 834 | batch_A_n2g = self.get_batch_A_n2g(num_nodes_cumsum, device) 835 | x_n_l_list = self.sample_node_layer( 836 | batch_A, batch_x_n, batch_abs_level, batch_rel_level, 837 | batch_A_n2g, curr_level=level, 838 | y=batch_y, 839 | min_num_steps_n=min_num_steps_n, 840 | max_num_steps_n=max_num_steps_n) 841 | 842 | edge_index_list_ = [] 843 | x_n_list_ = [] 844 | abs_level_list_ = [] 845 | query_src_list = [] 846 | query_dst_list = [] 847 | num_new_nodes_list = [] 848 | batch_query_src = [] 849 | batch_query_dst = [] 850 | 851 | if y is not None: 852 | y_list_ = [] 853 | else: 854 | y_list_ = None 855 | 856 | level += 1 857 | node_count = 0 858 | for i, x_n_l_i in enumerate(x_n_l_list): 859 | if len(x_n_l_i) == 0: 860 | edge_index_finished.append(edge_index_list[i] - 1) 861 | x_n_finished.append(x_n_list[i][1:]) 862 | if y is not None: 863 | y_finished.append(y_list[i]) 864 | else: 865 | edge_index_list_.append(edge_index_list[i]) 866 | x_n_list_.append(torch.cat([x_n_list[i], x_n_l_i])) 867 | if y is not None: 868 | y_list_.append(y_list[i]) 869 | abs_level_list_.append( 870 | torch.cat([ 871 | abs_level_list[i], 872 | torch.zeros(len(x_n_l_i), 1).fill_(level).to(device) 873 | ]) 874 | ) 875 | 876 | N_old_i = len(x_n_list[i]) 877 | N_new_i = len(x_n_l_i) 878 | 879 | query_src_i = [] 880 | query_dst_i = [] 881 | 882 | src_candidates_i = list(range(1, N_old_i)) 883 | for dst_i in range(N_old_i, N_old_i + N_new_i): 884 | query_src_i.extend(src_candidates_i) 885 | query_dst_i.extend([dst_i] * len(src_candidates_i)) 886 | query_src_i = torch.LongTensor(query_src_i).to(device) 887 | query_dst_i = torch.LongTensor(query_dst_i).to(device) 888 | 889 | query_src_list.append(query_src_i) 890 | query_dst_list.append(query_dst_i) 891 | batch_query_src.append(query_src_i + node_count) 892 | batch_query_dst.append(query_dst_i + node_count) 893 | num_new_nodes_list.append(N_new_i) 894 | 895 | node_count = node_count + N_old_i + N_new_i 896 | 897 | edge_index_list = edge_index_list_ 898 | x_n_list = x_n_list_ 899 | y_list = y_list_ 900 | abs_level_list = abs_level_list_ 901 | 902 | if len(edge_index_list) == 0: 903 | break 904 | 905 | num_nodes_cumsum = torch.cumsum(torch.tensor( 906 | [0] + [len(x_n_i) for x_n_i in x_n_list]), dim=0) 907 | batch_x_n = torch.cat(x_n_list) 908 | batch_abs_level = torch.cat(abs_level_list) 909 | batch_rel_level = batch_abs_level.max() - batch_abs_level 910 | batch_y = self.get_batch_y(y_list, x_n_list, device) 911 | 912 | if level == 1: 913 | continue 914 | 915 | batch_query_src = torch.cat(batch_query_src) 916 | batch_query_dst = torch.cat(batch_query_dst) 917 | 918 | edge_index_list = self.sample_edge_layer( 919 | num_nodes_cumsum, edge_index_list, batch_x_n, batch_abs_level, 920 | batch_rel_level, num_new_nodes_list, batch_query_src, 921 | batch_query_dst, query_src_list, query_dst_list, batch_y, 922 | curr_level=level, 923 | min_num_steps_e=min_num_steps_e, 924 | max_num_steps_e=max_num_steps_e) 925 | 926 | if self.max_level is not None and level == self.max_level: 927 | break 928 | 929 | for i in range(len(edge_index_list)): 930 | edge_index_finished.append(edge_index_list[i] - 1) 931 | x_n_finished.append(x_n_list[i][1:]) 932 | 933 | if y is None: 934 | return edge_index_finished, x_n_finished 935 | else: 936 | y_finished.extend(y_list) 937 | return edge_index_finished, x_n_finished, y_finished 938 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import dgl.sparse as dglsp 2 | import pandas as pd 3 | import time 4 | import torch 5 | import torch.nn as nn 6 | import wandb 7 | 8 | from copy import deepcopy 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | from setup_utils import set_seed, load_yaml 13 | from src.dataset import load_dataset, LayerDAGNodeCountDataset,\ 14 | LayerDAGNodePredDataset, LayerDAGEdgePredDataset, collate_node_count,\ 15 | collate_node_pred, collate_edge_pred 16 | from src.model import DiscreteDiffusion, EdgeDiscreteDiffusion, LayerDAG 17 | 18 | @torch.no_grad() 19 | def eval_node_count(device, val_loader, model): 20 | model.eval() 21 | total_nll = 0 22 | total_count = 0 23 | true_count = 0 24 | for batch_data in tqdm(val_loader): 25 | if len(batch_data) == 8: 26 | batch_size, batch_edge_index, batch_x_n, batch_abs_level,\ 27 | batch_rel_level, batch_y, batch_n2g_index, batch_label = batch_data 28 | batch_y = batch_y.to(device) 29 | else: 30 | batch_size, batch_edge_index, batch_x_n, batch_abs_level,\ 31 | batch_rel_level, batch_n2g_index, batch_label = batch_data 32 | batch_y = None 33 | 34 | num_nodes = len(batch_x_n) 35 | batch_A = dglsp.spmatrix( 36 | batch_edge_index, shape=(num_nodes, num_nodes)).to(device) 37 | batch_x_n = batch_x_n.to(device) 38 | batch_abs_level = batch_abs_level.to(device) 39 | batch_rel_level = batch_rel_level.to(device) 40 | batch_A_n2g = dglsp.spmatrix( 41 | batch_n2g_index, shape=(batch_size, num_nodes)).to(device) 42 | batch_label = batch_label.to(device) 43 | 44 | batch_logits = model(batch_A, batch_x_n, batch_abs_level, 45 | batch_rel_level, batch_A_n2g, batch_y) 46 | 47 | batch_nll = -batch_logits.log_softmax(dim=-1) 48 | # In case the max layer size in the validation set is larger than 49 | # that in the training set. 50 | batch_label = batch_label.clamp(max=batch_nll.shape[-1] - 1) 51 | batch_nll = batch_nll[torch.arange(batch_size).to(device), batch_label] 52 | total_nll += batch_nll.sum().item() 53 | 54 | batch_probs = batch_logits.softmax(dim=-1) 55 | batch_preds = batch_probs.multinomial(1).squeeze(-1) 56 | true_count += (batch_preds == batch_label).sum().item() 57 | 58 | total_count += batch_size 59 | 60 | return total_nll / total_count, true_count / total_count 61 | 62 | def main_node_count(device, train_set, val_set, model, config, patience): 63 | train_loader = DataLoader(train_set, 64 | shuffle=True, 65 | collate_fn=collate_node_count, 66 | **config['loader'], 67 | drop_last=True) 68 | val_loader = DataLoader(val_set, 69 | shuffle=False, 70 | collate_fn=collate_node_count, 71 | **config['loader']) 72 | criterion = nn.CrossEntropyLoss() 73 | optimizer = torch.optim.Adam(model.parameters(), **config['optimizer']) 74 | 75 | best_val_nll = float('inf') 76 | best_val_acc = 0 77 | best_state_dict = deepcopy(model.state_dict()) 78 | num_patient_epochs = 0 79 | for epoch in range(config['num_epochs']): 80 | model.train() 81 | for batch_data in tqdm(train_loader): 82 | if len(batch_data) == 8: 83 | batch_size, batch_edge_index, batch_x_n, batch_abs_level,\ 84 | batch_rel_level, batch_y, batch_n2g_index, batch_label = batch_data 85 | batch_y = batch_y.to(device) 86 | else: 87 | batch_size, batch_edge_index, batch_x_n, batch_abs_level,\ 88 | batch_rel_level, batch_n2g_index, batch_label = batch_data 89 | batch_y = None 90 | 91 | num_nodes = len(batch_x_n) 92 | batch_A = dglsp.spmatrix(batch_edge_index, shape=(num_nodes, num_nodes)).to(device) 93 | batch_x_n = batch_x_n.to(device) 94 | batch_abs_level = batch_abs_level.to(device) 95 | batch_rel_level = batch_rel_level.to(device) 96 | batch_A_n2g = dglsp.spmatrix(batch_n2g_index, shape=(batch_size, num_nodes)).to(device) 97 | batch_label = batch_label.to(device) 98 | 99 | batch_pred = model(batch_A, batch_x_n, batch_abs_level, 100 | batch_rel_level, batch_A_n2g, batch_y) 101 | 102 | loss = criterion(batch_pred, batch_label) 103 | optimizer.zero_grad() 104 | loss.backward() 105 | optimizer.step() 106 | 107 | wandb.log({'node_count/loss': loss.item()}) 108 | 109 | val_nll, val_acc = eval_node_count(device, val_loader, model) 110 | if val_nll < best_val_nll: 111 | best_val_nll = val_nll 112 | if val_acc > best_val_acc: 113 | best_val_acc = val_acc 114 | best_state_dict = deepcopy(model.state_dict()) 115 | num_patient_epochs = 0 116 | else: 117 | num_patient_epochs += 1 118 | wandb.log({'node_count/epoch': epoch, 119 | 'node_count/val_nll': val_nll, 120 | 'node_count/best_val_nll': best_val_nll, 121 | 'node_count/val_acc': val_acc, 122 | 'node_count/best_val_acc': best_val_acc, 123 | 'node_count/num_patient_epochs': num_patient_epochs}) 124 | 125 | if (patience is not None) and (num_patient_epochs == patience): 126 | break 127 | 128 | return best_state_dict 129 | 130 | @torch.no_grad() 131 | def eval_node_pred(device, val_loader, model): 132 | model.eval() 133 | total_nll = 0 134 | total_count = 0 135 | for batch_data in tqdm(val_loader): 136 | if len(batch_data) == 11: 137 | batch_size, batch_edge_index, batch_x_n, batch_abs_level,\ 138 | batch_rel_level, batch_n2g_index, batch_z_t, batch_t, query2g,\ 139 | num_query_cumsum, batch_z = batch_data 140 | batch_y = None 141 | else: 142 | batch_size, batch_edge_index, batch_x_n, batch_abs_level,\ 143 | batch_rel_level, batch_n2g_index, batch_z_t, batch_t, batch_y,\ 144 | query2g, num_query_cumsum, batch_z = batch_data 145 | batch_y = batch_y.to(device) 146 | 147 | num_nodes = len(batch_x_n) 148 | batch_A = dglsp.spmatrix( 149 | batch_edge_index, shape=(num_nodes, num_nodes)).to(device) 150 | batch_x_n = batch_x_n.to(device) 151 | batch_abs_level = batch_abs_level.to(device) 152 | batch_rel_level = batch_rel_level.to(device) 153 | batch_A_n2g = dglsp.spmatrix( 154 | batch_n2g_index, shape=(batch_size, num_nodes)).to(device) 155 | batch_z_t = batch_z_t.to(device) 156 | batch_t = batch_t.to(device) 157 | query2g = query2g.to(device) 158 | num_query_cumsum = num_query_cumsum.to(device) 159 | batch_z = batch_z.to(device) 160 | 161 | batch_logits = model(batch_A, batch_x_n, batch_abs_level, 162 | batch_rel_level, batch_A_n2g, batch_z_t, batch_t, 163 | query2g, num_query_cumsum, batch_y) 164 | 165 | D = len(batch_logits) 166 | batch_num_queries = batch_logits[0].shape[0] 167 | for d in range(D): 168 | batch_logits_d = batch_logits[d] 169 | batch_nll_d = -batch_logits_d.log_softmax(dim=-1) 170 | batch_nll_d = batch_nll_d[torch.arange(batch_num_queries).to(device), batch_z[:, d]] 171 | total_nll += batch_nll_d.sum().item() 172 | total_count += batch_num_queries * D 173 | 174 | return total_nll / total_count 175 | 176 | def main_node_pred(device, train_set, val_set, model, config, patience): 177 | train_loader = DataLoader(train_set, 178 | shuffle=True, 179 | collate_fn=collate_node_pred, 180 | **config['loader']) 181 | val_loader = DataLoader(val_set, 182 | collate_fn=collate_node_pred, 183 | **config['loader']) 184 | criterion = nn.CrossEntropyLoss() 185 | optimizer = torch.optim.Adam(model.parameters(), **config['optimizer']) 186 | 187 | best_val_nll = float('inf') 188 | best_state_dict = deepcopy(model.state_dict()) 189 | num_patient_epochs = 0 190 | for epoch in range(config['num_epochs']): 191 | val_nll = eval_node_pred(device, val_loader, model) 192 | if val_nll < best_val_nll: 193 | best_val_nll = val_nll 194 | best_state_dict = deepcopy(model.state_dict()) 195 | num_patient_epochs = 0 196 | else: 197 | num_patient_epochs += 1 198 | 199 | wandb.log({'node_pred/epoch': epoch, 200 | 'node_pred/val_nll': val_nll, 201 | 'node_pred/best_val_nll': best_val_nll, 202 | 'node_pred/num_patient_epochs': num_patient_epochs}) 203 | 204 | if (patience is not None) and (num_patient_epochs == patience): 205 | break 206 | 207 | model.train() 208 | for batch_data in tqdm(train_loader): 209 | if len(batch_data) == 11: 210 | batch_size, batch_edge_index, batch_x_n, batch_abs_level,\ 211 | batch_rel_level, batch_n2g_index, batch_z_t, batch_t,\ 212 | query2g, num_query_cumsum, batch_z = batch_data 213 | batch_y = None 214 | else: 215 | batch_size, batch_edge_index, batch_x_n, batch_abs_level,\ 216 | batch_rel_level, batch_n2g_index, batch_z_t, batch_t,\ 217 | batch_y, query2g, num_query_cumsum, batch_z = batch_data 218 | batch_y = batch_y.to(device) 219 | 220 | num_nodes = len(batch_x_n) 221 | batch_A = dglsp.spmatrix( 222 | batch_edge_index, shape=(num_nodes, num_nodes)).to(device) 223 | batch_x_n = batch_x_n.to(device) 224 | batch_abs_level = batch_abs_level.to(device) 225 | batch_rel_level = batch_rel_level.to(device) 226 | batch_A_n2g = dglsp.spmatrix( 227 | batch_n2g_index, shape=(batch_size, num_nodes)).to(device) 228 | batch_z_t = batch_z_t.to(device) 229 | batch_t = batch_t.to(device) 230 | query2g = query2g.to(device) 231 | num_query_cumsum = num_query_cumsum.to(device) 232 | batch_z = batch_z.to(device) 233 | 234 | batch_pred = model(batch_A, batch_x_n, batch_abs_level, 235 | batch_rel_level, batch_A_n2g, batch_z_t, 236 | batch_t, query2g, num_query_cumsum, batch_y) 237 | 238 | loss = 0 239 | D = len(batch_pred) 240 | for d in range(D): 241 | loss = loss + criterion(batch_pred[d], batch_z[:, d]) 242 | loss /= D 243 | 244 | optimizer.zero_grad() 245 | loss.backward() 246 | optimizer.step() 247 | 248 | wandb.log({'node_pred/loss': loss.item()}) 249 | 250 | return best_state_dict 251 | 252 | @torch.no_grad() 253 | def eval_edge_pred(device, val_loader, model): 254 | model.eval() 255 | total_nll = 0 256 | total_count = 0 257 | for batch_data in tqdm(val_loader): 258 | if len(batch_data) == 9: 259 | batch_edge_index, batch_noisy_edge_index, batch_x_n,\ 260 | batch_abs_level, batch_rel_level, batch_t, batch_query_src,\ 261 | batch_query_dst, batch_label = batch_data 262 | batch_y = None 263 | else: 264 | batch_edge_index, batch_noisy_edge_index, batch_x_n,\ 265 | batch_abs_level, batch_rel_level, batch_t, batch_y,\ 266 | batch_query_src, batch_query_dst, batch_label = batch_data 267 | batch_y = batch_y.to(device) 268 | 269 | num_nodes = len(batch_x_n) 270 | batch_A = dglsp.spmatrix( 271 | torch.cat([batch_edge_index, batch_noisy_edge_index], dim=1), 272 | shape=(num_nodes, num_nodes)).to(device) 273 | batch_x_n = batch_x_n.to(device) 274 | batch_abs_level = batch_abs_level.to(device) 275 | batch_rel_level = batch_rel_level.to(device) 276 | batch_t = batch_t.to(device) 277 | batch_query_src = batch_query_src.to(device) 278 | batch_query_dst = batch_query_dst.to(device) 279 | batch_label = batch_label.to(device) 280 | 281 | batch_logits = model(batch_A, batch_x_n, batch_abs_level, 282 | batch_rel_level, batch_t, batch_query_src, 283 | batch_query_dst, batch_y) 284 | batch_nll = -batch_logits.log_softmax(dim=-1) 285 | batch_num_queries = batch_logits.shape[0] 286 | batch_nll = batch_nll[ 287 | torch.arange(batch_num_queries).to(device), batch_label] 288 | total_nll += batch_nll.sum().item() 289 | total_count += batch_num_queries 290 | 291 | return total_nll / total_count 292 | 293 | def main_edge_pred(device, train_set, val_set, model, config, patience): 294 | train_loader = DataLoader(train_set, 295 | shuffle=True, 296 | collate_fn=collate_edge_pred, 297 | **config['loader']) 298 | val_loader = DataLoader(val_set, 299 | collate_fn=collate_edge_pred, 300 | **config['loader']) 301 | criterion = nn.CrossEntropyLoss() 302 | optimizer = torch.optim.Adam(model.parameters(), **config['optimizer']) 303 | 304 | best_val_nll = float('inf') 305 | best_state_dict = deepcopy(model.state_dict()) 306 | num_patient_epochs = 0 307 | for epoch in range(config['num_epochs']): 308 | val_nll = eval_edge_pred(device, val_loader, model) 309 | if val_nll < best_val_nll: 310 | best_val_nll = val_nll 311 | best_state_dict = deepcopy(model.state_dict()) 312 | num_patient_epochs = 0 313 | else: 314 | num_patient_epochs += 1 315 | wandb.log({'edge_pred/epoch': epoch, 316 | 'edge_pred/val_nll': val_nll, 317 | 'edge_pred/best_val_nll': best_val_nll, 318 | 'edge_pred/num_patient_epochs': num_patient_epochs}) 319 | 320 | if (patience is not None) and (num_patient_epochs == patience): 321 | break 322 | 323 | model.train() 324 | for batch_data in tqdm(train_loader): 325 | if len(batch_data) == 9: 326 | batch_edge_index, batch_noisy_edge_index, batch_x_n,\ 327 | batch_abs_level, batch_rel_level, batch_t,\ 328 | batch_query_src, batch_query_dst, batch_label = batch_data 329 | batch_y = None 330 | else: 331 | batch_edge_index, batch_noisy_edge_index, batch_x_n,\ 332 | batch_abs_level, batch_rel_level, batch_t,\ 333 | batch_y, batch_query_src, batch_query_dst, batch_label = batch_data 334 | batch_y = batch_y.to(device) 335 | 336 | num_nodes = len(batch_x_n) 337 | batch_A = dglsp.spmatrix( 338 | torch.cat([batch_edge_index, batch_noisy_edge_index], dim=1), 339 | shape=(num_nodes, num_nodes)).to(device) 340 | batch_x_n = batch_x_n.to(device) 341 | batch_abs_level = batch_abs_level.to(device) 342 | batch_rel_level = batch_rel_level.to(device) 343 | batch_t = batch_t.to(device) 344 | batch_query_src = batch_query_src.to(device) 345 | batch_query_dst = batch_query_dst.to(device) 346 | batch_label = batch_label.to(device) 347 | 348 | batch_pred = model(batch_A, batch_x_n, batch_abs_level, 349 | batch_rel_level, batch_t, batch_query_src, 350 | batch_query_dst, batch_y) 351 | loss = criterion(batch_pred, batch_label) 352 | optimizer.zero_grad() 353 | loss.backward() 354 | optimizer.step() 355 | 356 | wandb.log({'edge_pred/loss': loss.item()}) 357 | 358 | return best_state_dict 359 | 360 | def main(args): 361 | torch.set_num_threads(args.num_threads) 362 | 363 | device_str = "cuda:0" if torch.cuda.is_available() else "cpu" 364 | device = torch.device(device_str) 365 | 366 | set_seed(args.seed) 367 | 368 | config = load_yaml(args.config_file) 369 | dataset = config['general']['dataset'] 370 | config_df = pd.json_normalize(config, sep='/') 371 | 372 | ts = time.strftime('%b%d-%H:%M:%S', time.gmtime()) 373 | 374 | wandb.init( 375 | project=f'LayerDAG_{dataset}', 376 | name=f'{ts}', 377 | config=config_df.to_dict(orient='records')[0] 378 | ) 379 | 380 | # For training the generative model, no need to use the test set. 381 | train_set, val_set, _ = load_dataset(dataset) 382 | 383 | train_node_count_dataset = LayerDAGNodeCountDataset(train_set, config['general']['conditional']) 384 | val_node_count_dataset = LayerDAGNodeCountDataset(val_set, config['general']['conditional']) 385 | 386 | train_node_pred_dataset = LayerDAGNodePredDataset(train_set, config['general']['conditional']) 387 | val_node_pred_dataset = LayerDAGNodePredDataset( 388 | val_set, config['general']['conditional'], get_marginal=False) 389 | 390 | node_diffusion_config = { 391 | 'marginal_list': train_node_pred_dataset.x_n_marginal, 392 | 'T': config['node_pred']['T'] 393 | } 394 | node_diffusion = DiscreteDiffusion(**node_diffusion_config) 395 | train_node_pred_dataset.node_diffusion = node_diffusion 396 | val_node_pred_dataset.node_diffusion = node_diffusion 397 | 398 | train_edge_pred_dataset = LayerDAGEdgePredDataset(train_set, config['general']['conditional']) 399 | val_edge_pred_dataset = LayerDAGEdgePredDataset(val_set, config['general']['conditional']) 400 | 401 | edge_diffusion_config = { 402 | 'avg_in_deg': train_edge_pred_dataset.avg_in_deg, 403 | 'T': config['edge_pred']['T'] 404 | } 405 | edge_diffusion = EdgeDiscreteDiffusion(**edge_diffusion_config) 406 | train_edge_pred_dataset.edge_diffusion = edge_diffusion 407 | val_edge_pred_dataset.edge_diffusion = edge_diffusion 408 | 409 | model_config = { 410 | 'num_x_n_cat': train_set.num_categories, 411 | 'node_count_encoder_config': config['node_count']['model'], 412 | 'max_layer_size': train_node_count_dataset.max_layer_size, 413 | 'node_pred_graph_encoder_config': config['node_pred']['graph_encoder'], 414 | 'node_predictor_config': config['node_pred']['predictor'], 415 | 'edge_pred_graph_encoder_config': config['edge_pred']['graph_encoder'], 416 | 'edge_predictor_config': config['edge_pred']['predictor'], 417 | 'max_level': max(train_node_pred_dataset.input_level.max().item(), 418 | val_node_pred_dataset.input_level.max().item()) 419 | } 420 | model = LayerDAG(device=device, 421 | node_diffusion=node_diffusion, 422 | edge_diffusion=edge_diffusion, 423 | **model_config) 424 | 425 | node_count_state_dict = main_node_count( 426 | device, train_node_count_dataset, val_node_count_dataset, 427 | model.node_count_model, config['node_count'], config['general']['patience']) 428 | model.node_count_model.load_state_dict(node_count_state_dict) 429 | 430 | node_pred_state_dict = main_node_pred( 431 | device, train_node_pred_dataset, val_node_pred_dataset, 432 | model.node_pred_model, config['node_pred'], config['general']['patience']) 433 | model.node_pred_model.load_state_dict(node_pred_state_dict) 434 | 435 | edge_pred_state_dict = main_edge_pred( 436 | device, train_edge_pred_dataset, val_edge_pred_dataset, 437 | model.edge_pred_model, config['edge_pred'], config['general']['patience']) 438 | model.edge_pred_model.load_state_dict(edge_pred_state_dict) 439 | 440 | save_path = f'model_{dataset}_{ts}.pth' 441 | torch.save({ 442 | 'dataset': dataset, 443 | 'node_diffusion_config': node_diffusion_config, 444 | 'edge_diffusion_config': edge_diffusion_config, 445 | 'model_config': model_config, 446 | 'model_state_dict': model.state_dict() 447 | }, save_path) 448 | 449 | if __name__ == '__main__': 450 | from argparse import ArgumentParser 451 | 452 | parser = ArgumentParser() 453 | parser.add_argument("--config_file", type=str, required=True) 454 | parser.add_argument("--num_threads", type=int, default=16) 455 | parser.add_argument("--seed", type=int, default=0) 456 | args = parser.parse_args() 457 | 458 | main(args) 459 | --------------------------------------------------------------------------------