├── src ├── __init__.py ├── __pycache__ │ ├── egnn.cpython-39.pyc │ ├── logger.cpython-39.pyc │ ├── util.cpython-39.pyc │ ├── __init__.cpython-39.pyc │ ├── hypermodel.cpython-39.pyc │ └── hyperdataset.cpython-39.pyc ├── logger.py ├── util.py ├── hyperdataset.py ├── egnn.py └── hypermodel.py ├── figs └── fig.png ├── .gitignore ├── bin ├── test_cnn.sh ├── test_hgnn.sh ├── test_ehnn.sh └── test_macrorank.sh ├── LICENSE ├── README.md ├── README_DATA.md ├── env.yaml ├── test.py └── main.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figs/fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IDEA/MacroRank/HEAD/figs/fig.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data* 2 | save* 3 | *.png 4 | visual* 5 | nohup.out 6 | read* 7 | __pychache__ -------------------------------------------------------------------------------- /src/__pycache__/egnn.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IDEA/MacroRank/HEAD/src/__pycache__/egnn.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IDEA/MacroRank/HEAD/src/__pycache__/logger.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IDEA/MacroRank/HEAD/src/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IDEA/MacroRank/HEAD/src/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/hypermodel.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IDEA/MacroRank/HEAD/src/__pycache__/hypermodel.cpython-39.pyc -------------------------------------------------------------------------------- /src/__pycache__/hyperdataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-IDEA/MacroRank/HEAD/src/__pycache__/hyperdataset.cpython-39.pyc -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | class Logger(object): 5 | def __init__(self,path='./',filename='train.log'): 6 | super(Logger).__init__() 7 | self.terminal = sys.stdout 8 | self.log = open(os.path.join(path,filename),mode='a') 9 | def write(self,message): 10 | self.terminal.write(message) 11 | self.log.write(message) 12 | def flush(self): 13 | self.log.flush() -------------------------------------------------------------------------------- /bin/test_cnn.sh: -------------------------------------------------------------------------------- 1 | python test.py --model CNN --checkp checkp/CNN_[1]_0/last.pth --label 1 --group 0 2 | python test.py --model CNN --checkp checkp/CNN_[2]_0/last.pth --label 2 --group 0 3 | python test.py --model CNN --checkp checkp/CNN_[3]_0/last.pth --label 3 --group 0 4 | python test.py --model CNN --checkp checkp/CNN_[1]_1/last.pth --label 1 --group 1 5 | python test.py --model CNN --checkp checkp/CNN_[2]_1/last.pth --label 2 --group 1 6 | python test.py --model CNN --checkp checkp/CNN_[3]_1/last.pth --label 3 --group 1 -------------------------------------------------------------------------------- /bin/test_hgnn.sh: -------------------------------------------------------------------------------- 1 | python test.py --model HGNN --layers 3 --checkp checkp/HGNN_[1]_0/last.pth --label 1 --group 0 2 | python test.py --model HGNN --layers 3 --checkp checkp/HGNN_[2]_0/last.pth --label 2 --group 0 3 | python test.py --model HGNN --layers 3 --checkp checkp/HGNN_[3]_0/last.pth --label 3 --group 0 4 | python test.py --model HGNN --layers 3 --checkp checkp/HGNN_[1]_1/last.pth --label 1 --group 1 5 | python test.py --model HGNN --layers 3 --checkp checkp/HGNN_[2]_1/last.pth --label 2 --group 1 6 | python test.py --model HGNN --layers 3 --checkp checkp/HGNN_[3]_1/last.pth --label 3 --group 1 -------------------------------------------------------------------------------- /bin/test_ehnn.sh: -------------------------------------------------------------------------------- 1 | python test.py --model EHGNN --egnn_layers 3 --pos_encode 4 --checkp checkp/EHGNN_\[1\]_0/last.pth --base_model EGNN --label 1 --group 0 2 | python test.py --model EHGNN --egnn_layers 3 --pos_encode 4 --checkp checkp/EHGNN_\[2\]_0/last.pth --base_model EGNN --label 2 --group 0 3 | python test.py --model EHGNN --egnn_layers 3 --pos_encode 4 --checkp checkp/EHGNN_\[3\]_0/last.pth --base_model EGNN --label 3 --group 0 4 | python test.py --model EHGNN --egnn_layers 3 --pos_encode 4 --checkp checkp/EHGNN_\[1\]_1/last.pth --base_model EGNN --label 1 --group 1 5 | python test.py --model EHGNN --egnn_layers 3 --pos_encode 4 --checkp checkp/EHGNN_\[2\]_1/last.pth --base_model EGNN --label 2 --group 1 6 | python test.py --model EHGNN --egnn_layers 3 --pos_encode 4 --checkp checkp/EHGNN_\[3\]_1/last.pth --base_model EGNN --label 3 --group 1 -------------------------------------------------------------------------------- /bin/test_macrorank.sh: -------------------------------------------------------------------------------- 1 | python test.py --model GClassifier --dataset PlainClusterSet --checkp chekcp/GClassifier_[1]_0/last.pth --base_model EGNN_DENSE --label 1 --group 0 --egnn_layers 4 --pos_encode 4 2 | python test.py --model GClassifier --dataset PlainClusterSet --checkp chekcp/GClassifier_[2]_0/last.pth --base_model EGNN_DENSE --label 2 --group 0 --egnn_layers 4 --pos_encode 4 3 | python test.py --model GClassifier --dataset PlainClusterSet --checkp chekcp/GClassifier_[3]_0/last.pth --base_model EGNN_DENSE --label 3 --group 0 --egnn_layers 5 --pos_encode 4 4 | python test.py --model GClassifier --dataset PlainClusterSet --checkp chekcp/GClassifier_[1]_1/last.pth --base_model EGNN_DENSE --label 1 --group 1 --egnn_layers 4 --pos_encode 4 5 | python test.py --model GClassifier --dataset PlainClusterSet --checkp chekcp/GClassifier_[2]_1/last.pth --base_model EGNN_DENSE --label 2 --group 1 --egnn_layers 4 --pos_encode 4 6 | python test.py --model GClassifier --dataset PlainClusterSet --checkp chekcp/GClassifier_[3]_1/last.pth --base_model EGNN --label 3 --group 1 --egnn_layers 4 --pos_encode 0 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Constwelve 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MacroRank: Ranking Macro Placement Solutions Leveraging Translation Equivariancy 2 | 3 | ## Overview 4 | Official implementation of our [MacroRank](https://dl.acm.org/doi/abs/10.1145/3566097.3567899), which can accurately predict the relative order of the quality of macro placement solutions. 5 | 6 | ![](figs/fig.png) 7 | 8 | ## Download Data and Model 9 | [Data&Model](https://drive.google.com/drive/folders/1TKHLMwHAMXxGo2zsbVSbO51Qv1Hikc8O?usp=sharing) 10 | 11 | ## Requirements 12 | ``` 13 | conda env create -f env.yaml 14 | ``` 15 | ## Test 16 | ``` 17 | bash bin/test_cnn.sh 18 | bash bin/test_hgnn.sh 19 | bash bin/test_ehnn.sh 20 | bash bin/test_macrorank.sh 21 | ``` 22 | ## License 23 | 24 | This repository is released under the MIT License. 25 | 26 | ## Citation 27 | 28 | If you think our work is useful, please feel free to cite our [paper](https://dl.acm.org/doi/abs/10.1145/3566097.3567899). 29 | 30 | ``` 31 | @inproceedings{chen2023macrorank, 32 | author = {Chen, Yifan and Mai, Jing and Gao, Xiaohan and Zhang, Muhan and Lin, Yibo}, 33 | title = {MacroRank: Ranking Macro Placement Solutions Leveraging Translation Equivariancy}, 34 | booktitle = {IEEE/ACM Asia and South Pacific Design Automation Conference (ASPDAC)}, 35 | year = {2023}, 36 | isbn = {9781450397834}, 37 | publisher = {Association for Computing Machinery}, 38 | address = {New York, NY, USA}, 39 | url = {https://doi.org/10.1145/3566097.3567899}, 40 | doi = {10.1145/3566097.3567899}, 41 | pages = {258–263}, 42 | numpages = {6}, 43 | location = {Tokyo, Japan}, 44 | } 45 | ``` 46 | 47 | ## Contact 48 | 49 | For any questions, please do not hesitate to contact us. 50 | 51 | ``` 52 | Yifan Chen: chenyifan2019@pku.edu.cn 53 | ``` 54 | -------------------------------------------------------------------------------- /README_DATA.md: -------------------------------------------------------------------------------- 1 | ## Dataset 2 | 3 | The dataset consists of 12 designs, which are as follows: 4 | 5 | 1. mgc_edit_dist_a 6 | 2. mgc_fft_b 7 | 3. mgc_matrix_mult_b 8 | 4. mgc_pci_bridge32_b 9 | 5. mgc_superblue14 10 | 6. mgc_superblue19 11 | 7. mgc_des_perf_a 12 | 8. mgc_fft_a 13 | 9. mgc_matrix_mult_a 14 | 10. mgc_matrix_mult_c 15 | 11. mgc_superblue11_a 16 | 12. mgc_superblue16_a 17 | 18 | The names of all these designs are stored in a `.txt` file called `all.names`. The `train.names` file contains the names of designs used for training, while the `test.names` file contains the names of designs used for testing. The parameter `-groud` indicates whether to exchange the test set and the train set. 19 | 20 | In the directory for each design, you can find the following files: 21 | 22 | - `edge_weights.txt`: It contains the weights of each net after clustering. 23 | - `hpwl.txt`: Each line represents the estimated HPWL (Half-Perimeter Wirelength) for each placement. 24 | - `labels.txt`: Each line contains the rWL (routed Wirelength), #vias, #shorts, and the score (ICCAD19 global routing contest score, not used in this work). 25 | - `golden.txt`: It contains the HPWL of the manually placed result. 26 | - `macro_index.txt`: It contains the list of all macros in the clustered netlist. 27 | - `meta.txt`: It contains metadata about the original design, including the number of nodes, macros, I/O, nets, row height, site width, number of pins, number of movable pins, total movable node area, total fixed node area, and total space area. However, this information is not utilized in this work. 28 | - `names.txt`: It contains the names of all the used placements. 29 | - `node_size.txt`: It provides the size of each node in the clustered netlist. 30 | - `pins.txt`: It contains information about all the pins. Each line represents a pin and includes the connected node, connected net, x offset of the pin, and y offset of the pin. 31 | - `region.txt`: It specifies the placeable region of nodes. 32 | - `node_pos`: It includes all the placements. Note that some placements are not used, as indicated in the `names.txt` file. -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: MacroRank 2 | channels: 3 | - pytorch 4 | - pyg 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=1_gnu 10 | - absl-py=0.15.0=pyhd3eb1b0_0 11 | - aiohttp=3.8.1=py39h7f8727e_1 12 | - aiosignal=1.2.0=pyhd3eb1b0_0 13 | - asttokens=2.0.5=pyhd3eb1b0_0 14 | - async-timeout=4.0.1=pyhd3eb1b0_0 15 | - attrs=21.4.0=pyhd3eb1b0_0 16 | - backcall=0.2.0=pyhd3eb1b0_0 17 | - blas=1.0=mkl 18 | - blinker=1.4=py39h06a4308_0 19 | - brotli=1.0.9=he6710b0_2 20 | - brotlipy=0.7.0=py39h27cfd23_1003 21 | - bzip2=1.0.8=h7b6447c_0 22 | - c-ares=1.18.1=h7f8727e_0 23 | - ca-certificates=2022.3.29=h06a4308_0 24 | - cachetools=4.2.2=pyhd3eb1b0_0 25 | - certifi=2021.10.8=py39h06a4308_2 26 | - cffi=1.15.0=py39hd667e15_1 27 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 28 | - click=8.0.4=py39h06a4308_0 29 | - colorama=0.4.4=pyhd3eb1b0_0 30 | - cryptography=3.4.8=py39hd23ed53_0 31 | - cudatoolkit=10.2.89=hfd86e86_1 32 | - cvxopt=1.3.0=py39h751e006_0 33 | - cycler=0.11.0=pyhd3eb1b0_0 34 | - dataclasses=0.8=pyh6d0b6a4_7 35 | - dbus=1.13.18=hb2f20db_0 36 | - debugpy=1.5.1=py39h295c915_0 37 | - decorator=5.1.1=pyhd3eb1b0_0 38 | - dsdp=5.8=hfa32c7d_0 39 | - entrypoints=0.3=py39h06a4308_0 40 | - executing=0.8.3=pyhd3eb1b0_0 41 | - expat=2.4.4=h295c915_0 42 | - ffmpeg=4.2.2=h20bf706_0 43 | - fftw=3.3.9=h27cfd23_1 44 | - fontconfig=2.13.1=h6c09931_0 45 | - fonttools=4.25.0=pyhd3eb1b0_0 46 | - freetype=2.11.0=h70c0345_0 47 | - frozenlist=1.2.0=py39h7f8727e_0 48 | - giflib=5.2.1=h7b6447c_0 49 | - glib=2.69.1=h4ff587b_1 50 | - glpk=4.65=h3ceedfd_2 51 | - gmp=6.2.1=h2531618_2 52 | - gnutls=3.6.15=he1e5248_0 53 | - google-auth-oauthlib=0.4.1=py_2 54 | - googledrivedownloader=0.4=pyhd3deb0d_1 55 | - grpcio=1.42.0=py39hce63b2e_0 56 | - gsl=2.7.1=hd82f3ee_0 57 | - gst-plugins-base=1.14.0=h8213a91_2 58 | - gstreamer=1.14.0=h28cd5cc_2 59 | - icu=58.2=he6710b0_3 60 | - idna=3.3=pyhd3eb1b0_0 61 | - importlib-metadata=4.11.3=py39h06a4308_0 62 | - intel-openmp=2022.0.1=h06a4308_3633 63 | - ipykernel=6.9.2=py39hef51801_0 64 | - ipython=8.1.1=py39h06a4308_0 65 | - jedi=0.18.1=py39h06a4308_1 66 | - jinja2=3.0.3=pyhd3eb1b0_0 67 | - joblib=1.1.0=pyhd3eb1b0_0 68 | - jpeg=9b=0 69 | - jupyter_client=7.1.2=pyhd3eb1b0_0 70 | - jupyter_core=4.9.2=py39h06a4308_0 71 | - kiwisolver=1.3.2=py39h295c915_0 72 | - lame=3.100=h7b6447c_0 73 | - lcms2=2.12=h3be6417_0 74 | - ld_impl_linux-64=2.35.1=h7274673_9 75 | - libblas=3.9.0=1_h86c2bf4_netlib 76 | - libcblas=3.9.0=5_h92ddd45_netlib 77 | - libffi=3.3=he6710b0_2 78 | - libgcc-ng=11.2.0=h1d223b6_14 79 | - libgfortran-ng=11.2.0=h69a702a_14 80 | - libgfortran5=11.2.0=h5c6108e_14 81 | - libgomp=11.2.0=h1d223b6_14 82 | - libidn2=2.3.2=h7f8727e_0 83 | - liblapack=3.9.0=5_h92ddd45_netlib 84 | - libopus=1.3.1=h7b6447c_0 85 | - libpng=1.6.37=hbc83047_0 86 | - libprotobuf=3.19.1=h4ff587b_0 87 | - libsodium=1.0.18=h7b6447c_0 88 | - libstdcxx-ng=11.2.0=he4da1e4_14 89 | - libtasn1=4.16.0=h27cfd23_0 90 | - libtiff=4.2.0=h85742a9_0 91 | - libunistring=0.9.10=h27cfd23_0 92 | - libuuid=1.0.3=h7f8727e_2 93 | - libuv=1.40.0=h7b6447c_0 94 | - libvpx=1.7.0=h439df22_0 95 | - libwebp=1.2.0=h89dd481_0 96 | - libwebp-base=1.2.0=h27cfd23_0 97 | - libxcb=1.14=h7b6447c_0 98 | - libxml2=2.9.12=h03d6c58_0 99 | - lz4-c=1.9.3=h295c915_1 100 | - markdown=3.3.4=py39h06a4308_0 101 | - markupsafe=2.0.1=py39h27cfd23_0 102 | - matplotlib=3.5.1=py39h06a4308_0 103 | - matplotlib-base=3.5.1=py39ha18d171_1 104 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 105 | - metis=5.1.0=hf484d3e_4 106 | - mkl=2020.2=256 107 | - mpfr=4.0.2=hb69a4c5_1 108 | - multidict=5.2.0=py39h7f8727e_2 109 | - munkres=1.1.4=py_0 110 | - ncurses=6.3=h7f8727e_2 111 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 112 | - nettle=3.7.3=hbbd107a_1 113 | - networkx=2.7.1=pyhd3eb1b0_0 114 | - ninja=1.10.2=py39hd09550d_3 115 | - numpy=1.20.3=py39hdbf815f_1 116 | - oauthlib=3.2.0=pyhd3eb1b0_0 117 | - openh264=2.1.1=h4ff587b_0 118 | - openssl=1.1.1n=h7f8727e_0 119 | - packaging=21.3=pyhd3eb1b0_0 120 | - pandas=1.2.3=py39hde0f152_0 121 | - parso=0.8.3=pyhd3eb1b0_0 122 | - pcre=8.45=h295c915_0 123 | - pexpect=4.8.0=pyhd3eb1b0_3 124 | - pickleshare=0.7.5=pyhd3eb1b0_1003 125 | - pillow=9.0.1=py39h22f2fdc_0 126 | - pip=21.2.4=py39h06a4308_0 127 | - protobuf=3.19.1=py39h295c915_0 128 | - ptyprocess=0.7.0=pyhd3eb1b0_2 129 | - pure_eval=0.2.2=pyhd3eb1b0_0 130 | - pyasn1=0.4.8=pyhd3eb1b0_0 131 | - pyasn1-modules=0.2.8=py_0 132 | - pycparser=2.21=pyhd3eb1b0_0 133 | - pydot=1.2.4=py_0 134 | - pyg=2.0.3=py39_torch_1.9.0_cu102 135 | - pygments=2.11.2=pyhd3eb1b0_0 136 | - pyjwt=2.1.0=py39h06a4308_0 137 | - pyopenssl=21.0.0=pyhd3eb1b0_1 138 | - pyparsing=3.0.4=pyhd3eb1b0_0 139 | - pyqt=5.9.2=py39h2531618_6 140 | - pysocks=1.7.1=py39h06a4308_0 141 | - python=3.9.7=h12debd9_1 142 | - python-dateutil=2.8.2=pyhd3eb1b0_0 143 | - python-louvain=0.15=pyhd3eb1b0_0 144 | - python_abi=3.9=2_cp39 145 | - pytorch=1.9.0=py3.9_cuda10.2_cudnn7.6.5_0 146 | - pytorch-cluster=1.5.9=py39_torch_1.9.0_cu102 147 | - pytorch-scatter=2.0.9=py39_torch_1.9.0_cu102 148 | - pytorch-sparse=0.6.12=py39_torch_1.9.0_cu102 149 | - pytorch-spline-conv=1.2.1=py39_torch_1.9.0_cu102 150 | - pytz=2021.3=pyhd3eb1b0_0 151 | - pyyaml=6.0=py39h7f8727e_1 152 | - pyzmq=22.3.0=py39h295c915_2 153 | - qt=5.9.7=h5867ecd_1 154 | - readline=8.1.2=h7f8727e_1 155 | - requests=2.27.1=pyhd3eb1b0_0 156 | - requests-oauthlib=1.3.0=py_0 157 | - rsa=4.7.2=pyhd3eb1b0_1 158 | - scikit-learn=1.0.2=py39h51133e4_1 159 | - scipy=1.8.0=py39hee8e79c_1 160 | - setuptools=58.0.4=py39h06a4308_0 161 | - sip=4.19.13=py39h295c915_0 162 | - six=1.16.0=pyhd3eb1b0_1 163 | - sqlite=3.38.2=hc218d9a_0 164 | - stack_data=0.2.0=pyhd3eb1b0_0 165 | - suitesparse=5.10.1=hd8046ac_0 166 | - tbb=2020.2=h4bd325d_4 167 | - tensorboard=2.6.0=py_1 168 | - tensorboard-data-server=0.6.0=py39hca6d32c_0 169 | - tensorboard-plugin-wit=1.6.0=py_0 170 | - threadpoolctl=2.2.0=pyh0d69192_0 171 | - tk=8.6.11=h1ccaba5_0 172 | - torchaudio=0.9.0=py39 173 | - torchvision=0.10.0=py39_cu102 174 | - tornado=6.1=py39h27cfd23_0 175 | - tqdm=4.63.0=pyhd3eb1b0_0 176 | - traitlets=5.1.1=pyhd3eb1b0_0 177 | - typing-extensions=4.1.1=hd3eb1b0_0 178 | - typing_extensions=4.1.1=pyh06a4308_0 179 | - tzdata=2022a=hda174b7_0 180 | - urllib3=1.26.8=pyhd3eb1b0_0 181 | - wcwidth=0.2.5=pyhd3eb1b0_0 182 | - werkzeug=2.0.3=pyhd3eb1b0_0 183 | - wheel=0.37.1=pyhd3eb1b0_0 184 | - x264=1!157.20191217=h7b6447c_0 185 | - xz=5.2.5=h7b6447c_0 186 | - yacs=0.1.6=pyhd3eb1b0_1 187 | - yaml=0.2.5=h7b6447c_0 188 | - yarl=1.6.3=py39h27cfd23_0 189 | - zeromq=4.3.4=h2531618_0 190 | - zipp=3.7.0=pyhd3eb1b0_0 191 | - zlib=1.2.11=h7f8727e_4 192 | - zstd=1.4.9=haebb681_0 193 | - pip: 194 | - asgiref==3.5.0 195 | - configspace==0.4.21 196 | - cython==0.29.28 197 | - deepxde==1.1.3 198 | - dill==0.3.4 199 | - django==4.0.3 200 | - egnn-pytorch==0.2.6 201 | - einops==0.4.1 202 | - emcee==3.1.1 203 | - google-auth==1.35.0 204 | - h5py==3.6.0 205 | - line-profiler==3.4.0 206 | - llvmlite==0.38.0 207 | - numba==0.55.1 208 | - patsy==0.5.2 209 | - platypus-opt==1.0.4 210 | - prompt-toolkit==3.0.28 211 | - psutil==5.9.0 212 | - pyaml==21.10.1 213 | - pyeda==0.28.0 214 | - pytorch-ranger==0.1.1 215 | - scikit-optimize==0.9.0 216 | - sqlparse==0.4.2 217 | - statsmodels==0.13.2 218 | - terminaltables==3.1.10 219 | - torch-optimizer==0.3.0 220 | - torchsort==0.1.9 221 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from random import random 2 | from sklearn.metrics import top_k_accuracy_score 3 | import torch 4 | import torch.nn.functional as F 5 | import argparse 6 | import os 7 | import time 8 | import random 9 | import numpy as np 10 | from tqdm import tqdm 11 | from torch.utils.data import Subset 12 | import pdb 13 | import sys 14 | #import swats 15 | import src.hyperdataset as hdatasets 16 | import src.hypermodel as hmodels 17 | from src.logger import Logger 18 | from torch_geometric.loader import DataLoader 19 | from src.util import InversePairs, kendall, mykendall 20 | #from torch.utils.tensorboard import writer 21 | import matplotlib.pyplot as plt 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument('--seed', type=int, default=777, help='seed') 25 | parser.add_argument('--device', type=str, default='cuda:0',help='device') 26 | parser.add_argument('--model', type=str, default='GClassifier',help='which mdoel to use') 27 | parser.add_argument('--batch_size', type=int, default=8,help='train batch size') 28 | parser.add_argument('--batch_step', type=int, default=1,help='how many batches per update') 29 | parser.add_argument('--test_batch_size', type=int, default=8,help='test batch size') 30 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 31 | parser.add_argument('--step_size', type=int, default=50, help='learning rate decay step') 32 | parser.add_argument('--lr_decay', type=float, default=1., help='learning rate decay ratio') 33 | parser.add_argument('--weight_decay', type=float, default=0.0001, help='weight decay') 34 | parser.add_argument('--nhid', type=int, default=16, help='hidden size') 35 | parser.add_argument('--layers',type=int,default=2,help='conv layers') 36 | parser.add_argument('--egnn_layers',type=int,default=3,help='egnn layers') 37 | parser.add_argument('--egnn_nhid',type=int,default=16,help='egnn layers hidden dim') 38 | #parser.add_argument('--pooling_ratio', type=float, default=0.1,help='pooling ratio') 39 | parser.add_argument('--dropout_ratio', type=float, default=0.1,help='dropout ratio') 40 | parser.add_argument('--group', type=int, default=0, help='which data group to use') 41 | parser.add_argument('--tests', type=str, nargs='+', 42 | default=['mgc_des_perf_a', 'mgc_fft_a', 'mgc_matrix_mult_a', 'mgc_matrix_mult_c', 'mgc_superblue14', 'mgc_superblue19'],help='test data') 43 | parser.add_argument('--trains', type=str, nargs='+', 44 | default=['mgc_edit_dist_a', 'mgc_fft_b', 'mgc_matrix_mult_b', 'mgc_pci_bridge32_b', 'mgc_superblue11_a', 'mgc_superblue16_a'],help='train data') 45 | parser.add_argument('--dataset_path', type=str, default='data') 46 | parser.add_argument('--dataset', type=str, default='PlainClusterSet') 47 | parser.add_argument('--epochs', type=int, default=400,help='maximum number of epochs') 48 | parser.add_argument('--patience', type=int, default=400,help='patience for earlystopping') 49 | parser.add_argument('--save_dir', type=str, default='save') 50 | parser.add_argument('--goon', action='store_true',help='continue training') 51 | parser.add_argument('--con', action='store_true',help='continue training') 52 | parser.add_argument('--checkp', type=str, default='test.pth') 53 | parser.add_argument('--pos_encode', type=int, default=4, help='whether use pos encoding on position') 54 | parser.add_argument('--size_encode', type=int, default=0, help='whether use pos encoding on size') 55 | parser.add_argument('--offset_encode', type=int, default=0, help='whether use pos encoding on offset') 56 | parser.add_argument('--design', type=str, default='all',help='whitch design to train') 57 | parser.add_argument('--loss', type=str, default='MAE',help='loss func') 58 | parser.add_argument('--acc', type=str, default='rel',help='loss func') 59 | parser.add_argument('--skip_cnt', action='store_true', default=True ,help='use skip cnt ?') 60 | parser.add_argument('--regresion', action='store_true', help='regression') 61 | parser.add_argument('--classifier', action='store_true', help='classification') 62 | parser.add_argument('--base_model', type=str, default='EGNN',help='which base mdoel to use in classifier') 63 | parser.add_argument('--metric', type=str, default='lambdda',help='which metric to use as lambda, [lambdda (top1 prob), ndcg]') 64 | parser.add_argument('--label', type=list[int],default=[1],help='which label to use, [0~5] = [hpwl, rwl, via, short, score]') 65 | parser.add_argument('--train_ratio', type=float, default=0.8,help='train ratio') 66 | parser.add_argument('--optimizer',type=str,default='Adam') 67 | args = parser.parse_args() 68 | 69 | def set_seed(seed): 70 | random.seed(seed) 71 | np.random.seed(seed) 72 | torch.manual_seed(seed) 73 | torch.cuda.manual_seed_all(seed) 74 | torch.cuda.manual_seed(seed) 75 | 76 | 77 | def build_test_loader(): 78 | MySet = getattr(hdatasets,args.dataset) 79 | 80 | dataset = MySet(args.dataset_path, mode=args.model, test_files=args.tests, train_files=args.trains, args=args) 81 | 82 | if args.model != 'CNN' and args.model != 'Classifier': 83 | args.num_node_features = dataset.num_node_features 84 | args.num_edge_features = dataset.num_edge_features 85 | args.num_pin_features = dataset.num_pin_features 86 | if args.model == 'EHGNN': 87 | args.num_pos_features = dataset.num_pos_features 88 | 89 | loader = {} 90 | for design in dataset.raw_file_names: 91 | design_set = Subset(dataset,range(dataset.ptr[design], 92 | dataset.ptr[design] + dataset.file_num[design])) 93 | loader[design] = DataLoader(design_set, batch_size= 1) 94 | return dataset, loader 95 | 96 | 97 | def build_model(): 98 | Model = getattr(hmodels,args.model) 99 | model = Model(args).to(args.device) 100 | #print(model) 101 | return model 102 | 103 | 104 | def build_log(): 105 | # make save dir 106 | st = time.strftime("%b:%d:%X",time.localtime()) 107 | args.save_dir = os.path.join(args.save_dir,'{}_{}_{}_{}'.format(args.model,args.label,args.group,st)) 108 | if not os.path.exists(args.save_dir): 109 | os.makedirs(args.save_dir) 110 | # rederict to save dir 111 | sys.stdout = Logger(path=args.save_dir) 112 | # print args 113 | print(args) 114 | # save paths 115 | best_model_path = os.path.join(args.save_dir,'best.pth'.format(st)) 116 | last_model_path = os.path.join(args.save_dir,'last.pth'.format(st)) 117 | return best_model_path, last_model_path 118 | 119 | 120 | # preparing 121 | torch.set_num_threads(16) 122 | # choose data group 123 | if args.group == 1: 124 | tmp = args.tests 125 | args.tests = args.trains 126 | args.trains = tmp 127 | 128 | label = [int(i) for i in args.label][0] 129 | set_seed(args.seed) 130 | # build up 131 | # print('loading dataset ...') 132 | dataset, loader = build_test_loader() 133 | model = build_model() 134 | 135 | # os.makedirs('log/{}'.format(args.checkp), exist_ok=True) 136 | # logger = writer.SummaryWriter('log/{}'.format(args.checkp)) 137 | 138 | checkp = torch.load(args.checkp, map_location='cuda') 139 | model.load_state_dict(checkp['model']) 140 | model = model.to(args.device) 141 | #print('load model from {}, loss = {}, err = {}'.format(args.checkp,checkp['val_loss'], checkp['rank_err'])) 142 | # golds = {} 143 | # for design in args.tests: 144 | # test_loader = loader[design] 145 | # preds = [] 146 | # reals = [] 147 | # origins = [] 148 | # label_p = 'data/raw/{}/labels.txt'.format(design) 149 | # idx_p = 'data/raw/{}/names.txt'.format(design) 150 | # golds[design] = np.loadtxt(label_p)[np.loadtxt(idx_p,dtype=int)] 151 | 152 | # meann = np.mean(labels, 0 ) 153 | # maxx = np.max(labels, 0 ) 154 | # minn = np.min(labels, 0 ) 155 | # pdb.set_trace() 156 | dataset.mode = 'CNN' 157 | mres = 0 158 | taut = 0 159 | score = 0 160 | print('model =', args.model, ", test group = ", args.group + 1) 161 | print("{:20}\t{:10}\t{:10}\t{}\t{}".format('design', 'mean_score', 'top30_score', 'mre', 'tau')) 162 | with torch.no_grad(): 163 | designs = [] 164 | embdds = [] 165 | model.eval() 166 | for design in args.tests: 167 | test_loader = loader[design] 168 | preds = [] 169 | reals = [] 170 | origins = [] 171 | #print(design, end='\t\t') 172 | label_p = 'data/raw/{}/labels.txt'.format(design) 173 | idx_p = 'data/raw/{}/names.txt'.format(design) 174 | 175 | labels_this = np.loadtxt(label_p)[np.loadtxt(idx_p,dtype=int)] 176 | #pdb.set_trace() 177 | for i, data in enumerate(test_loader): 178 | data = data.to(args.device) 179 | if args.model == 'HGNN' or args.model == 'EHGNN': 180 | out = model(data) 181 | else: 182 | out = model.predict(data) 183 | #out = out * (maxx[label - 1] - minn[label - 1]) + meann[label - 1] 184 | #data.y[:, 1 : ] = data.y[:, 1 : ].cuda() * (maxx - minn) + meann 185 | reals.append(data.y[:, label].view(-1).item()) 186 | origins.append(dataset.origin[design][i][:, label].item()) 187 | preds.append(out.view(-1).item()) 188 | reals = np.array(reals) 189 | preds = np.array(preds) 190 | origins = np.array(origins) 191 | mre = np.mean(np.abs(reals - preds)/np.abs(reals)) 192 | tau = mykendall(reals, preds) 193 | taut += tau/len(args.tests) 194 | mres += mre/len(args.tests) 195 | top30 = np.argsort(preds)[:30] 196 | top30_score = origins[top30].mean() 197 | mean_score = np.mean(origins) 198 | score += top30_score/mean_score/len(args.tests) 199 | 200 | print("{:20}\t{:>10.4f}\t{:>.4f}\t{:>.3f}\t{:>.3f}".format(design, mean_score, top30_score, mre, tau)) 201 | 202 | print("{:20}\t{:>10.4f}\t{:>10.4f}\t{:>.3f}\t{:>.3f}".format('average', 1., score, mres, taut)) 203 | #print('average mre = {:.3f}'.format(mres/len(args.tests)), ', tau = {:.3f}'.format(taut/len(args.tests))) 204 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter 3 | from tqdm import tqdm 4 | import pylab 5 | import os.path as osp 6 | import os 7 | import numpy as np 8 | import scipy.stats as stats 9 | 10 | def position_encoding(position : torch.Tensor, L : int=4) -> torch.Tensor: 11 | """ inputs : position [n, d] 12 | outputs : position [n, d * L * 2]""" 13 | n = position.shape[0] 14 | arr = torch.arange(0, L, 1) 15 | arr = torch.pow(2, arr) * np.pi 16 | arr = torch.stack((arr, arr + np.pi/2)).T.reshape(-1) 17 | encoding = position.view(n, -1, 1) * arr.view(1, 1, -1) 18 | encoding = torch.sin(encoding) 19 | encoding = encoding.view(n, -1) 20 | return encoding 21 | 22 | def draw_rect(coord, size, path = 'figs'): 23 | import matplotlib.pyplot as plt 24 | 25 | fig = plt.figure(dpi=500) 26 | ax = fig.add_subplot(111, aspect='equal') 27 | plt.axis('off') 28 | plt.xlim(xmax=1.2,xmin=-0.2) 29 | plt.ylim(ymax=1.2,ymin=-0.2) 30 | 31 | c , s = coord, size 32 | #patches = [matplotlib.patches.Rectangle((x, y),w, h, alpha=0.2,color='blue') for x,y,w,h in zip(coord[0], coord[1], size[0], size[1])] 33 | #ax.add_collection(PatchCollection(patches)) 34 | [ax.add_patch(plt.Rectangle((x, y),w, h, alpha=0.2,facecolor='blue')) for x,y,w,h in zip(c[0], c[1], s[0], s[1])] 35 | fig.savefig(os.path.join(path,"draw.png"),bbox_inches='tight') 36 | plt.close(fig) 37 | plt.cla() 38 | plt.clf() 39 | 40 | 41 | def get_ensity_map(macro_index, num_bins, bin_size, node_pos, cell_size, edge_index, pins, B): 42 | density = [] 43 | ox = macro_index.new_zeros(num_bins,num_bins).float() 44 | oy = macro_index.new_zeros(num_bins,num_bins).float() 45 | for idx in macro_index: 46 | pos = node_pos[idx] 47 | size = cell_size[idx] 48 | ox = torch.arange(0,1,bin_size,dtype=float).view(1,-1).repeat(num_bins,1) 49 | oy = torch.arange(0,1,bin_size,dtype=float).view(-1,1).repeat(1,num_bins) 50 | 51 | ox = torch.clamp((size[0]/2 + bin_size/2 - torch.abs(pos[0] - ox + size[0]/2 - bin_size/2)) / bin_size,0,1) 52 | oy = torch.clamp((size[1]/2 + bin_size/2 - torch.abs(pos[1] - oy + size[1]/2 - bin_size/2)) / bin_size,0,1) 53 | 54 | density.append((ox * oy).view(num_bins,num_bins,1)) 55 | 56 | density = torch.cat(density,dim = -1) 57 | density_map = density.sum(dim=-1) 58 | 59 | pin_density = torch.zeros_like(density_map).view(-1) 60 | cnt_density = torch.zeros_like(density_map).view(-1) 61 | 62 | 63 | all_pin_pos = ((torch.index_select(node_pos,dim=0,index=edge_index[0]) + pins) /bin_size).long().clamp(0, num_bins - 1) 64 | def dd2d(index): 65 | return index[:,1] * num_bins + index[:,0] 66 | pin_mask = torch.zeros(all_pin_pos.shape[0]).bool() 67 | for pidx in macro_index: 68 | pin_mask |= (edge_index[0]==pidx) 69 | pin_pos = all_pin_pos[pin_mask] 70 | indx = dd2d(pin_pos) 71 | 72 | pin_density = pin_density + scatter(all_pin_pos.new_ones(pin_pos.shape[0]), \ 73 | indx, dim=0, dim_size=num_bins * num_bins, reduce='sum') 74 | 75 | cnt_density = cnt_density + scatter(B[pin_mask], \ 76 | indx, dim=0, dim_size=num_bins * num_bins, reduce='sum') 77 | 78 | pin_density /= pin_density.max() 79 | cnt_density /= cnt_density.max() 80 | 81 | pin_density = pin_density.view(num_bins,num_bins) 82 | cnt_density = cnt_density.view(num_bins,num_bins) 83 | 84 | pic = torch.cat([density_map.view(1,1,num_bins,num_bins),pin_density.view(1,1,num_bins,num_bins),cnt_density.view(1,1,num_bins,num_bins)],dim=1) 85 | 86 | return pic 87 | 88 | def diameter(hyperedge_index): 89 | num_nodes = hyperedge_index[0].max().item() + 1 90 | num_edges = hyperedge_index[1].max().item() + 1 91 | 92 | maxx = 1000 93 | for i in tqdm(range(0,num_nodes)): 94 | 95 | vec = torch.zeros(num_nodes,dtype=torch.long).cuda() 96 | vec[i] = 1 97 | cnt = 0 98 | while vec.sum()/num_nodes <= 0.9: 99 | vec = vec.index_select(-1, hyperedge_index[0]) 100 | vec = scatter(vec,hyperedge_index[1], dim=0, dim_size=num_edges,reduce='max') 101 | vec = vec.index_select(-1,hyperedge_index[1]) 102 | vec = scatter(vec,hyperedge_index[1], dim=0, dim_size=num_edges,reduce='max') 103 | cnt += 1 104 | if cnt > maxx: 105 | break 106 | if cnt < maxx: 107 | maxx = cnt 108 | print(maxx) 109 | print(maxx) 110 | 111 | 112 | def k_shortest(hyperedge_index,macro_index): 113 | macro_num = int(len(macro_index)) 114 | node_num = int(hyperedge_index[0].max()+1) 115 | edge_num = int(hyperedge_index[1].max()+1) 116 | shortest_length = [] 117 | for i in range(0,macro_num): 118 | macro_id = macro_index[i] 119 | steps = torch.zeros(node_num,dtype=torch.long).to('cuda:0') 120 | visited = torch.zeros(node_num,dtype=torch.long).to('cuda:0') 121 | visited[macro_id] = 1 122 | cnt = 0 123 | # newly added nodes mask 124 | new_node = torch.zeros(node_num,dtype=torch.long).to('cuda:0') 125 | new_node[macro_id] = 1 126 | steps[macro_id] = 10 127 | while torch.sum(new_node) > 0 : 128 | cnt += 1 129 | tmp_vec = visited.index_select(-1, hyperedge_index[0]) 130 | tmp_vec = scatter(tmp_vec,hyperedge_index[1], dim=0, dim_size=edge_num,reduce='max') 131 | tmp_vec = tmp_vec.index_select(-1,hyperedge_index[1]) 132 | tmp_vec = scatter(tmp_vec,hyperedge_index[0], dim=0, dim_size=node_num,reduce='max') 133 | new_node = (tmp_vec - visited).long() 134 | steps[new_node.bool()] = cnt 135 | visited = tmp_vec 136 | #print(visited) 137 | steps = torch.where(visited == 0, node_num + edge_num,steps) 138 | shortest_length.append(steps.view(node_num,1)) 139 | shortest = torch.cat(shortest_length,dim=1) 140 | return shortest 141 | 142 | 143 | 144 | def standardization(x): 145 | mean = torch.mean(x) 146 | std = torch.std(x) 147 | return (x-mean)/std 148 | 149 | def normalization(x): 150 | minn = torch.min(x)[0] 151 | maxx = torch.max(x)[0] 152 | return (x-minn)/(maxx-minn) 153 | 154 | 155 | def build_cg_index(macro_pos,size): 156 | with torch.no_grad(): 157 | num = macro_pos.shape[0] 158 | edge_index_v = [] 159 | edge_index_h = [] 160 | for i in range(num): 161 | for j in range(i+1,num): 162 | pi = macro_pos[i] 163 | pj = macro_pos[j] 164 | si = size[i] 165 | sj = size[j] 166 | # add h edge 167 | if pi[0] + si[0] < pj[0]: 168 | edge_index_v.append([i,j]) 169 | if pj[0] + sj[0] < pi[0]: 170 | edge_index_v.append([j,i]) 171 | if pi[1] + si[1] < pj[1]: 172 | edge_index_h.append([i,j]) 173 | if pj[1] + sj[1] < pi[1]: 174 | edge_index_h.append([j,i]) 175 | return torch.from_numpy(np.array(edge_index_v).T), torch.from_numpy(np.array(edge_index_h).T) 176 | 177 | 178 | def MergeSort(data): 179 | n=len(data) 180 | #递归基 181 | if n==1:return data, 0 182 | #分两半来排序 183 | part1,part2=data[:n//2],data[n//2:] 184 | sorted_part1,s1=MergeSort(part1) 185 | sorted_part2,s2=MergeSort(part2) 186 | #排序后拼接这两半,拼接后先计数,然后将两个有序序列合并 187 | s,sorted_temp=0,sorted_part1+sorted_part2 188 | #用p、q两个指针指向两段,计算q中每个元素离插入点的index差 189 | p,q,len1,len_all=0,sorted_temp.index(sorted_part2[0]),len(sorted_part1),len(sorted_temp) 190 | while p torch.Tensor: 236 | target = torch.argsort(target) 237 | target = torch.argsort(target).float() 238 | 239 | pred = torch.argsort(pred) 240 | pred = torch.argsort(pred).float() 241 | 242 | return corrcoef(target, pred / pred.shape[-1]) 243 | 244 | 245 | def kendall(target, pred): 246 | if type(target) == torch.Tensor: 247 | target = target.detach().cpu().numpy() 248 | if type(pred) == torch.Tensor: 249 | pred = pred.detach().cpu().numpy() 250 | return stats.kendalltau(target, pred)[0] 251 | 252 | def mykendall(target, pred): 253 | if type(target) == torch.Tensor: 254 | target = target.detach().cpu().numpy() 255 | if type(pred) == torch.Tensor: 256 | pred = pred.detach().cpu().numpy() 257 | Rp = np.argsort(pred) 258 | Rr = np.argsort(np.array(target)[Rp]) 259 | return 1 - 2 * InversePairs(Rr.tolist()) / (len(target)**2 - len(target)) * 2 260 | 261 | 262 | def mle_loss(target : torch.Tensor, pred : torch.Tensor) -> torch.Tensor: 263 | perm = torch.argsort(target) 264 | pred = pred[perm] 265 | exp_pred = torch.exp(pred) 266 | sum_exp_pred = torch.cumsum(exp_pred, dim=-1) 267 | prob = exp_pred / sum_exp_pred 268 | log_prob = torch.log(prob) 269 | return - torch.sum(log_prob) 270 | 271 | def dcg_score(input, target): 272 | perm = input.argsort() 273 | out = target[perm] 274 | logi = np.arange(0, len(perm), 1) + 2 275 | logi = np.log2(logi) 276 | out = out / logi 277 | return out.mean() 278 | 279 | def idcg_score(input): 280 | return dcg_score(-input, input) 281 | 282 | def ndcg_score(input, target): 283 | return dcg_score(input, target) / idcg_score(target) 284 | 285 | 286 | def top_k_match(input, target, k=30): 287 | p_idx = np.argsort(input) 288 | r_idx = np.argsort(target) 289 | pk = p_idx[:k] 290 | rk = r_idx[:k] 291 | cross = np.intersect1d(pk, rk) 292 | return len(cross) / k 293 | 294 | def rank(input): 295 | perm = input.argsort() 296 | return perm.argsort() 297 | 298 | def mean_dist(data): 299 | n = data.shape[0] 300 | data= data.view(n, -1, 1) 301 | dist = torch.cdist(data, data, p=1) 302 | dist = dist.view(n, -1) 303 | return torch.mean(dist, dim=-1).view(-1, 1) 304 | 305 | if __name__ == '__main__': 306 | a = torch.randn(5, 10) 307 | meann = mean_dist(a) 308 | print(meann) 309 | 310 | 311 | 312 | 313 | -------------------------------------------------------------------------------- /src/hyperdataset.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from itertools import combinations 3 | import os.path as osp 4 | import pandas as pd 5 | import torch 6 | import numpy as np 7 | import os 8 | from torch_scatter import scatter 9 | from torch_geometric.data import Dataset, Data 10 | from torchvision import transforms 11 | import pdb 12 | import matplotlib.pyplot as plt 13 | from tqdm import tqdm 14 | from src.util import mean_dist, position_encoding, draw_rect, get_ensity_map 15 | 16 | 17 | class BipartiteData(Data): 18 | def __init__(self, **kwargs): 19 | super().__init__(**kwargs) 20 | def __inc__(self, key, value, *args, **kwargs): 21 | if key == 'edge_index': 22 | return torch.tensor([[self.x.size(0)], [self.edge_weight.size(0)]]) 23 | else: 24 | return super().__inc__(key, value, *args, **kwargs) 25 | 26 | 27 | class PlainClusterSet(Dataset): 28 | def __init__(self, root, transform=None, pre_transform=None, mode ='graph', pos_encoding=True, test_files=['mgc_fft_a','mgc_matrix_mult_b'], train_files=['mgc_fft_b'], device='cpu', args=None): 29 | self.args = args 30 | self.tot_file_num = None # int 31 | self.file_num = None # dict, file nums for each design 32 | self.ptr = None 33 | self.num_bins = 224 34 | self.bin_size = 1./224 35 | self.train_file_names = train_files 36 | self.test_file_names = test_files 37 | self.device = device 38 | self.mode = mode 39 | self.pos_encoding = pos_encoding 40 | # info 41 | self.labels = ['hpwl', 'rwl','vias','short', 'score'] 42 | self.weight = {} 43 | # for label statistics 44 | self.stats = {} 45 | self.tot_labels = None 46 | self.tot_means = None 47 | self.tot_stds = None 48 | self.means = {} 49 | self.stds = {} 50 | # 51 | super(PlainClusterSet, self).__init__(root, transform, pre_transform) 52 | # data prefech 53 | self.netlist = {} 54 | self.data = [] 55 | self.y = torch.load(osp.join(self.processed_dir, 'labels.pt')) 56 | self.weight = torch.load(osp.join(self.processed_dir, 'weight.pt')) 57 | self.lambdda = torch.load(osp.join(self.processed_dir, 'lambda.pt')) 58 | self.dcg = torch.load(osp.join(self.processed_dir, 'dcg.pt')) 59 | self.origin = {} 60 | 61 | for design in self.raw_file_names: 62 | self.origin[design] = [] 63 | self.netlist[design] = torch.load(osp.join(self.processed_dir, '{}.pt'.format(design))).to(device) 64 | 65 | for i in range(len(self.processed_file_names)): 66 | self.data.append(self.pre_load_data(i).to(device)) 67 | 68 | @property 69 | def processed_dir(self) -> str: 70 | return osp.join(self.root, 'processed_plain') 71 | 72 | @property 73 | def raw_file_names(self): 74 | names_path = osp.join(self.root,'raw','all.names') 75 | names = np.loadtxt(names_path,dtype=str) 76 | if names.ndim == 0: 77 | return [str(names)] 78 | return names.tolist() 79 | 80 | @property 81 | def num_node_features(self): 82 | if self[0].x is not None: return self[0].x.size(1) 83 | return 1 84 | 85 | @property 86 | def num_pin_features(self): 87 | if self[0].pin_offset is not None:return self[0].pin_offset.size(1) 88 | return 1 89 | 90 | @property 91 | def num_edge_features(self): 92 | if self[0].edge_weight is not None:return self[0].edge_weight.size(1) 93 | return 1 94 | 95 | @property 96 | def num_pos_features(self): 97 | if self[0].macro_pos is not None:return self[0].macro_pos.size(1) 98 | return 1 99 | 100 | @property 101 | def processed_file_names(self): 102 | if self.tot_file_num is None: 103 | self.tot_file_num = 0 104 | self.file_num = {} 105 | self.ptr = {} 106 | for design in self.raw_file_names: 107 | path = osp.join(self.raw_dir,design) 108 | name_path = osp.join(path,'names.txt') 109 | names = np.array(pd.read_table(name_path,header=None)).reshape(-1) 110 | self.tot_file_num += names.shape[0] 111 | self.file_num[design] = names.shape[0] 112 | self.ptr[self.raw_file_names[0]] = 0 113 | for i in range(1,int(len(self.raw_file_names))): 114 | self.ptr[self.raw_file_names[i]] = self.ptr[self.raw_file_names[i-1]] + self.file_num[self.raw_file_names[i-1]] 115 | return ['data_%d.pt'%i for i in range(0, self.tot_file_num)] 116 | 117 | 118 | def process(self): 119 | 120 | self.tot_labels = [] 121 | i = 0 122 | for design in self.raw_file_names: 123 | # paths 124 | path = osp.join(self.raw_dir,design) 125 | size_path = osp.join(path,'node_size.txt') 126 | name_path = osp.join(path,'names.txt') 127 | pos_root = osp.join(path,'node_pos') 128 | pin_path = osp.join(path,'pins.txt') 129 | region_path = osp.join(path,'region.txt') 130 | macro_path = osp.join(path,'macro_index.txt') 131 | hpwl_path = osp.join(path,'hpwl.txt') 132 | meta_path = osp.join(path,'meta.txt') 133 | label_path = osp.join(path,'labels.txt') 134 | hedge_w_path = osp.join(path, 'edge_weights.txt') 135 | # loading ... 136 | pins = np.loadtxt(pin_path) 137 | size = np.loadtxt(size_path) 138 | hedge_w = torch.from_numpy(np.loadtxt(hedge_w_path)).float() 139 | 140 | incidence = pins[:,:2] 141 | pin_feature = pins[:,2:] 142 | xl,yl,xh,yh = np.loadtxt(region_path) 143 | 144 | macro_index = torch.tensor(np.loadtxt(macro_path),dtype=torch.long) 145 | names = np.loadtxt(name_path,dtype=int) 146 | 147 | hpwls = np.loadtxt(hpwl_path) 148 | meta_data = np.loadtxt(meta_path) 149 | labels = np.loadtxt(label_path) 150 | 151 | rWLs = labels[:,0] 152 | vias = labels[:,1] 153 | short = labels[:,2] 154 | score = labels[:,3] 155 | mask = (rWLs != 0) 156 | # labels statics 157 | self.stats[design] = np.stack([hpwls[mask], rWLs[mask], vias[mask], short[mask], score[mask]], axis=0) 158 | self.tot_labels.append(self.stats[design]) 159 | 160 | meta_data[5] = meta_data[5]/(yh-yl) 161 | meta_data[8] = meta_data[8]/(yh-yl)/(xh-xl) 162 | meta_data[9] = meta_data[9]/(yh-yl)/(xh-xl) 163 | meta_data[10] = meta_data[10]/(yh-yl)/(xh-xl) 164 | 165 | meta_data = torch.from_numpy(meta_data).float() 166 | size[:,0] = size[:,0]/(xh-xl) 167 | size[:,1] = size[:,1]/(yh-yl) 168 | pin_feature[:,0] = pin_feature[:,0]/(xh-xl) 169 | pin_feature[:,1] = pin_feature[:,1]/(yh-yl) 170 | # std 171 | rWLs = rWLs/(xh-xl+yh-yl)*2 172 | rWLs = rWLs 173 | hpwls = hpwls/(xh-xl+yh-yl)*2 174 | hpwls = hpwls 175 | 176 | cell_size = torch.tensor(size, dtype=torch.float) 177 | edge_index = torch.tensor(incidence.T, dtype=torch.long) 178 | pins = torch.tensor(pin_feature,dtype=torch.float) 179 | 180 | num_nodes = cell_size.shape[0] 181 | num_egdes = hedge_w.shape[0] 182 | num_pins = pins.shape[0] 183 | # node_degree 184 | D = scatter(torch.ones(num_pins), edge_index[0], dim=0, dim_size=num_nodes, reduce='sum') 185 | # add self loop to no edge block 186 | block_index = torch.where(D == 0)[0] # no edge connected 187 | if len(block_index) >0: 188 | self_loop_edge = torch.arange(num_egdes, num_egdes + block_index.shape[0], 1).long() 189 | self_loop_edge = torch.stack([block_index, self_loop_edge],dim=0) 190 | edge_index = torch.cat([edge_index, self_loop_edge],dim=-1) 191 | self_loop_pin = torch.zeros((block_index.shape[0], 2)).float() 192 | pins = torch.cat([pins, self_loop_pin], dim=0) 193 | self_loop_edge_w = torch.zeros(block_index.shape[0]).float() 194 | 195 | hedge_w = torch.cat([hedge_w, self_loop_edge_w], dim=-1) 196 | 197 | num_egdes += block_index.shape[0] 198 | num_pins += block_index.shape[0] 199 | # edge_degree 200 | B = scatter(torch.ones(num_pins), edge_index[1], dim=0, dim_size=num_egdes, reduce='sum') 201 | B = torch.index_select(B, dim=-1, index=edge_index[1]).clamp(0,50) 202 | 203 | 204 | macro_mask = torch.zeros(cell_size.shape[0]).float() 205 | macro_mask[macro_index] = 1 206 | 207 | node_attr = torch.cat((cell_size, D.view(-1, 1), macro_mask.view(-1,1)),dim=-1) 208 | # netlist is the same 209 | data = Data( 210 | # x = [size[2 or 16], degree[1], pins[1]] 211 | node_attr=node_attr, 212 | edge_index=edge_index, 213 | edge_weight=hedge_w.view(-1,1), 214 | pin_offset=pins, 215 | macro_index=macro_index) 216 | if osp.exists(osp.join(self.processed_dir, '{}.pt'.format(design))): continue 217 | torch.save(data, osp.join(self.processed_dir, '{}.pt'.format(design))) 218 | for name in tqdm(names): 219 | if osp.exists(osp.join(self.processed_dir, 'data_{}.pt'.format(i))):continue 220 | if hpwls[name] == 0: print('{}-{}'.format(design,name)) 221 | pos_path = osp.join(pos_root,'%d.txt'%name) 222 | node_pos = torch.tensor(np.loadtxt(pos_path),dtype=torch.float) 223 | # normalize 224 | node_pos[:,0] = (node_pos[:,0]-xl)/(xh-xl) 225 | node_pos[:,1] = (node_pos[:,1]-yl)/(yh-yl) 226 | # fill zero 227 | fake_pos = torch.zeros_like(node_pos) 228 | fake_pos[macro_index] = node_pos[macro_index] 229 | # density map 230 | pic = get_ensity_map(macro_index,self.num_bins,self.bin_size,node_pos, cell_size, edge_index, pins, B) 231 | 232 | data = Data(# position[ll][2 or 16] 233 | pos=fake_pos.float(), 234 | # label = [hpwl, rwl, vias, short, score] 235 | y=torch.tensor([hpwls[name], rWLs[name], vias[name], short[name], score[name]],dtype=torch.float).view(1, -1), 236 | # density_map 237 | pic = pic.float(), 238 | # design 239 | design = design) 240 | 241 | if self.pre_filter is not None and not self.pre_filter(data): 242 | continue 243 | 244 | if self.pre_transform is not None: 245 | data = self.pre_transform(data) 246 | 247 | torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i))) 248 | i += 1 249 | 250 | self.tot_labels = np.hstack(self.tot_labels) 251 | self.tot_means = torch.from_numpy(np.mean(self.tot_labels, axis=-1)) 252 | self.tot_maxs = torch.from_numpy(np.max(self.tot_labels, axis=-1)) 253 | self.tot_mins = torch.from_numpy(np.min(self.tot_labels, axis=-1)) 254 | self.tot_stds = torch.from_numpy(np.std(self.tot_labels, axis=-1)) 255 | 256 | self.means = {} 257 | self.stds = {} 258 | self.lambdda = {} 259 | self.dcg = {} 260 | ws = [] 261 | 262 | for design in self.raw_file_names: 263 | tmp = torch.from_numpy(self.stats[design]) 264 | tmp = (tmp - tmp.mean(dim=-1).view(-1,1))/tmp.std(dim=-1).view(-1,1) 265 | logi = torch.argsort(tmp, dim=-1) 266 | logi = torch.argsort(logi, dim=-1) 267 | logi = torch.log2(logi+2) 268 | self.lambdda[design] = torch.softmax(-tmp, dim=-1) * tmp.shape[-1] 269 | self.dcg[design] = tmp / logi 270 | meann = mean_dist(self.dcg[design]) 271 | self.dcg[design] = self.dcg[design] / meann 272 | print(design, self.dcg[design].max(), self.dcg[design].min(), self.dcg[design].std()) 273 | 274 | 275 | for design in self.raw_file_names: 276 | self.stats[design] = (torch.from_numpy(self.stats[design]) - self.tot_means.view(-1,1))/(self.tot_maxs.view(-1,1) - self.tot_mins.view(-1,1)) 277 | self.means[design] = torch.mean(self.stats[design], dim=-1) 278 | self.stds[design] = torch.std(self.stats[design], dim=-1) 279 | if design == 'mgc_pci_bridge32_b': 280 | self.stds[design] = self.stds['mgc_fft_a'] 281 | ws.append(self.stds[design].view(-1, 1)) 282 | 283 | ws = torch.cat(ws, dim=-1) 284 | ws = 1 / ws 285 | mws = torch.mean(ws, dim=-1) 286 | ws = ws / mws.view(-1, 1) 287 | 288 | labes = [] 289 | for i, design in enumerate(self.raw_file_names): 290 | self.weight[design] = ws[:, i].float() 291 | 292 | for i, design in enumerate(self.raw_file_names): 293 | labes.append(self.stats[design]) 294 | 295 | labes = torch.cat(labes, dim=-1).float() 296 | 297 | torch.save(labes, osp.join(self.processed_dir, 'labels.pt')) 298 | torch.save(self.weight, osp.join(self.processed_dir, 'weight.pt')) 299 | torch.save(self.lambdda, osp.join(self.processed_dir, 'lambda.pt')) 300 | torch.save(self.dcg, osp.join(self.processed_dir, 'dcg.pt')) 301 | 302 | 303 | def len(self): 304 | return len(self.processed_file_names) 305 | 306 | def pre_load_data(self, idx): 307 | data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx))).to(self.device) 308 | design = data.design 309 | netlist = self.netlist[design] 310 | # 311 | path = osp.join(self.raw_dir,design) 312 | region_path = osp.join(path,'region.txt') 313 | xl,yl,xh,yh = np.loadtxt(region_path) 314 | data.y[:, 1] *= (xh+yh-xl-yl)/2 315 | self.origin[design].append(data.y) 316 | # normalize 317 | y = self.y[:, idx].view(1, -1) 318 | if self.mode == 'HGNN': 319 | w = self.weight[design].view(1, -1) 320 | size = netlist.node_attr[:,2] 321 | pe = position_encoding(data.pos) 322 | x = torch.cat([pe, netlist.node_attr], dim=-1) 323 | bipartdata = BipartiteData(x=x, edge_index=netlist.edge_index, y=y, pic = data.pic, \ 324 | edge_weight=netlist.edge_weight, pin_offset=netlist.pin_offset, 325 | macro_index=netlist.macro_index, design = design, w=w) 326 | elif self.mode == 'EHGNN' or self.mode == 'CEHGNN': 327 | w = self.weight[design].view(1, -1) 328 | size = netlist.node_attr[netlist.macro_index, :2] 329 | pos = data.pos[netlist.macro_index] 330 | d4pos = torch.cat([pos, pos + size], dim=-1) 331 | x = netlist.node_attr 332 | offset = netlist.pin_offset 333 | bipartdata = BipartiteData(x=x, edge_index=netlist.edge_index, y=y, pic = data.pic, \ 334 | edge_weight=netlist.edge_weight, pin_offset=offset, 335 | macro_index=netlist.macro_index, design = design, w=w, macro_num=netlist.macro_index.shape[0], macro_pos = d4pos) 336 | elif self.mode == 'CNN': 337 | w = self.weight[design].view(1, -1) 338 | bipartdata = Data(y=y, density = data.pic, design = design, w=w) 339 | elif self.mode == 'Classifier' or self.mode == 'RClassifier': 340 | w = self.weight[design].view(1, -1) 341 | bipartdata = Data(y=y, density = data.pic, design = design, w=w) 342 | elif self.mode == 'GClassifier': 343 | w = self.weight[design].view(1, -1) 344 | size = netlist.node_attr[netlist.macro_index, :2] 345 | pos = data.pos[netlist.macro_index] 346 | d4pos = torch.cat([pos, pos + size], dim=-1) 347 | x = netlist.node_attr 348 | offset = netlist.pin_offset 349 | bipartdata = BipartiteData(x=x, edge_index=netlist.edge_index, y=y, pic = data.pic, \ 350 | edge_weight=netlist.edge_weight, pin_offset=offset, 351 | macro_index=netlist.macro_index, design = design, w=w, macro_num=netlist.macro_index.shape[0], macro_pos = d4pos) 352 | else: 353 | assert(False) 354 | return bipartdata 355 | 356 | def get(self, idx): 357 | if self.mode == 'Classifier' or self.mode == 'RClassifier': 358 | design = self.data[idx].design 359 | begin = self.ptr[design] 360 | lenth = self.file_num[design] 361 | select_pair = np.random.randint(begin, begin + lenth, 2) 362 | data1 = self.data[select_pair[0]] 363 | data2 = self.data[select_pair[1]] 364 | mask1, mask5, mask0 = (data1.y > data2.y), (data1.y == data2.y), (data1.y < data2.y) 365 | target = mask1 * 1 + 0.5 * mask5 366 | lambdd1 = self.lambdda[design][:, select_pair[0] - begin] 367 | lambdd2 = self.lambdda[design][:, select_pair[1] - begin] 368 | w = (lambdd1 - lambdd2).abs().view(1, -1) 369 | 370 | bidata = Data(y=target, density=torch.cat((data1.density, data2.density),dim=0), w=w, y1=data1.y, y2=data2.y, w1=data1.w, w2=data2.w) 371 | return bidata 372 | elif self.mode == 'GClassifier': 373 | # select data 374 | design = self.data[idx].design 375 | begin = self.ptr[design] 376 | lenth = self.file_num[design] 377 | select_pair = np.random.randint(begin, begin + lenth, 2) 378 | data1 = self.data[select_pair[0]] 379 | data2 = self.data[select_pair[1]] 380 | # get weight 381 | mask1, mask5, mask0 = (data1.y > data2.y), (data1.y == data2.y), (data1.y < data2.y) 382 | target = mask1 * 1 + 0.5 * mask5 383 | lambdd1 = self.lambdda[design][:, select_pair[0] - begin] 384 | lambdd2 = self.lambdda[design][:, select_pair[1] - begin] 385 | w = (lambdd1 - lambdd2).abs().view(1, -1) 386 | # 387 | netlist = self.netlist[design] 388 | bidata = BipartiteData(x=netlist.node_attr, edge_index=netlist.edge_index, y1=data1.y, y2=data2.y, 389 | y=target, pic1 = data1.pic, pic2 = data2.pic, edge_weight=netlist.edge_weight, 390 | pin_offset=netlist.pin_offset, macro_index=netlist.macro_index, design = design, w=w, 391 | macro_num=netlist.macro_index.shape[0], macro_pos1 = data1.macro_pos, macro_pos2 = data2.macro_pos) 392 | return bidata 393 | else :return self.data[idx] 394 | 395 | 396 | def pre_trans(): 397 | return 398 | 399 | if __name__=='__main__': 400 | #from torch.utils.tensorboard import writer 401 | #logger = writer.SummaryWriter('visual') 402 | Set = PlainClusterSet('data', mode='HGNN') 403 | target = 1 404 | for m, design in enumerate(Set.raw_file_names): 405 | labels = [] 406 | positions = [] 407 | for i in range(Set.ptr[design], Set.file_num[design] + Set.ptr[design]): 408 | data = Set[i] 409 | pos = data.x[data.macro_index,:2] 410 | positions.append(pos.view(-1)) 411 | labels.append(torch.load(osp.join(Set.processed_dir, 'data_{}.pt'.format(i))).y[:, target].view(-1).item()) 412 | a = torch.stack(positions, dim=0) 413 | print(design, np.std(labels)/np.mean(labels)) 414 | #logger.add_embedding(a, metadata=torch.tensor(labels), global_step=m) 415 | 416 | 417 | 418 | 419 | -------------------------------------------------------------------------------- /src/egnn.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import torch 3 | from torch import nn, einsum, broadcast_tensors 4 | import torch.nn.functional as F 5 | import pdb 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | 9 | # helper functions 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | def safe_div(num, den, eps = 1e-8): 15 | res = num.div(den.clamp(min = eps)) 16 | res.masked_fill_(den == 0, 0.) 17 | return res 18 | 19 | def batched_index_select(values, indices, dim = 1): 20 | value_dims = values.shape[(dim + 1):] 21 | values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices)) 22 | indices = indices[(..., *((None,) * len(value_dims)))] 23 | indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims) 24 | value_expand_len = len(indices_shape) - (dim + 1) 25 | values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)] 26 | 27 | value_expand_shape = [-1] * len(values.shape) 28 | expand_slice = slice(dim, (dim + value_expand_len)) 29 | value_expand_shape[expand_slice] = indices.shape[expand_slice] 30 | values = values.expand(*value_expand_shape) 31 | 32 | dim += value_expand_len 33 | return values.gather(dim, indices) 34 | 35 | def fourier_encode_dist(x, num_encodings = 4, include_self = True): 36 | x = x.unsqueeze(-1) 37 | device, dtype, orig_x = x.device, x.dtype, x 38 | scales = 2 ** torch.arange(num_encodings, device = device, dtype = dtype) 39 | x = x / scales 40 | x = torch.cat([x.sin(), x.cos()], dim=-1) 41 | x = torch.cat((x, orig_x), dim = -1) if include_self else x 42 | return x 43 | 44 | def embedd_token(x, dims, layers): 45 | stop_concat = -len(dims) 46 | to_embedd = x[:, stop_concat:].long() 47 | for i,emb_layer in enumerate(layers): 48 | # the portion corresponding to `to_embedd` part gets dropped 49 | x = torch.cat([ x[:, :stop_concat], 50 | emb_layer( to_embedd[:, i] ) 51 | ], dim=-1) 52 | stop_concat = x.shape[-1] 53 | return x 54 | 55 | # swish activation fallback 56 | 57 | class Swish_(nn.Module): 58 | def forward(self, x): 59 | return x * x.sigmoid() 60 | 61 | SiLU = nn.SiLU if hasattr(nn, 'SiLU') else Swish_ 62 | 63 | # helper classes 64 | 65 | # this follows the same strategy for normalization as done in SE3 Transformers 66 | # https://github.com/lucidrains/se3-transformer-pytorch/blob/main/se3_transformer_pytorch/se3_transformer_pytorch.py#L95 67 | 68 | class CoorsNorm(nn.Module): 69 | def __init__(self, eps = 1e-8, scale_init = 1.): 70 | super().__init__() 71 | self.eps = eps 72 | scale = torch.zeros(1).fill_(scale_init) 73 | self.scale = nn.Parameter(scale) 74 | 75 | def forward(self, coors): 76 | norm = coors.norm(dim = -1, keepdim = True) 77 | normed_coors = coors / norm.clamp(min = self.eps) 78 | return normed_coors * self.scale 79 | 80 | # global linear attention 81 | 82 | class Attention(nn.Module): 83 | def __init__(self, dim, heads = 8, dim_head = 64): 84 | super().__init__() 85 | inner_dim = heads * dim_head 86 | self.heads = heads 87 | self.scale = dim_head ** -0.5 88 | 89 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 90 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 91 | self.to_out = nn.Linear(inner_dim, dim) 92 | 93 | def forward(self, x, context, mask = None): 94 | h = self.heads 95 | 96 | q = self.to_q(x) 97 | kv = self.to_kv(context).chunk(2, dim = -1) 98 | 99 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv)) 100 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 101 | 102 | if exists(mask): 103 | mask_value = -torch.finfo(dots.dtype).max 104 | mask = rearrange(mask, 'b n -> b () () n') 105 | dots.masked_fill_(~mask, mask_value) 106 | 107 | attn = dots.softmax(dim = -1) 108 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 109 | 110 | out = rearrange(out, 'b h n d -> b n (h d)', h = h) 111 | return self.to_out(out) 112 | 113 | class GlobalLinearAttention(nn.Module): 114 | def __init__( 115 | self, 116 | *, 117 | dim, 118 | heads = 8, 119 | dim_head = 64 120 | ): 121 | super().__init__() 122 | self.norm_seq = nn.LayerNorm(dim) 123 | self.norm_queries = nn.LayerNorm(dim) 124 | self.attn1 = Attention(dim, heads, dim_head) 125 | self.attn2 = Attention(dim, heads, dim_head) 126 | 127 | self.ff = nn.Sequential( 128 | nn.LayerNorm(dim), 129 | nn.Linear(dim, dim * 4), 130 | nn.GELU(), 131 | nn.Linear(dim * 4, dim) 132 | ) 133 | 134 | def forward(self, x, queries, mask = None): 135 | res_x, res_queries = x, queries 136 | x, queries = self.norm_seq(x), self.norm_queries(queries) 137 | 138 | induced = self.attn1(queries, x, mask = mask) 139 | out = self.attn2(x, induced) 140 | 141 | x = out + res_x 142 | queries = induced + res_queries 143 | 144 | x = self.ff(x) + x 145 | return x, queries 146 | 147 | # classes 148 | 149 | class EGNN(nn.Module): 150 | def __init__( 151 | self, 152 | dim, 153 | edge_dim = 0, 154 | m_dim = 16, 155 | fourier_features = 0, 156 | num_nearest_neighbors = 0, 157 | dropout = 0.0, 158 | init_eps = 1e-3, 159 | norm_feats = False, 160 | norm_coors = False, 161 | norm_coors_scale_init = 1e-2, 162 | update_feats = True, 163 | update_coors = True, 164 | only_sparse_neighbors = False, 165 | valid_radius = float('inf'), 166 | m_pool_method = 'sum', 167 | soft_edges = False, 168 | coor_weights_clamp_value = None, 169 | use_rel_coord = True, 170 | act = nn.SiLU, 171 | ): 172 | super().__init__() 173 | assert m_pool_method in {'sum', 'mean'}, 'pool method must be either sum or mean' 174 | assert update_feats or update_coors, 'you must update either features, coordinates, or both' 175 | 176 | self.fourier_features = fourier_features 177 | self.use_rel_coord = use_rel_coord 178 | edge_input_dim = (fourier_features * 2) + (dim * 2) + edge_dim + 1 + 4 * use_rel_coord 179 | dropout = nn.Dropout(dropout, inplace=True) if dropout > 0 else nn.Identity() 180 | 181 | self.edge_mlp = nn.Sequential( 182 | nn.Linear(edge_input_dim, m_dim), 183 | dropout, 184 | act, 185 | ) 186 | 187 | self.edge_gate = nn.Sequential( 188 | nn.Linear(m_dim, 1), 189 | nn.Sigmoid() 190 | ) if soft_edges else None 191 | 192 | self.node_norm = nn.LayerNorm(dim) if norm_feats else nn.Identity() 193 | self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init) if norm_coors else nn.Identity() 194 | 195 | self.m_pool_method = m_pool_method 196 | 197 | self.node_mlp = nn.Sequential( 198 | nn.Linear(dim + m_dim, dim), 199 | dropout, 200 | act, 201 | ) if update_feats else None 202 | 203 | self.coors_mlp = nn.Sequential( 204 | nn.Linear(m_dim, 1), 205 | dropout, 206 | act, 207 | ) if update_coors else None 208 | 209 | self.num_nearest_neighbors = num_nearest_neighbors 210 | self.only_sparse_neighbors = only_sparse_neighbors 211 | self.valid_radius = valid_radius 212 | 213 | self.coor_weights_clamp_value = coor_weights_clamp_value 214 | 215 | self.init_eps = init_eps 216 | self.apply(self.init_) 217 | 218 | def init_(self, module): 219 | if type(module) in {nn.Linear}: 220 | # seems to be needed to keep the network from exploding to NaN with greater depths 221 | nn.init.normal_(module.weight, std = self.init_eps) 222 | 223 | def forward(self, feats, coors, edges = None, mask = None, adj_mat = None, num_nearest = 0): 224 | b, n, d, device, fourier_features, num_nearest, valid_radius, only_sparse_neighbors = *feats.shape, feats.device, self.fourier_features, num_nearest, self.valid_radius, self.only_sparse_neighbors 225 | 226 | if exists(mask): 227 | num_nodes = mask.sum(dim = -1) 228 | 229 | use_nearest = num_nearest > 0 or only_sparse_neighbors 230 | 231 | rel_coors = rearrange(coors, 'b i d -> b i () d') - rearrange(coors, 'b j d -> b () j d') 232 | rel_dist = (rel_coors ** 2).sum(dim = -1, keepdim = True) 233 | 234 | i = j = n 235 | 236 | if use_nearest: 237 | ranking = rel_dist[..., 0].clone() 238 | 239 | if exists(mask): 240 | rank_mask = mask[:, :, None] * mask[:, None, :] 241 | ranking.masked_fill_(~rank_mask, 1e5) 242 | 243 | if exists(adj_mat): 244 | if len(adj_mat.shape) == 2: 245 | adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b) 246 | 247 | if only_sparse_neighbors: 248 | num_nearest = int(adj_mat.float().sum(dim = -1).max().item()) 249 | valid_radius = 0 250 | 251 | self_mask = rearrange(torch.eye(n, device = device, dtype = torch.bool), 'i j -> () i j') 252 | 253 | adj_mat = adj_mat.masked_fill(self_mask, False) 254 | ranking.masked_fill_(self_mask, -1.) 255 | ranking.masked_fill_(adj_mat, 0.) 256 | 257 | nbhd_ranking, nbhd_indices = ranking.topk(num_nearest, dim = -1, largest = False) 258 | 259 | nbhd_mask = nbhd_ranking <= valid_radius 260 | 261 | rel_coors = batched_index_select(rel_coors, nbhd_indices, dim = 2) 262 | rel_dist = batched_index_select(rel_dist, nbhd_indices, dim = 2) 263 | 264 | if exists(edges): 265 | edges = batched_index_select(edges, nbhd_indices, dim = 2) 266 | 267 | j = num_nearest 268 | 269 | if fourier_features > 0: 270 | rel_dist = fourier_encode_dist(rel_dist, num_encodings = fourier_features) 271 | rel_dist = rearrange(rel_dist, 'b i j () d -> b i j d') 272 | 273 | if use_nearest: 274 | feats_j = batched_index_select(feats, nbhd_indices, dim = 1) 275 | else: 276 | feats_j = rearrange(feats, 'b j d -> b () j d') 277 | 278 | feats_i = rearrange(feats, 'b i d -> b i () d') 279 | feats_i, feats_j = broadcast_tensors(feats_i, feats_j) 280 | if self.use_rel_coord: 281 | rel_dist = torch.cat((rel_dist, rel_coors), dim=-1) 282 | edge_input = torch.cat((feats_i, feats_j, rel_dist), dim = -1) 283 | 284 | if exists(edges): 285 | edge_input = torch.cat((edge_input, edges), dim = -1) 286 | 287 | m_ij = self.edge_mlp(edge_input) 288 | 289 | if exists(self.edge_gate): 290 | m_ij = m_ij * self.edge_gate(m_ij) 291 | 292 | if exists(mask): 293 | mask_i = rearrange(mask, 'b i -> b i ()') 294 | 295 | if use_nearest: 296 | mask_j = batched_index_select(mask, nbhd_indices, dim = 1) 297 | mask = (mask_i * mask_j) & nbhd_mask 298 | else: 299 | mask_j = rearrange(mask, 'b j -> b () j') 300 | mask = mask_i * mask_j 301 | 302 | if exists(self.coors_mlp): 303 | coor_weights = self.coors_mlp(m_ij) 304 | coor_weights = rearrange(coor_weights, 'b i j () -> b i j') 305 | 306 | rel_coors = self.coors_norm(rel_coors) 307 | 308 | if exists(mask): 309 | coor_weights.masked_fill_(~mask, 0.) 310 | 311 | if exists(self.coor_weights_clamp_value): 312 | clamp_value = self.coor_weights_clamp_value 313 | coor_weights.clamp_(min = -clamp_value, max = clamp_value) 314 | 315 | coors_out = einsum('b i j, b i j c -> b i c', coor_weights, rel_coors) + coors 316 | else: 317 | coors_out = coors 318 | 319 | if exists(self.node_mlp): 320 | if exists(mask): 321 | m_ij_mask = rearrange(mask, '... -> ... ()') 322 | m_ij = m_ij.masked_fill(~m_ij_mask, 0.) 323 | 324 | if self.m_pool_method == 'mean': 325 | if exists(mask): 326 | # masked mean 327 | mask_sum = m_ij_mask.sum(dim = -2) 328 | m_i = safe_div(m_ij.sum(dim = -2), mask_sum) 329 | else: 330 | m_i = m_ij.mean(dim = -2) 331 | 332 | elif self.m_pool_method == 'sum': 333 | m_i = m_ij.sum(dim = -2) 334 | 335 | normed_feats = self.node_norm(feats) 336 | node_mlp_input = torch.cat((normed_feats, m_i), dim = -1) 337 | node_out = self.node_mlp(node_mlp_input) + feats 338 | else: 339 | node_out = feats 340 | 341 | return node_out, coors_out 342 | 343 | class EGNN_DENSE(nn.Module): 344 | def __init__( 345 | self, 346 | dim, 347 | edge_dim = 0, 348 | m_dim = 16, 349 | fourier_features = 0, 350 | num_nearest_neighbors = 0, 351 | dropout = 0.0, 352 | init_eps = 1e-3, 353 | norm_feats = False, 354 | norm_coors = False, 355 | norm_coors_scale_init = 1e-2, 356 | update_feats = True, 357 | update_coors = True, 358 | only_sparse_neighbors = False, 359 | valid_radius = float('inf'), 360 | m_pool_method = 'sum', 361 | soft_edges = False, 362 | coor_weights_clamp_value = None, 363 | use_rel_coord = False, 364 | act = nn.SiLU, 365 | ): 366 | super().__init__() 367 | assert m_pool_method in {'sum', 'mean'}, 'pool method must be either sum or mean' 368 | assert update_feats or update_coors, 'you must update either features, coordinates, or both' 369 | 370 | self.fourier_features = fourier_features 371 | self.use_rel_coord = use_rel_coord 372 | edge_input_dim = (fourier_features * 2) + (dim * 2) + edge_dim + 1 + 4 * self.use_rel_coord 373 | dropout = nn.Dropout(dropout, inplace=True) if dropout > 0 else nn.Identity() 374 | 375 | self.edge_mlp = nn.Sequential( 376 | nn.Linear(edge_input_dim, edge_input_dim * 2), 377 | dropout, 378 | SiLU(), 379 | nn.Linear(edge_input_dim * 2, m_dim), 380 | SiLU() 381 | ) 382 | 383 | self.edge_gate = nn.Sequential( 384 | nn.Linear(m_dim, 1), 385 | nn.Sigmoid() 386 | ) if soft_edges else None 387 | 388 | self.node_norm = nn.LayerNorm(dim) if norm_feats else nn.Identity() 389 | self.coors_norm = CoorsNorm(scale_init = norm_coors_scale_init) if norm_coors else nn.Identity() 390 | 391 | self.m_pool_method = m_pool_method 392 | 393 | self.node_mlp = nn.Sequential( 394 | nn.Linear(dim + m_dim, dim * 2), 395 | dropout, 396 | SiLU(), 397 | nn.Linear(dim * 2, dim), 398 | ) if update_feats else None 399 | 400 | self.coors_mlp = nn.Sequential( 401 | nn.Linear(m_dim, m_dim * 4), 402 | dropout, 403 | SiLU(), 404 | nn.Linear(m_dim * 4, 1) 405 | ) if update_coors else None 406 | 407 | self.num_nearest_neighbors = num_nearest_neighbors 408 | self.only_sparse_neighbors = only_sparse_neighbors 409 | self.valid_radius = valid_radius 410 | 411 | self.coor_weights_clamp_value = coor_weights_clamp_value 412 | 413 | self.init_eps = init_eps 414 | self.apply(self.init_) 415 | 416 | def init_(self, module): 417 | if type(module) in {nn.Linear}: 418 | # seems to be needed to keep the network from exploding to NaN with greater depths 419 | nn.init.normal_(module.weight, std = self.init_eps) 420 | 421 | def forward(self, feats, coors, edges = None, mask = None, adj_mat = None, num_nearest = 0): 422 | b, n, d, device, fourier_features, num_nearest, valid_radius, only_sparse_neighbors = *feats.shape, feats.device, self.fourier_features, num_nearest, self.valid_radius, self.only_sparse_neighbors 423 | 424 | if exists(mask): 425 | num_nodes = mask.sum(dim = -1) 426 | 427 | use_nearest = num_nearest > 0 or only_sparse_neighbors 428 | 429 | rel_coors = rearrange(coors, 'b i d -> b i () d') - rearrange(coors, 'b j d -> b () j d') 430 | rel_dist = (rel_coors ** 2).sum(dim = -1, keepdim = True) 431 | 432 | i = j = n 433 | 434 | if use_nearest: 435 | ranking = rel_dist[..., 0].clone() 436 | 437 | if exists(mask): 438 | rank_mask = mask[:, :, None] * mask[:, None, :] 439 | ranking.masked_fill_(~rank_mask, 1e5) 440 | 441 | if exists(adj_mat): 442 | if len(adj_mat.shape) == 2: 443 | adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b) 444 | 445 | if only_sparse_neighbors: 446 | num_nearest = int(adj_mat.float().sum(dim = -1).max().item()) 447 | valid_radius = 0 448 | 449 | self_mask = rearrange(torch.eye(n, device = device, dtype = torch.bool), 'i j -> () i j') 450 | 451 | adj_mat = adj_mat.masked_fill(self_mask, False) 452 | ranking.masked_fill_(self_mask, -1.) 453 | ranking.masked_fill_(adj_mat, 0.) 454 | 455 | nbhd_ranking, nbhd_indices = ranking.topk(num_nearest, dim = -1, largest = False) 456 | 457 | nbhd_mask = nbhd_ranking <= valid_radius 458 | 459 | rel_coors = batched_index_select(rel_coors, nbhd_indices, dim = 2) 460 | rel_dist = batched_index_select(rel_dist, nbhd_indices, dim = 2) 461 | 462 | if exists(edges): 463 | edges = batched_index_select(edges, nbhd_indices, dim = 2) 464 | 465 | j = num_nearest 466 | 467 | if fourier_features > 0: 468 | rel_dist = fourier_encode_dist(rel_dist, num_encodings = fourier_features) 469 | rel_dist = rearrange(rel_dist, 'b i j () d -> b i j d') 470 | 471 | if use_nearest: 472 | feats_j = batched_index_select(feats, nbhd_indices, dim = 1) 473 | else: 474 | feats_j = rearrange(feats, 'b j d -> b () j d') 475 | 476 | feats_i = rearrange(feats, 'b i d -> b i () d') 477 | feats_i, feats_j = broadcast_tensors(feats_i, feats_j) 478 | if self.use_rel_coord: 479 | rel_dist = torch.cat((rel_dist, rel_coors), dim=-1) 480 | edge_input = torch.cat((feats_i, feats_j, rel_dist), dim = -1) 481 | 482 | if exists(edges): 483 | edge_input = torch.cat((edge_input, edges), dim = -1) 484 | 485 | m_ij = self.edge_mlp(edge_input) 486 | 487 | if exists(self.edge_gate): 488 | m_ij = m_ij * self.edge_gate(m_ij) 489 | 490 | if exists(mask): 491 | mask_i = rearrange(mask, 'b i -> b i ()') 492 | 493 | if use_nearest: 494 | mask_j = batched_index_select(mask, nbhd_indices, dim = 1) 495 | mask = (mask_i * mask_j) & nbhd_mask 496 | else: 497 | mask_j = rearrange(mask, 'b j -> b () j') 498 | mask = mask_i * mask_j 499 | 500 | if exists(self.coors_mlp): 501 | coor_weights = self.coors_mlp(m_ij) 502 | coor_weights = rearrange(coor_weights, 'b i j () -> b i j') 503 | 504 | rel_coors = self.coors_norm(rel_coors) 505 | 506 | if exists(mask): 507 | coor_weights.masked_fill_(~mask, 0.) 508 | 509 | if exists(self.coor_weights_clamp_value): 510 | clamp_value = self.coor_weights_clamp_value 511 | coor_weights.clamp_(min = -clamp_value, max = clamp_value) 512 | 513 | coors_out = einsum('b i j, b i j c -> b i c', coor_weights, rel_coors) + coors 514 | else: 515 | coors_out = coors 516 | 517 | if exists(self.node_mlp): 518 | if exists(mask): 519 | m_ij_mask = rearrange(mask, '... -> ... ()') 520 | m_ij = m_ij.masked_fill(~m_ij_mask, 0.) 521 | 522 | if self.m_pool_method == 'mean': 523 | if exists(mask): 524 | # masked mean 525 | mask_sum = m_ij_mask.sum(dim = -2) 526 | m_i = safe_div(m_ij.sum(dim = -2), mask_sum) 527 | else: 528 | m_i = m_ij.mean(dim = -2) 529 | 530 | elif self.m_pool_method == 'sum': 531 | m_i = m_ij.sum(dim = -2) 532 | 533 | normed_feats = self.node_norm(feats) 534 | node_mlp_input = torch.cat((normed_feats, m_i), dim = -1) 535 | node_out = self.node_mlp(node_mlp_input) + feats 536 | else: 537 | node_out = feats 538 | 539 | return node_out, coors_out -------------------------------------------------------------------------------- /src/hypermodel.py: -------------------------------------------------------------------------------- 1 | from genericpath import exists 2 | import torch 3 | import numpy as np 4 | from torch_geometric.nn import global_mean_pool as gap 5 | from torch_geometric.nn import global_max_pool as gmp 6 | from torch_geometric.nn import global_add_pool as gsp 7 | from torch_geometric.nn import SAGPooling as Pool 8 | from torch_geometric.nn import SAGEConv, GATv2Conv, GINConv, ResGatedGraphConv 9 | import torch.nn.functional as F 10 | import pdb 11 | from torch import Tensor, dropout 12 | from torch_scatter import scatter, scatter_mean 13 | import torch.nn as nn 14 | from torchvision import models 15 | import src.egnn as egnn_models 16 | from torch.nn.utils.rnn import pad_sequence 17 | 18 | def split_batch(x, ptr): 19 | if len(ptr) == 1: return [x] 20 | split_x = list(torch.split_with_sizes(x, ptr.cpu().numpy().tolist(), dim=0)) 21 | return split_x 22 | 23 | def unpad_sequence(padded_sequences, masks): 24 | unpadded_sequences = [] 25 | 26 | for seq, mask in zip(padded_sequences, masks): 27 | unpacked_seq = seq[mask] 28 | unpadded_sequences.append(unpacked_seq) 29 | 30 | return unpadded_sequences 31 | 32 | class CNN(torch.nn.Module): 33 | """ cnn baseline """ 34 | def __init__(self,args): 35 | super(CNN, self).__init__() 36 | self.args = args 37 | self.dropout = args.dropout_ratio 38 | self.num_classes = 1 39 | self.net = models.vgg11(pretrained=True) 40 | self.net.classifier = nn.Sequential( 41 | nn.Linear(512 * 7 * 7, 4096), 42 | nn.ReLU(True), 43 | nn.Dropout(self.dropout), 44 | nn.Linear(4096, 4096), 45 | nn.ReLU(True), 46 | nn.Dropout(self.dropout), 47 | nn.Linear(4096, 1000), 48 | nn.ReLU(True), 49 | nn.Dropout(self.dropout), 50 | nn.Linear(1000, 1), 51 | ) 52 | 53 | def forward(self, data): 54 | 55 | x = data.density 56 | x = self.net(x) 57 | return x.view(-1) 58 | 59 | def predict(self, data): 60 | x = data.density 61 | x = self.net(x) 62 | return x.view(-1) 63 | 64 | class RClassifier(torch.nn.Module): 65 | """ cnn baseline """ 66 | def __init__(self,args): 67 | super(RClassifier, self).__init__() 68 | self.args = args 69 | self.num_classes = 1 70 | self.dropout = args.dropout_ratio 71 | self.net = models.vgg11(pretrained=True) 72 | self.net.classifier = nn.Sequential( 73 | nn.Linear(512 * 7 * 7, 4096), 74 | nn.ReLU(True), 75 | nn.Dropout(self.dropout), 76 | nn.Linear(4096, 1), 77 | ) 78 | 79 | self.out = nn.Sigmoid() 80 | 81 | def forward(self, data): 82 | x = data.density 83 | x = self.net(x) 84 | return x.view(-1) 85 | 86 | def predict(self, data): 87 | x = data.density 88 | return self.net(x).view(-1) 89 | 90 | class Classifier(torch.nn.Module): 91 | """ cnn baseline """ 92 | def __init__(self,args): 93 | super(Classifier, self).__init__() 94 | self.args = args 95 | self.num_classes = 1 96 | self.dropout = args.dropout_ratio 97 | self.net = models.vgg11(pretrained=False) 98 | self.net.classifier = nn.Sequential( 99 | nn.Linear(512 * 7 * 7, 4096), 100 | nn.ReLU(True), 101 | nn.Dropout(self.dropout), 102 | nn.Linear(4096, 1), 103 | ) 104 | 105 | self.out = nn.Sigmoid() 106 | 107 | def forward(self, data): 108 | x = data.density 109 | index = torch.arange(0, x.shape[0], 2).to(x.device) 110 | x0 = torch.index_select(x,dim=0,index=index) 111 | x1 = torch.index_select(x,dim=0,index=(index+1)) 112 | s0 = self.net(x0) 113 | s1 = self.net(x1) 114 | 115 | x = self.out(s0 - s1) 116 | return x.view(-1) 117 | 118 | def predict(self, data): 119 | x = data.density 120 | return self.net(x).view(-1) 121 | 122 | class GClassifier(torch.nn.Module): 123 | """ cnn baseline """ 124 | def __init__(self, args): 125 | super(GClassifier, self).__init__() 126 | self.args = args 127 | self.net = EHGNN(args=args) 128 | self.out = nn.Sigmoid() 129 | 130 | def forward(self, data): 131 | s0 = self.net.predict(data, data.macro_pos1) 132 | s1 = self.net.predict(data, data.macro_pos2) 133 | x = self.out(s0 - s1) 134 | return x.view(-1) 135 | 136 | def predict(self, data): 137 | return self.net(data).view(-1) 138 | 139 | def test(self, data): 140 | return self.net(data).view(-1) 141 | 142 | 143 | class V2PLayer(torch.nn.Module): 144 | def __init__(self): 145 | super(V2PLayer, self).__init__() 146 | 147 | def forward(self, node_feat : Tensor, pin_feat : Tensor, edge_index : Tensor)->Tensor: 148 | # creare edges 149 | node_pin_feat = torch.index_select(node_feat, dim=0, index=edge_index[0]) 150 | pin_feat = torch.cat([node_pin_feat, pin_feat], dim=-1) 151 | return pin_feat 152 | 153 | 154 | class P2VLayer(torch.nn.Module): 155 | def __init__(self): 156 | super(P2VLayer, self).__init__() 157 | 158 | def forward(self, pin_feat : Tensor, edge_index : Tensor)->Tensor: 159 | # creare edges 160 | node_feat = scatter_mean(pin_feat, edge_index[0], dim=0) 161 | return node_feat 162 | 163 | 164 | class HyperGATConv(torch.nn.Module): 165 | def __init__(self,in_nch=6, in_pch=2, in_ech=1, nhid=16, out_ch=16, dropout=0, leaky_relu=0.1): 166 | super(HyperGATConv, self).__init__() 167 | self.in_node_ch = in_nch 168 | self.in_pin_ch = in_pch 169 | self.in_edge_ch = in_ech 170 | self.nhid = nhid 171 | self.out_ch = out_ch 172 | self.dropout = dropout 173 | self.leaky_relu = leaky_relu 174 | 175 | self.Dropout = nn.Dropout(p=self.dropout, inplace=True) 176 | self.act = nn.LeakyReLU(self.leaky_relu) 177 | self.v2p = V2PLayer() 178 | self.p2e = GATv2Conv(in_channels=(in_nch + in_pch, in_ech), out_channels=nhid, dropout=dropout) 179 | self.e2p = SAGEConv(in_channels=(nhid, in_nch + in_pch), out_channels=out_ch) 180 | self.p2v = P2VLayer() 181 | 182 | def forward(self, node_feat : Tensor, pin_feat : Tensor, edge_index : Tensor, edge_attr : Tensor): 183 | # creare edges 184 | pins_index = torch.arange(0, pin_feat.shape[0], 1).to(node_feat.device) 185 | if edge_attr.dim() == 1: edge_attr = edge_attr.view(-1, 1) 186 | pin_edge = torch.stack((pins_index, edge_index[1]),dim=0) 187 | edge_pin = torch.stack((edge_index[1], pins_index),dim=0) 188 | # forward, v2p 189 | pin_feat = self.v2p(node_feat = node_feat, pin_feat = pin_feat, edge_index = edge_index) 190 | # p2e 191 | pin_feat = self.Dropout(pin_feat) 192 | edge_feat = self.p2e(x=(pin_feat, edge_attr), edge_index=pin_edge) 193 | edge_feat = self.act(edge_feat) 194 | # e2p 195 | edge_feat = self.Dropout(edge_feat) 196 | pin_feat = self.e2p(x=(edge_feat, pin_feat), edge_index=edge_pin, size=(edge_attr.shape[0], pin_feat.shape[0])) 197 | pin_feat = self.act(pin_feat) 198 | # p2v 199 | node_feat = self.p2v(pin_feat=pin_feat, edge_index=edge_index) 200 | return node_feat, pin_feat 201 | 202 | 203 | class EGNNet(torch.nn.Module): 204 | """ plain gnn baseline """ 205 | def __init__(self, layers = 3, feat_dim=32, pos_dim=2, nhid=32, position_encoding = 0, num_nearest_neighbors=0, dropout = 0.,edge_dim=0, args=None): 206 | super(EGNNet, self).__init__() 207 | self.layers = layers 208 | self.feat_dim = feat_dim 209 | self.pos_dim = pos_dim 210 | self.nhid = nhid 211 | self.position_encoding = position_encoding 212 | self.num_nearest_neighbors = num_nearest_neighbors 213 | self.dropout = dropout 214 | self.edge_dim = edge_dim 215 | self.embedd = nn.Sequential( 216 | nn.Linear(feat_dim, nhid), 217 | nn.Dropout(p=dropout, inplace=True), 218 | nn.LeakyReLU(negative_slope=0.1), 219 | ) 220 | self.convs = nn.ModuleList([]) 221 | 222 | base_model = getattr(egnn_models, args.base_model) 223 | for i in range(layers): 224 | self.convs.append(base_model( 225 | dim = nhid, # input dimension 226 | edge_dim = edge_dim, # dimension of the edges, if exists, should be > 0 227 | m_dim = nhid, # hidden model dimension 228 | fourier_features = position_encoding, # number of fourier features for encoding of relative distance - defaults to none as in paper 229 | num_nearest_neighbors = num_nearest_neighbors, # cap the number of neighbors doing message passing by relative distance 230 | dropout = dropout, # dropout 231 | norm_feats = True if args.model == 'GClassifier' else False, # whether to layernorm the features 232 | norm_coors = True, # whether to normalize the coordinates, using a strategy from the SE(3) Transformers paper 233 | update_feats = True, # whether to update features - you can build a layer that only updates one or the other 234 | update_coors = True, # whether ot update coordinates 235 | only_sparse_neighbors = False, # using this would only allow message passing along adjacent neighbors, using the adjacency matrix passed in 236 | valid_radius = float('inf'), # the valid radius each node considers for message passing 237 | m_pool_method = 'sum', # whether to mean or sum pool for output node representation 238 | soft_edges = True, # extra GLU on the edges, purportedly helps stabilize the network in updated version of the paper 239 | coor_weights_clamp_value = None, # clamping of the coordinate updates, again, for stabilization purposes 240 | act=nn.LeakyReLU(negative_slope=0.1), 241 | use_rel_coord = True if (args.model == 'GClassifier' and args.label[0] == '3') else False 242 | )) 243 | 244 | 245 | def forward(self, feat_ : Tensor, coor_ : Tensor, batch : Tensor, edge_index : Tensor=None, edge_attr : Tensor=None): 246 | if self.embedd is not None : feat_ = self.embedd(feat_) 247 | # if batch.max() <= 200 : 248 | # feats = split_batch(feat_, batch) 249 | # feats = pad_sequence(feats, batch_first=True) 250 | # coors = split_batch(coor_, batch) 251 | # coors = pad_sequence(coors, batch_first=True) 252 | # masks = split_batch(feat_.new_ones(feat_.shape[0], dtype=bool), batch) 253 | # masks = pad_sequence(masks, batch_first=True) 254 | 255 | # for i, conv in enumerate(self.convs): 256 | # feats, coors = conv(feats, coors, mask=masks) 257 | 258 | # feats = unpad_sequence(feats, masks) 259 | # feats = torch.cat(feats, dim=0) 260 | # else: 261 | feats = split_batch(feat_, batch) 262 | coors = split_batch(coor_, batch) 263 | feats = [p.view(1, -1, self.feat_dim) for p in feats] 264 | coors = [p.view(1, -1, self.pos_dim) for p in coors] 265 | zipped = list(zip(feats, coors)) 266 | 267 | for i, conv in enumerate(self.convs): 268 | zipped = list(map(lambda p: conv(p[0], p[1], num_nearest=128 if p[0].shape[1] > 128 else 0), zipped)) 269 | 270 | feats = [p[0].view(-1, self.feat_dim) for p in zipped] 271 | feats = torch.cat(feats, dim=0) 272 | return feats 273 | 274 | 275 | class HGNN(torch.nn.Module): 276 | """ plain gnn baseline """ 277 | def __init__(self,args=None): 278 | super(HGNN, self).__init__() 279 | self.args = args 280 | self.out_ch = 1 281 | self.num_node_features = args.num_node_features 282 | self.num_pin_features = args.num_pin_features 283 | self.num_edge_features = args.num_edge_features 284 | self.nhid = args.nhid 285 | self.negative_slope = 0.1 286 | self.dropout_ratio = args.dropout_ratio 287 | self.conv_layers = args.layers 288 | self.skip_cnt = args.skip_cnt 289 | 290 | self.convs = nn.ModuleList([HyperGATConv(in_nch=self.num_node_features,in_pch=self.num_pin_features, 291 | in_ech=self.num_edge_features, nhid=self.nhid, 292 | out_ch=self.nhid, dropout=self.dropout_ratio)]) 293 | for i in range(self.conv_layers - 1): 294 | self.convs.append(HyperGATConv(in_nch=self.nhid,in_pch=self.nhid, 295 | in_ech=self.num_edge_features, nhid=self.nhid, 296 | out_ch=self.nhid, dropout=self.dropout_ratio)) 297 | self.mlp = nn.Sequential( 298 | nn.Linear(self.nhid * 2, self.nhid), 299 | nn.LeakyReLU(negative_slope=self.negative_slope,inplace=True), 300 | nn.Dropout(p=self.dropout_ratio), 301 | nn.Linear(self.nhid, self.nhid), 302 | nn.LeakyReLU(negative_slope=self.negative_slope,inplace=True), 303 | nn.Linear(self.nhid, self.out_ch)) 304 | 305 | def forward(self, data): 306 | #pdb.set_trace() 307 | x, edge_index = data.x, data.edge_index 308 | pin_feat, edge_weight = data.pin_offset, data.edge_weight 309 | batch, macro_index = data.batch, data.macro_index 310 | # add macro pos 311 | macro_batch = batch[macro_index] 312 | # model forward 313 | for i in range(self.conv_layers): 314 | last_x, last_pin_feat = x, pin_feat 315 | x, pin_feat = self.convs[i](x, pin_feat, edge_index, edge_weight) 316 | if self.skip_cnt and i > 0: x, pin_feat = x + last_x, pin_feat + last_pin_feat 317 | # 318 | macro_feature = x[macro_index] 319 | x = torch.cat([gap(macro_feature, macro_batch), gap(x, batch)], dim=-1) 320 | # mlp 321 | x = self.mlp(x) 322 | return x 323 | 324 | 325 | class EHGNN(torch.nn.Module): 326 | """ egnn + gnn """ 327 | def __init__(self,args=None): 328 | super(EHGNN, self).__init__() 329 | self.args = args 330 | self.out_ch = 1 331 | self.num_node_features = args.num_node_features 332 | self.num_pin_features = args.num_pin_features 333 | self.num_edge_features = args.num_edge_features 334 | self.nhid = args.nhid 335 | self.negative_slope = 0.1 336 | self.dropout_ratio = args.dropout_ratio 337 | self.conv_layers = args.layers 338 | self.skip_cnt = args.skip_cnt 339 | self.pos_encode = args.pos_encode 340 | self.pos_dim = 4 341 | self.num_egnn = args.egnn_layers 342 | self.egnn_dim = args.egnn_nhid 343 | 344 | self.convs = nn.ModuleList([HyperGATConv(in_nch=self.num_node_features,in_pch=self.num_pin_features, 345 | in_ech=self.num_edge_features, nhid=self.nhid, 346 | out_ch=self.nhid, dropout=self.dropout_ratio)]) 347 | for i in range(self.conv_layers - 1): 348 | self.convs.append(HyperGATConv(in_nch=self.nhid,in_pch=self.nhid, 349 | in_ech=self.num_edge_features, nhid=self.nhid, 350 | out_ch=self.nhid, dropout=self.dropout_ratio)) 351 | 352 | self.posnet = EGNNet(self.num_egnn, self.nhid, self.pos_dim, self.egnn_dim, position_encoding=self.pos_encode, dropout=self.dropout_ratio, args=args) 353 | 354 | self.mlp = nn.Sequential( 355 | nn.Linear(self.egnn_dim, self.nhid), 356 | nn.LeakyReLU(negative_slope=self.negative_slope), 357 | nn.Dropout(p=self.dropout_ratio, inplace=True), 358 | nn.Linear(self.nhid, self.nhid), 359 | nn.LeakyReLU(negative_slope=self.negative_slope), 360 | nn.Linear(self.nhid, self.out_ch)) 361 | 362 | def forward(self, data): 363 | 364 | x, edge_index = data.x, data.edge_index 365 | pin_feat, edge_weight = data.pin_offset, data.edge_weight 366 | batch, macro_index = data.batch, data.macro_index 367 | # add macro pos 368 | macro_batch = batch[macro_index] 369 | macro_pos = data.macro_pos 370 | # model forward 371 | for i, conv in enumerate(self.convs): 372 | last_x, last_pin_feat = x, pin_feat 373 | x, pin_feat = conv(x, pin_feat, edge_index, edge_weight) 374 | if self.skip_cnt and i > 0: x, pin_feat = x + last_x, pin_feat + last_pin_feat 375 | # 376 | macro_feature = x[macro_index] 377 | # EGNN for position feature 378 | feat = self.posnet(macro_feature, macro_pos, data.macro_num) 379 | # mlp 380 | #x = torch.cat([x, gap(feat, macro_batch)], dim=-1) 381 | x = gap(feat, macro_batch) 382 | x = self.mlp(x) 383 | return x 384 | 385 | def predict(self, data, macro_pos): 386 | """ eplicitly input macro_pos, since other info are all the same within a netlist """ 387 | x, edge_index = data.x, data.edge_index 388 | pin_feat, edge_weight = data.pin_offset, data.edge_weight 389 | batch, macro_index = data.batch, data.macro_index 390 | # add macro pos 391 | macro_batch = batch[macro_index] 392 | # model forward 393 | for i, conv in enumerate(self.convs): 394 | last_x, last_pin_feat = x, pin_feat 395 | x, pin_feat = conv(x, pin_feat, edge_index, edge_weight) 396 | if self.skip_cnt and i > 0: x, pin_feat = x + last_x, pin_feat + last_pin_feat 397 | # macro feature 398 | macro_feature = x[macro_index] 399 | # EGNN for position feature 400 | feat = self.posnet(macro_feature, macro_pos, data.macro_num) 401 | # mlp 402 | x = gap(feat, macro_batch) 403 | x = self.mlp(x) 404 | return x 405 | 406 | 407 | class CEHGNN(torch.nn.Module): 408 | """ plain gnn baseline """ 409 | def __init__(self,args=None): 410 | super(CEHGNN, self).__init__() 411 | self.args = args 412 | self.out_ch = 1 413 | self.num_node_features = args.num_node_features 414 | self.num_pin_features = args.num_pin_features 415 | self.num_edge_features = args.num_edge_features 416 | self.nhid = args.nhid 417 | self.negative_slope = 0.1 418 | self.dropout_ratio = args.dropout_ratio 419 | self.conv_layers = args.layers 420 | self.skip_cnt = args.skip_cnt 421 | self.pos_encode = args.pos_encode 422 | self.pos_dim = 4 423 | self.num_egnn = args.egnn_layers 424 | self.egnn_dim = args.egnn_nhid 425 | 426 | self.convs = nn.ModuleList([HyperGATConv(in_nch=self.num_node_features,in_pch=self.num_pin_features, 427 | in_ech=self.num_edge_features, nhid=self.nhid, 428 | out_ch=self.nhid, dropout=self.dropout_ratio)]) 429 | for i in range(self.conv_layers - 1): 430 | self.convs.append(HyperGATConv(in_nch=self.nhid,in_pch=self.nhid, 431 | in_ech=self.num_edge_features, nhid=self.nhid, 432 | out_ch=self.nhid, dropout=self.dropout_ratio)) 433 | 434 | self.posnet = EGNNet(self.num_egnn, self.nhid, self.pos_dim, self.egnn_dim, position_encoding=self.pos_encode, dropout=self.dropout_ratio, args=args) 435 | 436 | self.net = models.vgg11(pretrained=True) 437 | self.net.classifier = nn.Sequential( 438 | nn.Dropout(self.dropout_ratio), 439 | nn.Linear(512 * 7 * 7, self.egnn_dim), 440 | nn.ReLU(True), 441 | ) 442 | 443 | self.mlp = nn.Sequential( 444 | nn.Linear(self.egnn_dim * 2, self.nhid), 445 | nn.LeakyReLU(negative_slope=self.negative_slope), 446 | nn.Dropout(p=self.dropout_ratio), 447 | nn.Linear(self.nhid, self.nhid), 448 | nn.LeakyReLU(negative_slope=self.negative_slope), 449 | nn.Linear(self.nhid, self.out_ch)) 450 | 451 | def forward(self, data): 452 | 453 | x, edge_index = data.x, data.edge_index 454 | pin_feat, edge_weight = data.pin_offset, data.edge_weight 455 | batch, macro_index = data.batch, data.macro_index 456 | density = data.pic 457 | # add macro pos 458 | macro_batch = batch[macro_index] 459 | macro_pos = data.macro_pos 460 | # model forward 461 | for i, conv in enumerate(self.convs): 462 | last_x, last_pin_feat = x, pin_feat 463 | x, pin_feat = conv(x, pin_feat, edge_index, edge_weight) 464 | if self.skip_cnt and i > 0: x, pin_feat = x + last_x, pin_feat + last_pin_feat 465 | # 466 | macro_feature = x[macro_index] 467 | x = torch.cat([gap(macro_feature, macro_batch), gap(x, batch)], dim=-1) 468 | # EGNN for position feature 469 | feat = self.posnet(macro_feature, macro_pos, data.macro_num) 470 | # mlp 471 | #x = torch.cat([x, gap(feat, macro_batch)], dim=-1) 472 | feat = gap(feat, macro_batch) 473 | # density feature 474 | density_feat = self.net(density) 475 | x = torch.cat([feat, density_feat], dim=-1) 476 | x = self.mlp(x) 477 | return x -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import sys 3 | import argparse 4 | import os 5 | import time 6 | import random 7 | import numpy as np 8 | from random import random 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch_geometric.loader import DataLoader 13 | from torch_optimizer import swats 14 | import torch_optimizer as optim 15 | from torch.utils.data import random_split, Subset, ConcatDataset 16 | from torch.utils.tensorboard import writer 17 | 18 | 19 | import src.hyperdataset as hdatasets 20 | import src.hypermodel as hmodels 21 | from src.logger import Logger 22 | from src.util import InversePairs, mle_loss, spearman, dcg_score 23 | 24 | #from src.meta import META 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument('--seed', type=int, default=777, help='seed') 28 | parser.add_argument('--device', type=str, default='cuda:0',help='device') 29 | parser.add_argument('--model', type=str, default='GClassifier',help='which mdoel to use') 30 | parser.add_argument('--batch_size', type=int, default=8,help='train batch size') 31 | parser.add_argument('--batch_step', type=int, default=1,help='how many batches per update') 32 | parser.add_argument('--test_batch_size', type=int, default=8,help='test batch size') 33 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 34 | parser.add_argument('--step_size', type=int, default=50, help='learning rate decay step') 35 | parser.add_argument('--lr_decay', type=float, default=1., help='learning rate decay ratio') 36 | parser.add_argument('--weight_decay', type=float, default=0.0001, help='weight decay') 37 | parser.add_argument('--nhid', type=int, default=16, help='hidden size') 38 | parser.add_argument('--layers',type=int,default=2,help='conv layers') 39 | parser.add_argument('--egnn_layers',type=int,default=3,help='egnn layers') 40 | parser.add_argument('--egnn_nhid',type=int,default=16,help='egnn layers hidden dim') 41 | #parser.add_argument('--pooling_ratio', type=float, default=0.1,help='pooling ratio') 42 | parser.add_argument('--dropout_ratio', type=float, default=0.1,help='dropout ratio') 43 | parser.add_argument('--group', type=int, default=0, help='which data group to use') 44 | parser.add_argument('--tests', type=str, nargs='+', 45 | default=['mgc_des_perf_a', 'mgc_fft_a', 'mgc_matrix_mult_a', 'mgc_matrix_mult_c', 'mgc_superblue14', 'mgc_superblue19'],help='test data') 46 | parser.add_argument('--trains', type=str, nargs='+', 47 | default=['mgc_edit_dist_a', 'mgc_fft_b', 'mgc_matrix_mult_b', 'mgc_pci_bridge32_b', 'mgc_superblue11_a', 'mgc_superblue16_a'],help='train data') 48 | parser.add_argument('--dataset_path', type=str, default='data') 49 | parser.add_argument('--dataset', type=str, default='PlainClusterSet') 50 | parser.add_argument('--epochs', type=int, default=400,help='maximum number of epochs') 51 | parser.add_argument('--patience', type=int, default=400,help='patience for earlystopping') 52 | parser.add_argument('--save_dir', type=str, default='save') 53 | parser.add_argument('--goon', action='store_true',help='continue training') 54 | parser.add_argument('--con', action='store_true',help='continue training') 55 | parser.add_argument('--checkp', type=str, default='test.pth') 56 | parser.add_argument('--pos_encode', type=int, default=4, help='whether use pos encoding on position') 57 | parser.add_argument('--size_encode', type=int, default=0, help='whether use pos encoding on size') 58 | parser.add_argument('--offset_encode', type=int, default=0, help='whether use pos encoding on offset') 59 | parser.add_argument('--design', type=str, default='all',help='whitch design to train') 60 | parser.add_argument('--loss', type=str, default='MAE',help='loss func') 61 | parser.add_argument('--acc', type=str, default='rel',help='loss func') 62 | parser.add_argument('--skip_cnt', action='store_true', default=True ,help='use skip cnt ?') 63 | parser.add_argument('--regresion', action='store_true', help='regression') 64 | parser.add_argument('--classifier', action='store_true', help='classification') 65 | parser.add_argument('--base_model', type=str, default='EGNN',help='which base mdoel to use in classifier') 66 | parser.add_argument('--metric', type=str, default='lambdda',help='which metric to use as lambda, [lambdda (top1 prob), ndcg]') 67 | parser.add_argument('--label', type=list[int],default=[1],help='which label to use, [0~5] = [hpwl, rwl, via, short, score]') 68 | parser.add_argument('--train_ratio', type=float, default=0.8,help='train ratio') 69 | parser.add_argument('--optimizer',type=str,default='Adam') 70 | args = parser.parse_args() 71 | args.betas = [0.005, ] 72 | 73 | def set_seed(seed): 74 | random.seed(seed) 75 | np.random.seed(seed) 76 | torch.manual_seed(seed) 77 | torch.cuda.manual_seed_all(seed) 78 | torch.cuda.manual_seed(seed) 79 | 80 | def build_loss(args): 81 | def MAELoss(out,data): 82 | if len(args.label) > 1: 83 | label = torch.tensor(args.label[0]).long().to(data.device) 84 | y = data.y[:, label] 85 | w = data.w[:, label] 86 | return F.l1_loss(out[0].view(-1) * w,y.view(-1) * w) 87 | else: 88 | label = args.label[0] 89 | y = data.y[:, label] 90 | w = data.w[:, label] 91 | return F.l1_loss(out[0].view(-1) * w, y.view(-1) * w) 92 | def MSELoss(out,data): 93 | if len(args.label) > 1: 94 | label = torch.tensor(args.label[0]).long().to(data.device) 95 | y = data.y[:, label] 96 | w = data.w[:, label] 97 | return F.mse_loss(out[0].view(-1) * w, y.view(-1) * w) 98 | else: 99 | label = args.label[0] 100 | y = data.y[:, label] 101 | w = data.w[:, label] 102 | return F.mse_loss(out[0].view(-1) * w,y.view(-1) * w) 103 | def BCELoss(out,data): 104 | if len(args.label) > 1: 105 | label = torch.tensor(args.label[0]).long().to(data.device) 106 | y = data.y[:, label] 107 | w = data.w[:, label] 108 | return F.mse_loss(out[0].view(-1) * w, y.view(-1) * w) 109 | else: 110 | label = args.label[0] 111 | y = data.y[:, label] 112 | w = data.w[:, label] 113 | return F.binary_cross_entropy(target=y.view(-1), input=out[0].view(-1), weight=w) 114 | def MLELoss(out,data): 115 | if len(args.label) > 1: 116 | label = torch.tensor(args.label[0]).long().to(data.device) 117 | y = data.y[:, label] 118 | w = data.w[:, label] 119 | return F.mse_loss(out[0].view(-1) * w, y.view(-1) * w) 120 | else: 121 | label = args.label[0] 122 | y = data.y[:, label] 123 | return mle_loss(y.view(-1), out[0].view(-1)) 124 | def SMAELoss(out,data): 125 | if len(args.label) > 1: 126 | label = torch.tensor(args.label[0]).long().to(data.device) 127 | y = data.y[:, label] 128 | w = data.w[:, label] 129 | return F.mse_loss(out[0].view(-1) * w, y.view(-1) * w) 130 | else: 131 | label = args.label[0] 132 | y = data.y[:, label] 133 | w = data.w[:, label] 134 | torch.nn.HuberLoss 135 | return F.smooth_l1_loss(out[0].view(-1) * w, y.view(-1) * w, beta=0.005) 136 | def CMSELoss(out,data): 137 | y = getattr(data,args.label) 138 | return F.l1_loss(out[0],y,reduction='sum') 139 | def CrossEntropyLoss(out,data): 140 | label = args.label[0] 141 | y = data.y[:, label] 142 | index = torch.arange(0, y.shape[0], 2).to(y.device) 143 | y0 = torch.index_select(y, dim=0, index=index) 144 | y1 = torch.index_select(y, dim=0, index=(index+1)) 145 | target = ((y0 - y1) >= 0).long() 146 | return F.cross_entropy(out[0], target.view(-1)) 147 | def COMBLoss(out,data): 148 | # get label 149 | label = args.label[0] 150 | y = data.y[:, label] 151 | y1 = data.y1[:, label] 152 | y2 = data.y2[:, label] 153 | w = data.w[:, label] 154 | w1 = data.w1[:, label] 155 | w2 = data.w2[:, label] 156 | # get out 157 | index = torch.arange(0, out[0].shape[0], 2).to(y.device) 158 | out1 = torch.index_select(out[0], dim=0, index=index) 159 | out2 = torch.index_select(out[0], dim=0, index=(index+1)) 160 | p = torch.sigmoid(out1 - out2) 161 | # 162 | bce_loss = F.binary_cross_entropy(input=p, target=y, weight=w) 163 | mae_loss = F.l1_loss(input=out1 * w1, target=y1 * w1) + F.l1_loss(input=out2 * w2, target=y2 * w2) 164 | return bce_loss + mae_loss 165 | if args.loss == 'MSE': 166 | return MSELoss 167 | elif args.loss == 'CMSE': 168 | return CMSELoss 169 | elif args.loss == 'BCE': 170 | return BCELoss 171 | elif args.loss == 'CROSS': 172 | return CrossEntropyLoss 173 | elif args.loss == 'MLE': 174 | return MLELoss 175 | elif args.loss == 'MAE': 176 | return MAELoss 177 | elif args.loss == 'SMAE': 178 | return SMAELoss 179 | elif args.loss == 'COMB': 180 | return COMBLoss 181 | else: 182 | print('Invalid loss function!') 183 | 184 | 185 | def build_acc(args): 186 | def RelAcc(out, data): 187 | label = args.label[0] 188 | y = data.y[:, label] 189 | return torch.mean(1-torch.abs((y.view(-1)-out[0].view(-1))/(y.view(-1)))) 190 | def CRelAcc(out, data): 191 | y = getattr(data,args.label) 192 | return torch.mean(1-torch.abs((y.view(-1)-out[0].view(-1))/(y.view(-1)+0.00001))) 193 | def SROCC(out, data): 194 | label = args.label[0] 195 | y = data.y[:, label] 196 | return spearman(y.view(-1), out[0].view(-1)) 197 | def EqAcc(out, data): 198 | label = args.label[0] 199 | y = data.y[:, label] 200 | index = torch.arange(0, y.shape[0], 2).to(y.device) 201 | y0 = torch.index_select(y, dim=0, index=index) 202 | y1 = torch.index_select(y, dim=0, index=(index+1)) 203 | target = ((y0 - y1) >= 0).long() 204 | return torch.eq(torch.argmax(out[0],dim=1).view(-1), target.view(-1)).float().mean() 205 | def BEQAcc(out, data): 206 | label = args.label[0] 207 | y = data.y[:, label] 208 | mask1, mask5, mask0 = (out[0] > 0.5), (out[0] == 0.5), (out[0] < 0.5) 209 | mask = 1. * mask1 + 0.5 * mask5 210 | return torch.eq(mask.view(-1), y.view(-1)).float().mean() 211 | def COMBAcc(out, data): 212 | label = args.label[0] 213 | y1 = data.y1[:, label] 214 | y2 = data.y2[:, label] 215 | y = torch.cat((y1,y2)) 216 | index = torch.arange(0, out[0].shape[0], 2).to(y.device) 217 | out1 = torch.index_select(out[0], dim=0, index=index) 218 | out2 = torch.index_select(out[0], dim=0, index=(index+1)) 219 | out = torch.cat((out1, out2)) 220 | return torch.mean(1-torch.abs((y.view(-1)-out.view(-1))/(y.view(-1)))) 221 | if args.acc == 'rel': 222 | return RelAcc 223 | elif args.acc == 'SROCC': 224 | return SROCC 225 | elif args.acc == 'Crel': 226 | return CRelAcc 227 | elif args.acc == 'eq': 228 | return EqAcc 229 | elif args.acc == 'BEQ': 230 | return BEQAcc 231 | elif args.acc == 'COMB': 232 | return COMBAcc 233 | else: 234 | print('Invalid acc function!') 235 | assert(False) 236 | 237 | 238 | def build_loader(design,train_ratio=0.8): 239 | MySet = getattr(hdatasets,args.dataset) 240 | 241 | dataset = MySet(args.dataset_path, mode=args.model, test_files=args.tests, train_files=args.trains, args=args) 242 | 243 | if args.model != 'CNN' and args.model != 'Classifier' and args.model != 'RClassifier': 244 | args.num_node_features = dataset.num_node_features 245 | args.num_edge_features = dataset.num_edge_features 246 | args.num_pin_features = dataset.num_pin_features 247 | if args.model == 'EHGNN': 248 | args.num_pos_features = dataset.num_pos_features 249 | 250 | 251 | if design == 'all': 252 | print(dataset.train_file_names) 253 | print(dataset.test_file_names) 254 | train_designs = dataset.train_file_names 255 | test_designs = dataset.test_file_names 256 | train_sets = [] 257 | test_sets = [] 258 | test_loader = {} 259 | num_training = 0 260 | num_testing = 0 261 | for design in train_designs: 262 | train_sets.append(Subset(dataset,range(dataset.ptr[design], 263 | dataset.ptr[design] + dataset.file_num[design]))) 264 | num_training += dataset.file_num[design] 265 | for design in test_designs: 266 | test_set = Subset(dataset,range(dataset.ptr[design], 267 | dataset.ptr[design] + dataset.file_num[design])) 268 | num_testing += dataset.file_num[design] 269 | test_loader[design] = DataLoader(test_set,batch_size=args.test_batch_size,shuffle=True) 270 | train_set = ConcatDataset(train_sets) 271 | print("Total %d training data, %d testing data."%(num_training,num_testing),flush=True) 272 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) 273 | else: 274 | num_training = int(dataset.file_num[design] * train_ratio) 275 | num_testing = dataset.file_num[design] - num_training 276 | test_loader = {} 277 | design_set = Subset(dataset,range(dataset.ptr[design], 278 | dataset.ptr[design] + dataset.file_num[design])) 279 | train_set, test_set = random_split(design_set,[num_training,num_testing]) 280 | print("Total %d training data, %d testing data."%(num_training,num_testing),flush=True) 281 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) 282 | test_loader[design] = DataLoader(test_set,batch_size=args.test_batch_size,shuffle=False) 283 | return dataset, train_loader, test_loader 284 | 285 | 286 | def build_model(): 287 | Model = getattr(hmodels,args.model) 288 | model = Model(args).to(args.device) 289 | print(model) 290 | if args.optimizer == 'RAdam': 291 | optimizer = optim.RAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 292 | elif args.optimizer == 'SWATS': 293 | optimizer = swats.SWATS(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 294 | elif args.optimizer == 'Ranger': 295 | optimizer = optim.Ranger(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 296 | elif args.optimizer == 'SGD': 297 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9) 298 | elif args.optimizer == 'Nesterov': 299 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, nesterov=True, momentum=0.9) 300 | else: 301 | optimizer = getattr(torch.optim,args.optimizer)(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 302 | schedule = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.lr_decay) 303 | #optimizer.zero_grad() 304 | return model, optimizer, schedule 305 | 306 | 307 | def build_log(): 308 | # make save dir 309 | st = time.strftime("%b:%d:%X",time.localtime()) 310 | args.save_dir = os.path.join(args.save_dir,'{}_{}_{}_{}'.format(args.model,args.label,args.group,st)) 311 | if not os.path.exists(args.save_dir): 312 | os.makedirs(args.save_dir) 313 | # rederict to save dir 314 | sys.stdout = Logger(path=args.save_dir) 315 | # print args 316 | print(args) 317 | # save paths 318 | best_model_path = os.path.join(args.save_dir,'best.pth'.format(st)) 319 | last_model_path = os.path.join(args.save_dir,'last.pth'.format(st)) 320 | # tensor board logger 321 | logger = writer.SummaryWriter(args.save_dir) 322 | return best_model_path, last_model_path, logger 323 | 324 | 325 | # preparing 326 | torch.set_num_threads(16) 327 | 328 | # choose data group 329 | if args.group == 1: 330 | tmp = args.tests 331 | args.tests = args.trains 332 | args.trains = tmp 333 | if args.model == 'Classifier' or args.model == 'GClassifier': 334 | args.loss = 'BCE' 335 | args.acc = 'BEQ' 336 | if args.model == 'RClassifier': 337 | args.loss = 'COMB' 338 | args.acc = 'COMB' 339 | 340 | args.label = [int(i) for i in args.label] 341 | set_seed(args.seed) 342 | # build up 343 | best_model_path, last_model_path, logger = build_log() 344 | print('loading dataset ...') 345 | dataset, train_loader, test_loader = build_loader(args.design, args.train_ratio) 346 | model, optimizer, schedule = build_model() 347 | criterion = build_loss(args) 348 | accuracy = build_acc(args) 349 | test_designs = dataset.test_file_names 350 | 351 | start = 0 352 | min_loss = 1e10 353 | min_train_loss = 1e10 354 | min_err = 1e10 355 | patience = 0 356 | 357 | def test(model,loader): 358 | with torch.no_grad(): 359 | model.eval() 360 | lenth = len(loader)# if epoch % 5 == 0 else int(len(loader)/5) 361 | maes = [] 362 | accs = [] 363 | ipes = [] 364 | for i,label in enumerate(args.label): 365 | correct = 0. 366 | loss = 0. 367 | reals = [] 368 | preds = [] 369 | for i, data in enumerate(loader): 370 | if i >= lenth : break 371 | data = data.to(args.device) 372 | out = model(data).view(-1) 373 | y = data.y[:, label].view(-1) 374 | 375 | preds.extend(out.detach().cpu().numpy().tolist()) 376 | reals.extend(data.y[:, label].cpu().numpy().tolist()) 377 | 378 | correct += torch.mean(torch.abs((y-out)/y)).item() 379 | loss += F.l1_loss(out, y).item() 380 | # rank loss 381 | Rp = np.argsort(preds) 382 | Rr = np.argsort(np.array(reals)[Rp]) 383 | rankacc = InversePairs(Rr.tolist()) / (len(reals)**2 - len(reals)) * 2 384 | #print('[{}]MAE=\t{:4f}\tMRE={:4f}\tIPE={:4f}'.format(label,loss/len(loader),correct/len(loader),rankacc),end='\t') 385 | maes.append(loss/lenth) 386 | accs.append(correct/lenth) 387 | ipes.append(rankacc) 388 | return np.mean(maes), np.mean(accs), np.mean(ipes) 389 | 390 | 391 | def test_class(model,loader): 392 | tmp_mode = dataset.mode 393 | dataset.mode = 'CNN' 394 | with torch.no_grad(): 395 | model.eval() 396 | lenth = len(loader)# if epoch % 5 == 0 else int(len(loader)/5) 397 | for i,label in enumerate(args.label): 398 | reals = [] 399 | preds = [] 400 | for i, data in enumerate(loader): 401 | if i >= lenth : break 402 | data = data.to(args.device) 403 | out = model.predict(data).view(-1) 404 | preds.extend(out.view(-1).detach().cpu().numpy().tolist()) 405 | reals.extend(data.y[:, label].cpu().numpy().tolist()) 406 | 407 | # rank loss 408 | reals = np.array(reals) 409 | preds = np.array(preds) 410 | Rp = np.argsort(preds) 411 | Rr = np.argsort(np.array(reals)[Rp]) 412 | rankacc = InversePairs(Rr.tolist()) / (len(reals)**2 - len(reals)) * 2 413 | dcg_s = dcg_score(input=preds, target=reals) 414 | dataset.mode = tmp_mode 415 | return 0, dcg_s, rankacc 416 | 417 | 418 | def test_design(model,design, test_loader): 419 | if(args.model == 'Classifier' or args.model == 'GClassifier' or args.model=='RClassifier'): 420 | return test_class(model, test_loader[design]) 421 | return test(model, test_loader[design]) 422 | 423 | if args.goon: 424 | checkp = torch.load(args.checkp) 425 | model.load_state_dict(checkp['model']) 426 | print('load model from {}, saved at epoch {}'.format(args.checkp,start - 1)) 427 | if args.con: 428 | optimizer.load_state_dict(checkp['optimizer']) 429 | 430 | minn_loss = 10000 431 | minn_errr = 10000 432 | 433 | for epoch in range(start, args.epochs): 434 | model.train() 435 | tt = time.time() 436 | Ave_loss = 0. 437 | Ave_cor = 0. 438 | 439 | for i, data in enumerate(train_loader): 440 | data = data.to(args.device) 441 | 442 | out = [model(data)] 443 | loss = criterion(out, data) / args.batch_step 444 | loss.backward() 445 | if (i+1) % args.batch_step == 0: 446 | optimizer.step() 447 | optimizer.zero_grad() 448 | with torch.no_grad(): 449 | Ave_loss += loss.mean().item() 450 | Ave_cor += accuracy(out,data).item() 451 | 452 | if optimizer.param_groups[0]['lr'] > args.lr / 100: 453 | schedule.step() 454 | val_losses = [] 455 | rank_errs = [] 456 | print("[Epoch\t{}]\tTrain loss:\t{:.4f}\tTrain acc:\t{:.4f}".format( 457 | epoch, Ave_loss / len(train_loader) * args.batch_step, 458 | Ave_cor / len(train_loader)), flush=True,end='\t') 459 | 460 | for design in test_designs: 461 | _, val_loss, rank_err = test_design(model, design, test_loader) 462 | val_losses.append(val_loss) 463 | rank_errs.append(rank_err) 464 | 465 | mean_val_loss = np.mean(val_losses) 466 | mean_rank_err = np.mean(rank_errs) 467 | 468 | print("{} mre:\t{:.4f}\t{} ipe:\t{:.4f}\tTime:{:.2f}\tlr:{:.5f}".format( 469 | 'Test', 470 | mean_val_loss, 471 | 'Test', 472 | mean_rank_err, 473 | time.time() - tt, 474 | optimizer.param_groups[0]['lr'])) 475 | 476 | logger.add_scalar('train loss', Ave_loss / len(train_loader), i) 477 | logger.add_scalar('train acc', Ave_cor / len(train_loader), i) 478 | logger.add_scalar('test mre', mean_val_loss, i) 479 | logger.add_scalar('test rank err', mean_rank_err, i) 480 | 481 | 482 | if mean_val_loss < minn_loss: 483 | minn_loss = mean_val_loss 484 | state = {'model': model.state_dict(), 'epoch': epoch , 'val_loss' : mean_val_loss, 'rank_err' : mean_rank_err} 485 | print('model saved {} {}'.format(mean_val_loss, mean_rank_err)) 486 | torch.save(state, best_model_path + '.loss') 487 | 488 | 489 | if mean_rank_err < minn_errr: 490 | minn_errr = mean_rank_err 491 | state = {'model': model.state_dict(), 'epoch': epoch , 'val_loss' : mean_val_loss, 'rank_err' : mean_rank_err} 492 | print('model saved {} {}'.format(mean_val_loss, mean_rank_err)) 493 | torch.save(state, best_model_path + '.err') 494 | 495 | state = {'model': model.state_dict(), 'val_loss' : mean_val_loss, 'rank_err' : mean_rank_err} 496 | torch.save(state, last_model_path) 497 | 498 | --------------------------------------------------------------------------------