├── .gitignore ├── README.md ├── build.sh ├── data ├── bcb.xyz ├── ptcda.xyz ├── water.xyz └── water_with_surface.xyz ├── environment.yml ├── environment_exact.yml ├── model_schem.png ├── pretrained_weights ├── model_random.pth ├── model_y.pth └── model_z.pth ├── scripts ├── generate_data.py ├── predict_examples.py ├── predict_random.py ├── test.py ├── train.py └── train_distributed.py └── src ├── analysis.py ├── c ├── bindings.py └── matching.cpp ├── cuda ├── bindings.py ├── ccl_cuda.cu ├── defs.h ├── matching_cuda.cu └── peak_find.cpp ├── data_loading.py ├── layers.py ├── models.py ├── preprocessing.py ├── utils.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.png 3 | *.so 4 | *.o 5 | *.hdf5 6 | # ignore compiled python 3 files 7 | __pycache__ 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph-AFM 2 | Machine learning molecule graphs from atomic force microscopy images. 3 | 4 | Paper: https://link.springer.com/article/10.1557/s43577-022-00324-3 5 | 6 | Abstract: 7 | _Despite the success of non-contact atomic force microscopy (AFM) in providing atomic-scale insight into the structure and properties of matter on surfaces, the wider applicability of the technique faces challenges in the difficulty of interpreting the measurement data. We tackle this problem by proposing a machine learning model for extracting molecule graphs of samples from AFM images. The predicted graphs contain not only atoms and their bond connections, but also their coordinates within the image and elemental identification. The model is shown to be effective on simulated AFM images, but we also highlight some issues with robustness that need to be addressed before generalization to real AFM images._ 8 | 9 | ![Model schematic](model_schem.png) 10 | 11 | ## Setup and usage 12 | 13 | This repository contains all the code and data used to achieve the results in the above paper. To run the code, first clone the repository: 14 | ```sh 15 | git clone https://github.com/SINGROUP/Graph-AFM.git 16 | cd Graph-AFM 17 | ``` 18 | Then, the Python environment required for running the code can be installed with Anaconda by using the provided `environment.yml` file: 19 | ```sh 20 | conda env create -f environment.yml 21 | ``` 22 | and the installed environment can be activated with 23 | ```sh 24 | conda activate graph-afm 25 | ``` 26 | The other environment file `environment_exact.yml` contains a dump of the exact list of packages used when training the model presented in the paper. Finally, run the script `build.sh` in the root of the repository to compile C extensions: 27 | ```sh 28 | ./build.sh 29 | ``` 30 | \ 31 | The model is trained on a dataset of simulated AFM images. The whole database is inconvenient to share directly due to its size, so we instead provide a database of molecule structures that can be used to generate the simulated AFM images. The AFM simulations are done using the [ProbeParticleModel](https://github.com/ProkopHapala/ProbeParticleModel) simulation code from Prokop Hapala. Navigate to the root of this repository and clone the ProbeParticleModel repository there with 32 | ```sh 33 | git clone https://github.com/ProkopHapala/ProbeParticleModel.git 34 | cd ProbeParticleModel 35 | git checkout 99c152328808989f7a1f6206159b0d28cb03c17a 36 | ``` 37 | The database of AFM simulations can then be generated by using the script `generate_data.py` in the `scripts` directory: 38 | ```sh 39 | cd scripts 40 | python generate_data.py 41 | ``` 42 | The script will automatically download the molecule database and generate the simulations into a HDF5 archive file. The complete database takes ~150GB of disk space. The molecule database can also be downloaded directly from https://www.dropbox.com/s/z4113upq82puzht/Molecules_rebias_210611.tar.gz?dl=0. 43 | 44 | The model can be trained using the script `train.py`, or `train_distributed.py` for multi-GPU training. Expect this to take a while. Training the model for 50 epochs took roughly 2 days on a system with 4 x Nvidia Tesla V100 32GB GPUs. After the model has been trained, statistics on test set predictions can be run using the script `test.py`. 45 | 46 | Pre-trained weights for the model are provided in the directory `pretrained_weights`. See the scripts `predict_examples.py` and `predict_random.py` for examples on how to run predictions with the pre-trained model. 47 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | script_path=$(realpath $0) 4 | script_dir=$(dirname $script_path) 5 | cd $script_dir/src/c 6 | 7 | g++ -fPIC -O3 -c matching.cpp && 8 | g++ -shared matching.o -o lib.so 9 | -------------------------------------------------------------------------------- /data/bcb.xyz: -------------------------------------------------------------------------------- 1 | 12 2 | 3 | 35 -13.30722774 -5.89694824 0.43700722 0.00916000 4 | 17 -7.73986087 -5.73247149 0.43661849 -0.03614000 5 | 17 -10.52789782 -1.06960205 0.43694792 -0.03633000 6 | 6 -11.66667925 -4.91574735 0.43582599 -0.22658000 7 | 6 -11.71260448 -3.52016612 0.43586879 0.17622000 8 | 6 -10.45895940 -5.61651093 0.43605073 0.17611000 9 | 6 -9.27253958 -4.87909854 0.43615513 -0.18382000 10 | 6 -10.50159215 -2.82418772 0.43610966 -0.18389000 11 | 6 -9.27066105 -3.48289012 0.43609409 0.17770000 12 | 1 -12.66085156 -2.98622496 0.43606867 0.04193000 13 | 1 -10.43733509 -6.70464366 0.43635407 0.04196000 14 | 1 -8.33663829 -2.92427188 0.43629273 0.04369000 15 | -------------------------------------------------------------------------------- /data/ptcda.xyz: -------------------------------------------------------------------------------- 1 | 38 2 | 3 | O 3.310869 7.523879 4.999999 -0.121029 4 | O 14.689131 7.523878 5.000000 -0.121029 5 | O 3.271500 9.753426 4.999997 -0.277762 6 | O 14.728501 9.753425 5.000000 -0.277762 7 | O 3.271923 5.294312 5.000001 -0.277788 8 | O 14.728076 5.294311 5.000000 -0.277788 9 | C 7.572632 7.524225 5.000000 0.013173 10 | C 10.427368 7.524225 5.000000 0.013173 11 | C 8.268392 8.771895 5.000000 0.016655 12 | C 9.731608 8.771895 5.000001 0.016655 13 | C 8.268424 6.276615 5.000000 0.016660 14 | C 9.731576 6.276615 5.000000 0.016660 15 | C 6.142708 7.524142 5.000000 -0.005312 16 | C 11.857293 7.524142 5.000000 -0.005312 17 | C 5.432203 8.747486 4.999999 0.020539 18 | C 12.567798 8.747485 5.000000 0.020539 19 | C 5.432362 6.300709 5.000001 0.020543 20 | C 12.567638 6.300708 5.000000 0.020543 21 | C 7.521832 9.955535 5.000001 -0.038392 22 | C 10.478169 9.955534 5.000000 -0.038392 23 | C 7.522020 5.092887 5.000001 -0.038383 24 | C 10.477980 5.092886 4.999999 -0.038383 25 | C 6.127076 9.948493 5.000000 -0.033293 26 | C 11.872925 9.948493 5.000000 -0.033293 27 | C 6.127321 5.099757 5.000001 -0.033285 28 | C 11.872678 5.099756 4.999999 -0.033285 29 | C 3.957478 8.760245 4.999999 0.228104 30 | C 14.042523 8.760244 5.000000 0.228104 31 | C 3.957669 6.287658 5.000000 0.228098 32 | C 14.042331 6.287657 5.000000 0.228098 33 | H 8.025766 10.919262 5.000002 0.055438 34 | H 9.974234 10.919262 4.999999 0.055438 35 | H 8.026122 4.129262 5.000001 0.055426 36 | H 9.973878 4.129261 4.999998 0.055426 37 | H 5.563915 10.880728 5.000000 0.085307 38 | H 12.436086 10.880728 5.000000 0.085307 39 | H 5.564211 4.167493 5.000002 0.085302 40 | H 12.435789 4.167492 4.999999 0.085302 41 | -------------------------------------------------------------------------------- /data/water.xyz: -------------------------------------------------------------------------------- 1 | 21 2 | 3 | 1 11.84933835 6.78724151 10.86545000 0.19766600 4 | 1 10.42127327 6.68056668 11.48587400 0.28434600 5 | 1 7.41094859 5.91436133 11.50939900 0.28505500 6 | 1 6.14135038 5.26370359 10.87702600 0.20005700 7 | 1 11.71143556 3.12062274 10.93600000 0.28932300 8 | 1 11.18477111 4.66762954 10.94707500 0.28959900 9 | 1 7.90953806 1.48854887 10.77337500 0.30523100 10 | 1 6.81444533 0.34632464 10.88302500 0.23166300 11 | 1 9.36505161 3.32510671 10.77409900 0.28620500 12 | 1 7.83365773 3.85024239 10.92595000 0.27956800 13 | 1 8.44960430 8.11295359 11.94927500 0.26273500 14 | 1 8.56035702 7.61728588 10.47255000 0.21091000 15 | 1 12.64680489 1.04906433 11.15539900 0.23832300 16 | 1 13.71430137 2.09848150 10.80090000 0.21339600 17 | 8 11.30484987 6.24008158 11.46655000 -0.46484700 18 | 8 6.88227172 5.08004745 11.48937500 -0.46822600 19 | 8 11.00236797 3.73877761 10.59585000 -0.42917900 20 | 8 7.71256498 0.53354884 10.55340000 -0.38710200 21 | 8 8.41229805 3.10460465 10.56312500 -0.42990100 22 | 8 8.64524802 7.31987677 11.41637500 -0.49025200 23 | 8 12.92073877 1.96364260 11.36115000 -0.43929400 24 | -------------------------------------------------------------------------------- /data/water_with_surface.xyz: -------------------------------------------------------------------------------- 1 | 341 2 | 3 | 1 11.84933835 6.78724151 10.86545000 0.19766600 4 | 1 10.42127327 6.68056668 11.48587400 0.28434600 5 | 1 7.41094859 5.91436133 11.50939900 0.28505500 6 | 1 6.14135038 5.26370359 10.87702600 0.20005700 7 | 1 11.71143556 3.12062274 10.93600000 0.28932300 8 | 1 11.18477111 4.66762954 10.94707500 0.28959900 9 | 1 7.90953806 1.48854887 10.77337500 0.30523100 10 | 1 6.81444533 0.34632464 10.88302500 0.23166300 11 | 1 9.36505161 3.32510671 10.77409900 0.28620500 12 | 1 7.83365773 3.85024239 10.92595000 0.27956800 13 | 1 8.44960430 8.11295359 11.94927500 0.26273500 14 | 1 8.56035702 7.61728588 10.47255000 0.21091000 15 | 1 12.64680489 1.04906433 11.15539900 0.23832300 16 | 1 13.71430137 2.09848150 10.80090000 0.21339600 17 | 8 11.30484987 6.24008158 11.46655000 -0.46484700 18 | 8 6.88227172 5.08004745 11.48937500 -0.46822600 19 | 8 11.00236797 3.73877761 10.59585000 -0.42917900 20 | 8 7.71256498 0.53354884 10.55340000 -0.38710200 21 | 8 8.41229805 3.10460465 10.56312500 -0.42990100 22 | 8 8.64524802 7.31987677 11.41637500 -0.49025200 23 | 8 12.92073877 1.96364260 11.36115000 -0.43929400 24 | 29 0.00000000 0.00000000 0.00000000 -0.02765900 25 | 29 1.04211983 1.04210569 2.08422500 0.03078200 26 | 29 1.42236438 -0.38238991 4.16210000 -0.00385500 27 | 29 19.72348006 5.28452788 6.24047500 0.03192200 28 | 29 1.03278249 1.04176416 8.32342500 -0.02982600 29 | 29 0.66067532 2.46567569 0.00000000 -0.02923000 30 | 29 1.80499966 -1.80499966 0.00000000 -0.03033000 31 | 29 2.46567498 0.66067603 0.00000000 -0.03094500 32 | 29 1.70279515 3.50778138 2.08422500 0.03099000 33 | 29 2.84711949 -0.76289397 2.08422500 0.03081700 34 | 29 3.50779481 1.70278172 2.08422500 0.03114700 35 | 29 2.08174075 2.08410956 4.16190000 -0.00387500 36 | 29 3.22611034 -2.18906188 4.15842500 -0.00393400 37 | 29 3.88841555 0.27818571 4.15752500 -0.00353000 38 | 29 15.09881574 -11.97366540 6.24177500 0.03204800 39 | 29 4.44622309 8.06373330 6.23915000 0.02903700 40 | 29 2.46085888 0.65868977 6.23810000 0.03299700 41 | 29 1.69429644 3.51152834 8.32357500 -0.03580500 42 | 29 2.83063754 -0.76864628 8.32217500 -0.03677900 43 | 29 3.48423967 1.70267707 8.32205000 -0.04779600 44 | 29 1.32135064 4.93135138 0.00000000 -0.02962900 45 | 29 3.60999932 -3.60999932 0.00000000 -0.03064700 46 | 29 4.93134996 1.32135206 0.00000000 -0.03002000 47 | 29 2.36347047 5.97345707 2.08422500 0.03100500 48 | 29 4.65211915 -2.56789363 2.08422500 0.03095000 49 | 29 5.97346980 2.36345775 2.08422500 0.03107700 50 | 29 2.74204272 4.55078368 4.16155000 -0.00385200 51 | 29 5.03081514 -3.99201588 4.15900000 -0.00369600 52 | 29 6.36356505 0.94559269 4.16840000 -0.00332200 53 | 29 15.75983189 -9.50630397 6.24235000 0.03216900 54 | 29 6.24716608 6.26594118 6.21237500 0.01438000 55 | 29 4.92470387 1.31834473 6.22082500 0.02397100 56 | 29 2.35765735 5.97984578 8.32525000 -0.03691300 57 | 29 4.64167024 -2.57470236 8.31880000 -0.04081900 58 | 29 5.90788988 2.37836639 8.27452500 -0.04263300 59 | 29 1.98202596 7.39702707 0.00000000 -0.02895400 60 | 29 4.27067464 -1.14432363 0.00000000 -0.02896900 61 | 29 5.59202529 3.78702775 0.00000000 -0.02978000 62 | 29 3.12635030 3.12635172 0.00000000 -0.03012700 63 | 29 5.41499898 -5.41499898 0.00000000 -0.02997700 64 | 29 6.73634962 -0.48364760 0.00000000 -0.02981200 65 | 29 3.78702563 5.59202741 0.00000000 -0.03044000 66 | 29 6.07567430 -2.94932329 0.00000000 -0.02981900 67 | 29 7.39702495 1.98202809 0.00000000 -0.02829900 68 | 29 3.02414580 8.43913276 2.08422500 0.03113400 69 | 29 5.31279447 -0.10221794 2.08422500 0.03104300 70 | 29 6.63414512 4.82913344 2.08422500 0.03147500 71 | 29 4.16847014 4.16845741 2.08422500 0.03086100 72 | 29 6.45711881 -4.37289329 2.08422500 0.03091100 73 | 29 7.77846946 0.55845809 2.08422500 0.03237300 74 | 29 4.82914546 6.63413310 2.08422500 0.03099400 75 | 29 7.11779414 -1.90721760 2.08422500 0.03088700 76 | 29 8.43914478 3.02413378 2.08422500 0.03229500 77 | 29 3.40446106 7.01756741 4.16222500 -0.00367800 78 | 29 5.69862729 -1.52007169 4.16377500 -0.00377000 79 | 29 7.01904031 3.39643610 4.16667500 -0.00381900 80 | 29 4.54841841 2.74435849 4.15807500 -0.00408800 81 | 29 6.83715547 -5.79645551 4.16217500 -0.00383600 82 | 29 8.15904988 -0.86212085 4.16135000 -0.00416900 83 | 29 5.20773962 5.21031207 4.15560000 -0.00361100 84 | 29 7.49851881 -3.32991786 4.16110000 -0.00382300 85 | 29 8.82415806 1.60807537 4.17990000 -0.00378800 86 | 29 16.42134301 -7.04031432 6.24165000 0.03210600 87 | 29 4.26107496 -1.15370694 6.22505000 0.02825200 88 | 29 5.58591023 3.78996648 6.21950000 0.01901700 89 | 29 3.12164239 3.12825808 6.23872500 0.03160300 90 | 29 10.69627559 14.30805682 6.23962500 0.03173300 91 | 29 6.74071954 -0.47553709 6.25517500 0.01921400 92 | 29 3.78347524 5.59667805 6.23715000 0.02475300 93 | 29 6.06629948 -2.95540582 6.22790000 0.02980500 94 | 29 7.41062827 1.97943018 6.29980000 0.01484200 95 | 29 3.02126858 8.44438585 8.32512600 -0.03924300 96 | 29 6.59300989 4.86406734 8.26917600 -0.03799200 97 | 29 4.14939522 4.18295593 8.32582500 -0.02930200 98 | 29 6.44921336 -4.38832236 8.32109900 -0.03778700 99 | 29 7.75881047 0.51611512 8.40420000 0.05440600 100 | 29 4.81930536 6.64845059 8.32305000 -0.03373900 101 | 29 7.10202636 -1.95082558 8.28717500 -0.01830300 102 | 29 8.40852988 3.06410228 8.39732600 0.03561200 103 | 29 5.26806290 -0.11213441 8.27870000 -0.04419700 104 | 29 2.64270129 9.86270276 0.00000000 -0.02757700 105 | 29 7.21999864 -7.21999864 0.00000000 -0.02803600 106 | 29 9.86269993 2.64270411 0.00000000 -0.02726200 107 | 29 3.68482112 10.90480845 2.08422500 0.03130700 108 | 29 8.26211777 -6.17789225 2.08422500 0.03103200 109 | 29 10.90481905 3.68481051 2.08422500 0.03206200 110 | 29 4.06506072 9.48391838 4.16237500 -0.00362500 111 | 29 8.64287426 -7.60102455 4.16317500 -0.00376700 112 | 29 11.27462105 2.26678069 4.17002500 -0.00387600 113 | 29 17.08179489 -4.57526513 6.24150000 0.03207800 114 | 29 12.50344606 12.50548819 6.24095000 0.03236600 115 | 29 9.85801252 2.66477362 6.30835000 0.01455600 116 | 29 3.68418472 10.90667097 8.32590000 -0.03293800 117 | 29 8.26043061 -6.18323232 8.32372600 -0.02848700 118 | 29 10.94760821 3.71002523 8.41575000 0.03551200 119 | 29 3.30337661 12.32837845 0.00000000 -0.02784800 120 | 29 7.88067396 -4.75432295 0.00000000 -0.02966500 121 | 29 10.52337525 5.10837980 0.00000000 -0.02939000 122 | 29 4.44770095 8.05770310 0.00000000 -0.03034700 123 | 29 9.02499830 -9.02499830 0.00000000 -0.02878300 124 | 29 11.66769959 0.83770445 0.00000000 -0.02933300 125 | 29 5.10837627 10.52337879 0.00000000 -0.02899100 126 | 29 9.68567362 -6.55932261 0.00000000 -0.02901100 127 | 29 12.32837491 3.30338014 0.00000000 -0.02960200 128 | 29 4.34549715 13.37048484 2.08422500 0.03143000 129 | 29 8.92279309 -3.71221656 2.08422500 0.03086400 130 | 29 11.56549508 6.15048691 2.08422500 0.03121300 131 | 29 5.48982078 9.09980879 2.08422500 0.03109800 132 | 29 10.06711813 -7.98289261 2.08422500 0.03124900 133 | 29 12.70981871 1.87981085 2.08422500 0.03149600 134 | 29 6.15049681 11.56548518 2.08422500 0.03129900 135 | 29 10.72779346 -5.51721692 2.08422500 0.03095100 136 | 29 13.37049474 4.34548725 2.08422500 0.03097500 137 | 29 4.72548997 11.94874130 4.16260000 -0.00353100 138 | 29 9.30378646 -5.13658559 4.16235000 -0.00379200 139 | 29 11.94550204 4.72660366 4.16027500 -0.00421400 140 | 29 5.86873809 7.68074306 4.15427500 -0.00340400 141 | 29 10.44777281 -9.40619745 4.16332500 -0.00381100 142 | 29 13.08908180 0.45657107 4.15477500 -0.00331100 143 | 29 6.52967504 10.14808186 4.15892500 -0.00378200 144 | 29 11.10864188 -6.94041569 4.16292500 -0.00371900 145 | 29 13.74362721 2.92437586 4.16142500 -0.00358500 146 | 29 17.74305145 -2.11104608 6.24190000 0.03191200 147 | 29 7.87784836 -4.75687419 6.23895000 0.03285400 148 | 29 10.52935172 5.09795281 6.25560000 0.01049300 149 | 29 7.08599837 17.91715740 6.23985000 0.03168600 150 | 29 11.66854600 0.83608306 6.21682500 0.01663600 151 | 29 14.30881201 10.70117159 6.24215000 0.03208600 152 | 29 5.10600958 10.52787103 6.24147500 0.03196400 153 | 29 9.68482085 -6.55957010 6.24145000 0.03162000 154 | 29 12.31617449 3.30922792 6.25435000 0.01491100 155 | 29 4.34490388 13.36997785 8.32595000 -0.03047200 156 | 29 8.92228397 -3.72696398 8.32060000 -0.04051100 157 | 29 11.58086546 6.19329162 8.27430000 -0.03547700 158 | 29 5.48777300 9.10658428 8.32505000 -0.03980000 159 | 29 10.06656447 -7.98386488 8.32455000 -0.03118400 160 | 29 12.73820481 1.85762608 8.27170000 -0.03850100 161 | 29 6.14966313 11.56626795 8.32665000 -0.04190800 162 | 29 10.72886614 -5.52216596 8.32355000 -0.02895000 163 | 29 13.41336450 4.36416689 8.28522500 -0.03557000 164 | 29 3.96405264 14.79405484 0.00000000 -0.02938000 165 | 29 8.54134929 -2.28864726 0.00000000 -0.03045700 166 | 29 11.18405128 7.57405620 0.00000000 -0.02977600 167 | 29 6.25270061 6.25270344 0.00000000 -0.03034300 168 | 29 10.82999726 -10.82999726 0.00000000 -0.02981900 169 | 29 13.47269854 -0.96729450 0.00000000 -0.03102900 170 | 29 7.57405196 11.18405552 0.00000000 -0.02950600 171 | 29 12.15134790 -5.89864588 0.00000000 -0.02925800 172 | 29 14.79405060 3.96405688 0.00000000 -0.03010500 173 | 29 5.00617176 15.83616124 2.08422500 0.03120800 174 | 29 9.58346841 -1.24654087 2.08422500 0.03096200 175 | 29 12.22617040 8.61616260 2.08422500 0.03114400 176 | 29 7.29482044 7.29480913 2.08422500 0.03127500 177 | 29 11.87211709 -9.78789157 2.08422500 0.03118500 178 | 29 14.51481837 0.07481119 2.08422500 0.03101200 179 | 29 8.61617179 12.22616121 2.08422500 0.03129800 180 | 29 13.19346773 -4.85654019 2.08422500 0.03104000 181 | 29 15.83616902 5.00616398 2.08422500 0.03107000 182 | 29 5.38677906 14.41240173 4.16205000 -0.00365900 183 | 29 9.96347537 -2.67089788 4.16072500 -0.00411600 184 | 29 12.61028990 7.19389438 4.15707500 -0.00401200 185 | 29 7.67670548 5.86147045 4.15630000 -0.00341400 186 | 29 12.25303339 -11.21118367 4.16362500 -0.00383400 187 | 29 14.89714748 -1.34874255 4.16125000 -0.00385200 188 | 29 8.99490384 10.80825372 4.15927500 -0.00382200 189 | 29 13.57476304 -6.27950985 4.16295000 -0.00365600 190 | 29 16.21944069 3.58141736 4.15865000 -0.00334200 191 | 29 3.96423295 14.79274740 6.24205000 0.03188400 192 | 29 8.53940757 -2.29193248 6.23395000 0.02795000 193 | 29 11.18200067 7.58505878 6.21370000 0.01449900 194 | 29 8.88871761 16.11013613 6.23880000 0.03154500 195 | 29 13.47718372 -0.96777745 6.23410000 0.02620300 196 | 29 16.11418997 8.89621931 6.24307500 0.03217900 197 | 29 7.57084806 11.18890839 6.24155000 0.02863200 198 | 29 12.15149427 -5.89871730 6.24117500 0.03163100 199 | 29 14.80130764 3.96127583 6.22467500 0.02355600 200 | 29 5.00389841 15.83450732 8.32530000 -0.03163700 201 | 29 9.59907214 -1.28185307 8.30125000 -0.01486100 202 | 29 12.22728339 8.62810280 8.32300000 -0.03488100 203 | 29 7.28589958 7.31638578 8.29017500 -0.05837700 204 | 29 11.87077641 -9.78640099 8.32557500 -0.03067100 205 | 29 14.52495121 0.07053178 8.32112500 -0.04380300 206 | 29 8.61267302 12.22513166 8.32625000 -0.04225200 207 | 29 13.19559329 -4.85906597 8.32372600 -0.03100300 208 | 29 15.84689441 5.00921868 8.32482500 -0.04039300 209 | 29 4.62472725 17.25972982 0.00000000 -0.02836800 210 | 29 9.20202461 0.17702843 0.00000000 -0.02869300 211 | 29 11.84472589 10.03973118 0.00000000 -0.03044700 212 | 29 6.91337593 8.71837912 0.00000000 -0.03027000 213 | 29 11.49067258 -8.36432157 0.00000000 -0.02925000 214 | 29 14.13337457 1.49838048 0.00000000 -0.03069600 215 | 29 8.23472657 13.64973050 0.00000000 -0.02849900 216 | 29 12.81202393 -3.43297090 0.00000000 -0.03038800 217 | 29 15.45472521 6.42973186 0.00000000 -0.02986600 218 | 29 5.76905300 12.98905447 0.00000000 -0.02782600 219 | 29 10.34634895 -4.09364692 0.00000000 -0.02973200 220 | 29 12.98905165 5.76905583 0.00000000 -0.02952700 221 | 29 8.05770027 4.44770377 0.00000000 -0.02866400 222 | 29 12.63499833 -12.63499833 0.00000000 -0.02869900 223 | 29 15.27769891 -2.77229487 0.00000000 -0.02897500 224 | 29 9.37905162 9.37905586 0.00000000 -0.03047200 225 | 29 13.95634827 -7.70364625 0.00000000 -0.02895200 226 | 29 16.59905097 2.15905651 0.00000000 -0.02924000 227 | 29 6.42972762 15.45472946 0.00000000 -0.02868000 228 | 29 11.00702497 -1.62797194 0.00000000 -0.03013500 229 | 29 13.64972555 8.23473152 0.00000000 -0.02982600 230 | 29 8.71837559 6.91337946 0.00000000 -0.03040700 231 | 29 13.29567365 -10.16932264 0.00000000 -0.02839300 232 | 29 15.93837494 -0.30661989 0.00000000 -0.03039400 233 | 29 10.03972623 11.84473084 0.00000000 -0.02765100 234 | 29 14.61702430 -5.23797126 0.00000000 -0.02802300 235 | 29 17.25972558 4.62473149 0.00000000 -0.02980000 236 | 29 5.66684850 18.30183693 2.08422500 0.03099800 237 | 29 10.24414373 1.21913482 2.08422500 0.03135200 238 | 29 12.88684643 11.08183899 2.08422500 0.03130400 239 | 29 7.95549647 9.76048552 2.08422500 0.03124400 240 | 29 12.53279241 -7.32221588 2.08422500 0.03129900 241 | 29 15.17549440 2.54048759 2.08422500 0.03109900 242 | 29 9.27684782 14.69183761 2.08422500 0.03110100 243 | 29 13.85414305 -2.39086450 2.08422500 0.03086800 244 | 29 16.49684505 7.47184038 2.08422500 0.03123600 245 | 29 6.81117213 14.03116087 2.08422500 0.03111000 246 | 29 11.38846878 -3.05154123 2.08422500 0.03074900 247 | 29 14.03117006 6.81116294 2.08422500 0.03092000 248 | 29 9.09982010 5.48980946 2.08422500 0.03137800 249 | 29 13.67711746 -11.59289193 2.08422500 0.03141000 250 | 29 16.31981874 -1.73018918 2.08422500 0.03098300 251 | 29 10.42117145 10.42116155 2.08422500 0.03100000 252 | 29 14.99846810 -6.66154056 2.08422500 0.03126500 253 | 29 17.64117009 3.20116291 2.08422500 0.03111400 254 | 29 7.47184816 16.49683727 2.08422500 0.03103200 255 | 29 12.04914339 -0.58586484 2.08422500 0.03126100 256 | 29 14.69184609 9.27683933 2.08422500 0.03119400 257 | 29 9.76049613 7.95548586 2.08422500 0.03138200 258 | 29 14.33779207 -9.12721554 2.08422500 0.03123800 259 | 29 16.98049477 0.73548722 2.08422500 0.03100100 260 | 29 11.08184748 12.88683795 2.08422500 0.03094600 261 | 29 15.65914342 -4.19586487 2.08422500 0.03125300 262 | 29 18.30184612 5.66683930 2.08422500 0.03098800 263 | 29 6.04750529 16.87745093 4.16257500 -0.00368200 264 | 29 10.61476668 -0.20109198 4.16605000 -0.00368800 265 | 29 13.26778183 9.66012586 4.16235000 -0.00367600 266 | 29 8.33493279 8.34097290 4.15082500 -0.00309500 267 | 29 12.91375114 -8.74555041 4.16302500 -0.00367800 268 | 29 15.55917125 1.11643463 4.15762500 -0.00349500 269 | 29 9.65667588 13.26995759 4.16282500 -0.00365400 270 | 29 14.23556989 -3.81361638 4.16277500 -0.00377100 271 | 29 16.87950295 6.04835523 4.16242500 -0.00382000 272 | 29 7.19045855 12.61019727 4.16217500 -0.00377300 273 | 29 11.76991470 -4.47503769 4.16262500 -0.00402000 274 | 29 14.41367382 5.38833257 4.15905000 -0.00401900 275 | 29 9.48185009 4.06275201 4.16852500 -0.00445800 276 | 29 14.05820771 -13.01648244 4.16315000 -0.00367000 277 | 29 16.70091606 -3.15293681 4.16295000 -0.00364500 278 | 29 10.80179925 9.00205976 4.15412500 -0.00321600 279 | 29 15.37977684 -8.08424930 4.16317500 -0.00373000 280 | 29 18.02462137 1.77661003 4.16130000 -0.00408300 281 | 29 7.85095427 15.07240248 4.16170000 -0.00384500 282 | 29 12.43011980 -2.00961586 4.15967500 -0.00361700 283 | 29 15.07458744 7.85414473 4.16222500 -0.00362200 284 | 29 10.14464179 6.52313642 4.15852500 -0.00356900 285 | 29 14.71879252 -10.55039097 4.16362500 -0.00390200 286 | 29 17.36276235 -0.68815491 4.16200000 -0.00390300 287 | 29 11.46173918 11.46519411 4.16232500 -0.00346300 288 | 29 16.04032277 -5.61844421 4.16292500 -0.00365200 289 | 29 18.68315275 4.24255301 4.16302500 -0.00372300 290 | 29 4.62438784 17.25811903 6.24155000 0.03189400 291 | 29 9.18023794 0.17880680 6.26080000 0.01943000 292 | 29 11.84169948 10.04305741 6.23917500 0.02885500 293 | 29 6.90868003 8.73235368 6.22480000 0.01876000 294 | 29 11.48985728 -8.36388104 6.24180000 0.03162300 295 | 29 14.14232018 1.49718901 6.21882500 0.01747300 296 | 29 8.23100790 13.64897178 6.24172500 0.03199100 297 | 29 12.81348127 -3.43245259 6.24015000 0.03192100 298 | 29 15.45657995 6.43050331 6.24060000 0.03097800 299 | 29 5.76694088 12.98971562 6.24202500 0.03203400 300 | 29 10.34690120 -4.09534893 6.23977500 0.03278700 301 | 29 12.99398301 5.77277663 6.22587500 0.01807000 302 | 29 8.05499205 4.43725557 6.24925000 0.01057600 303 | 29 17.91972562 7.09095943 6.24175000 0.03200600 304 | 29 15.27881826 -2.77173837 6.24135000 0.03216200 305 | 29 9.37453957 9.39166994 6.22437500 0.01857900 306 | 29 13.95550045 -7.70257356 6.24157500 0.03164500 307 | 29 16.60531310 2.15687155 6.23620000 0.02543500 308 | 29 6.42656049 15.45234297 6.24130000 0.03169500 309 | 29 11.00917387 -1.62944485 6.23085000 0.03000900 310 | 29 13.65017315 8.23858737 6.23812500 0.02447500 311 | 29 8.71404880 6.92360776 6.20062500 0.01597500 312 | 29 13.29437187 -10.16817076 6.24207500 0.03148300 313 | 29 15.94208301 -0.30740195 6.23877500 0.02997200 314 | 29 10.03667648 11.84735633 6.24102500 0.03196900 315 | 29 14.61705541 -5.23710152 6.24105000 0.03159000 316 | 29 17.26068371 4.62380023 6.24052500 0.03226400 317 | 29 5.66110679 18.29974813 8.32407500 -0.02962500 318 | 29 10.27593596 1.18704278 8.27100000 -0.03830800 319 | 29 12.88528797 11.08319522 8.32480000 -0.03957900 320 | 29 9.27137835 14.68762537 8.32402500 -0.03021200 321 | 29 13.85890966 -2.39248024 8.32370000 -0.03633000 322 | 29 16.49721274 7.47457123 8.32510000 -0.03657300 323 | 29 6.80800924 14.02947159 8.32585000 -0.03075600 324 | 29 11.39612533 -3.06259897 8.32215000 -0.03987900 325 | 29 14.03578676 6.82237694 8.32582500 -0.03078900 326 | 29 9.09043326 5.54989304 8.25950000 -0.04325700 327 | 29 13.67322059 -11.59069495 8.32440000 -0.02912500 328 | 29 16.32238271 -1.73080295 8.32522500 -0.03494300 329 | 29 10.41802977 10.42515317 8.32490100 -0.03949400 330 | 29 14.99667063 -6.65984350 8.32477500 -0.03089300 331 | 29 17.64148193 3.19980031 8.32460000 -0.03706300 332 | 29 7.46557612 16.49290434 8.32402500 -0.03158800 333 | 29 12.06095844 -0.59910541 8.31640000 -0.04919700 334 | 29 14.69132354 9.28001071 8.32667500 -0.03659800 335 | 29 9.75757578 7.97916404 8.28892500 -0.05945100 336 | 29 14.33401612 -9.12491461 8.32505000 -0.03048500 337 | 29 16.98139633 0.73523478 8.32602500 -0.03554700 338 | 29 11.07787212 12.88300684 8.32472500 -0.03331000 339 | 29 15.65948354 -4.19527867 8.32467600 -0.02985900 340 | 29 18.30113689 5.66782006 8.32427500 -0.03098100 341 | 29 7.95388214 9.76279988 8.32640000 -0.03309400 342 | 29 12.53213551 -7.32268469 8.32510000 -0.03058200 343 | 29 15.18286882 2.53903872 8.31472500 -0.03858900 344 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: graph-afm 2 | channels: 3 | - pytorch 4 | - nvidia/label/cuda-11.3.1 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.8 9 | - pip 10 | - pytorch=1.10 11 | - cuda 12 | - ninja 13 | - torchvision 14 | - matplotlib>=3.3.0 15 | - numpy==1.21.4 16 | - scipy 17 | - pyopencl 18 | - ocl-icd-system 19 | - scikit-image 20 | - gxx=8.5.0 21 | - h5py 22 | -------------------------------------------------------------------------------- /environment_exact.yml: -------------------------------------------------------------------------------- 1 | name: ml_pt110 2 | channels: 3 | - pytorch 4 | - nvidia/label/cuda-11.3.1 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=1_llvm 10 | - alsa-lib=1.2.3=h516909a_0 11 | - appdirs=1.4.4=pyh9f0ad1d_0 12 | - attrs=21.2.0=pyhd8ed1ab_0 13 | - binutils_impl_linux-64=2.36.1=h193b22a_2 14 | - blas=2.112=mkl 15 | - blas-devel=3.9.0=12_linux64_mkl 16 | - blosc=1.21.0=h9c3ff4c_0 17 | - brotli=1.0.9=h7f98852_6 18 | - brotli-bin=1.0.9=h7f98852_6 19 | - brotlipy=0.7.0=py38h497a2fe_1003 20 | - brunsli=0.1=h9c3ff4c_0 21 | - bzip2=1.0.8=h7f98852_4 22 | - c-ares=1.18.1=h7f98852_0 23 | - c-blosc2=2.0.4=h5f21a17_1 24 | - ca-certificates=2021.10.8=ha878542_0 25 | - certifi=2021.10.8=py38h578d9bd_1 26 | - cffi=1.15.0=py38h3931269_0 27 | - cfitsio=4.0.0=h9a35b8e_0 28 | - chardet=4.0.0=py38h578d9bd_2 29 | - charls=2.2.0=h9c3ff4c_0 30 | - cloudpickle=2.0.0=pyhd8ed1ab_0 31 | - cryptography=35.0.0=py38h3e25421_2 32 | - cuda=11.3.1=h712c49d_0 33 | - cuda-command-line-tools=11.3.1=h712c49d_0 34 | - cuda-compiler=11.3.1=h712c49d_0 35 | - cuda-cudart=11.3.109=hfb95d0c_0 36 | - cuda-cuobjdump=11.3.122=hbf6ec6b_0 37 | - cuda-cupti=11.3.111=h12ad217_0 38 | - cuda-cuxxfilt=11.3.122=h4dc11a3_0 39 | - cuda-gdb=11.3.109=h33b7820_0 40 | - cuda-libraries=11.3.1=h712c49d_0 41 | - cuda-libraries-dev=11.3.1=h712c49d_0 42 | - cuda-memcheck=11.3.109=hf5cb439_0 43 | - cuda-nvcc=11.3.122=h4814707_0 44 | - cuda-nvdisasm=11.3.122=ha26faa6_0 45 | - cuda-nvml-dev=11.3.58=hc25e488_0 46 | - cuda-nvprof=11.3.111=h95a27d4_0 47 | - cuda-nvprune=11.3.122=hb3346b8_0 48 | - cuda-nvrtc=11.3.122=h1aa17d8_0 49 | - cuda-nvtx=11.3.109=h4ec7630_0 50 | - cuda-nvvp=11.3.111=h4c4416a_0 51 | - cuda-runtime=11.3.1=h712c49d_0 52 | - cuda-samples=11.3.58=h6d5b628_0 53 | - cuda-sanitizer-api=11.3.111=h2446cfc_0 54 | - cuda-thrust=11.3.109=he8b717c_0 55 | - cuda-toolkit=11.3.1=h712c49d_0 56 | - cuda-tools=11.3.1=h712c49d_0 57 | - cuda-visual-tools=11.3.1=h712c49d_0 58 | - cudatoolkit=11.3.1=ha36c431_9 59 | - cycler=0.11.0=pyhd8ed1ab_0 60 | - cytoolz=0.11.2=py38h497a2fe_1 61 | - dask-core=2021.11.1=pyhd8ed1ab_0 62 | - dbus=1.13.6=h48d8840_2 63 | - expat=2.4.1=h9c3ff4c_0 64 | - ffmpeg=4.3=hf484d3e_0 65 | - flake8=4.0.1=pyhd8ed1ab_1 66 | - fontconfig=2.13.1=hba837de_1005 67 | - freetype=2.10.4=h0708190_1 68 | - fsspec=2021.11.0=pyhd8ed1ab_0 69 | - gcc=8.5.0=h143be6b_1 70 | - gcc_impl_linux-64=8.5.0=hb55b52c_11 71 | - gettext=0.19.8.1=h73d1719_1008 72 | - giflib=5.2.1=h36c2ea0_2 73 | - glib=2.70.0=h780b84a_1 74 | - glib-tools=2.70.0=h780b84a_1 75 | - gmp=6.2.1=h58526e2_0 76 | - gnutls=3.6.13=h85f3911_1 77 | - gst-plugins-base=1.18.5=hf529b03_1 78 | - gstreamer=1.18.5=h9f60fe5_1 79 | - gxx=8.5.0=h143be6b_1 80 | - gxx_impl_linux-64=8.5.0=hb55b52c_11 81 | - icu=68.2=h9c3ff4c_0 82 | - idna=2.10=pyh9f0ad1d_0 83 | - imagecodecs=2021.8.26=py38h678ac2f_2 84 | - imageio=2.9.0=py_0 85 | - importlib-metadata=4.2.0=py38h578d9bd_0 86 | - iniconfig=1.1.1=pyh9f0ad1d_0 87 | - jbig=2.1=h7f98852_2003 88 | - jpeg=9d=h36c2ea0_0 89 | - jxrlib=1.1=h7f98852_2 90 | - kernel-headers_linux-64=2.6.32=he073ed8_15 91 | - kiwisolver=1.3.2=py38h1fd1430_1 92 | - krb5=1.19.2=hcc1bbae_3 93 | - lame=3.100=h7f98852_1001 94 | - lcms2=2.12=hddcbb42_0 95 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 96 | - lerc=3.0=h9c3ff4c_0 97 | - libaec=1.0.6=h9c3ff4c_0 98 | - libblas=3.9.0=12_linux64_mkl 99 | - libbrotlicommon=1.0.9=h7f98852_6 100 | - libbrotlidec=1.0.9=h7f98852_6 101 | - libbrotlienc=1.0.9=h7f98852_6 102 | - libcblas=3.9.0=12_linux64_mkl 103 | - libcublas=11.5.1.109=h0fd73e7_0 104 | - libcufft=10.4.2.109=h2344711_0 105 | - libcurand=10.2.4.109=h0189693_0 106 | - libcurl=7.79.1=h2574ce0_1 107 | - libcusolver=11.1.2.109=h1e009e5_0 108 | - libcusparse=11.6.0.109=hf5bfba9_0 109 | - libdeflate=1.8=h7f98852_0 110 | - libedit=3.1.20191231=he28a2e2_2 111 | - libev=4.33=h516909a_1 112 | - libevent=2.1.10=h9b69904_4 113 | - libffi=3.4.2=h9c3ff4c_4 114 | - libgcc-devel_linux-64=8.5.0=h82e8279_11 115 | - libgcc-ng=11.2.0=h1d223b6_11 116 | - libgfortran-ng=11.2.0=h69a702a_11 117 | - libgfortran5=11.2.0=h5c6108e_11 118 | - libglib=2.70.0=h174f98d_1 119 | - libgomp=11.2.0=h1d223b6_11 120 | - libiconv=1.16=h516909a_0 121 | - liblapack=3.9.0=12_linux64_mkl 122 | - liblapacke=3.9.0=12_linux64_mkl 123 | - libllvm11=11.1.0=hf817b99_2 124 | - libnghttp2=1.43.0=h812cca2_1 125 | - libnpp=11.3.3.95=h122bb27_0 126 | - libnsl=2.0.0=h7f98852_0 127 | - libnvjpeg=11.5.0.109=h159916b_0 128 | - libogg=1.3.4=h7f98852_1 129 | - libopus=1.3.1=h7f98852_1 130 | - libpng=1.6.37=h21135ba_2 131 | - libpq=13.3=hd57d9b9_3 132 | - libsanitizer=8.5.0=h70fd0c9_11 133 | - libssh2=1.10.0=ha56f1ee_2 134 | - libstdcxx-devel_linux-64=8.5.0=h82e8279_11 135 | - libstdcxx-ng=11.2.0=he4da1e4_11 136 | - libtiff=4.3.0=h6f004c6_2 137 | - libuuid=2.32.1=h7f98852_1000 138 | - libuv=1.42.0=h7f98852_0 139 | - libvorbis=1.3.7=h9c3ff4c_0 140 | - libwebp-base=1.2.1=h7f98852_0 141 | - libxcb=1.13=h7f98852_1003 142 | - libxkbcommon=1.0.3=he3ba5ed_0 143 | - libxml2=2.9.12=h72842e0_0 144 | - libzlib=1.2.11=h36c2ea0_1013 145 | - libzopfli=1.0.3=h9c3ff4c_0 146 | - llvm-openmp=12.0.1=h4bd325d_1 147 | - locket=0.2.0=py_2 148 | - lz4-c=1.9.3=h9c3ff4c_1 149 | - mako=1.1.5=pyhd8ed1ab_0 150 | - markupsafe=2.0.1=py38h497a2fe_1 151 | - matplotlib=3.4.3=py38h578d9bd_1 152 | - matplotlib-base=3.4.3=py38hf4fb855_1 153 | - mccabe=0.6.1=py_1 154 | - mkl=2021.4.0=h8d4b97c_729 155 | - mkl-devel=2021.4.0=ha770c72_730 156 | - mkl-include=2021.4.0=h8d4b97c_729 157 | - more-itertools=8.10.0=pyhd8ed1ab_0 158 | - mysql-common=8.0.27=ha770c72_1 159 | - mysql-libs=8.0.27=hfa10184_1 160 | - ncurses=6.2=h58526e2_4 161 | - nettle=3.6=he412f7d_0 162 | - networkx=2.6.3=pyhd8ed1ab_1 163 | - ninja=1.10.2=h4bd325d_1 164 | - nspr=4.32=h9c3ff4c_1 165 | - nss=3.71=hb5efdd6_0 166 | - numpy=1.21.4=py38he2449b9_0 167 | - ocl-icd=2.3.1=h7f98852_0 168 | - ocl-icd-system=1.0.0=1 169 | - olefile=0.46=pyh9f0ad1d_1 170 | - openh264=2.1.1=h780b84a_0 171 | - openjpeg=2.4.0=hb52868f_1 172 | - openssl=1.1.1l=h7f98852_0 173 | - packaging=21.0=pyhd8ed1ab_0 174 | - pandas=1.3.4=py38h43a58ef_1 175 | - partd=1.2.0=pyhd8ed1ab_0 176 | - pcre=8.45=h9c3ff4c_0 177 | - pillow=8.3.2=py38h8e6f84c_0 178 | - pip=21.3.1=pyhd8ed1ab_0 179 | - pluggy=1.0.0=py38h578d9bd_2 180 | - pooch=1.5.2=pyhd8ed1ab_0 181 | - pthread-stubs=0.4=h36c2ea0_1001 182 | - py=1.11.0=pyh6c4a22f_0 183 | - pycodestyle=2.8.0=pyhd8ed1ab_0 184 | - pycparser=2.21=pyhd8ed1ab_0 185 | - pyflakes=2.4.0=pyhd8ed1ab_0 186 | - pyopencl=2021.2.9=py38h2b96118_1 187 | - pyopenssl=21.0.0=pyhd8ed1ab_0 188 | - pyparsing=3.0.5=pyhd8ed1ab_0 189 | - pyqt=5.12.3=py38h578d9bd_7 190 | - pyqt-impl=5.12.3=py38h7400c14_7 191 | - pyqt5-sip=4.19.18=py38h709712a_7 192 | - pyqtchart=5.12=py38h7400c14_7 193 | - pyqtwebengine=5.12.1=py38h7400c14_7 194 | - pysocks=1.7.1=py38h578d9bd_4 195 | - pytest=6.2.5=py38h578d9bd_1 196 | - python=3.8.12=hb7a2778_2_cpython 197 | - python-dateutil=2.8.2=pyhd8ed1ab_0 198 | - python_abi=3.8=2_cp38 199 | - pytools=2021.2.9=pyhd8ed1ab_0 200 | - pytorch=1.10.0=py3.8_cuda11.3_cudnn8.2.0_0 201 | - pytorch-mutex=1.0=cuda 202 | - pytz=2021.3=pyhd8ed1ab_0 203 | - pywavelets=1.1.1=py38h6c62de6_4 204 | - pyyaml=6.0=py38h497a2fe_3 205 | - qt=5.12.9=hda022c4_4 206 | - readline=8.1=h46c0cb4_0 207 | - requests=2.25.1=pyhd3deb0d_0 208 | - scikit-image=0.18.3=py38h43a58ef_0 209 | - scipy=1.7.2=py38h56a6a73_0 210 | - setuptools=58.5.3=py38h578d9bd_0 211 | - six=1.16.0=pyh6c4a22f_0 212 | - snappy=1.1.8=he1b5a44_3 213 | - sqlite=3.36.0=h9cd32fc_2 214 | - sysroot_linux-64=2.12=he073ed8_15 215 | - tbb=2021.4.0=h4bd325d_1 216 | - tifffile=2021.11.2=pyhd8ed1ab_0 217 | - tk=8.6.11=h27826a3_1 218 | - toml=0.10.2=pyhd8ed1ab_0 219 | - toolz=0.11.2=pyhd8ed1ab_0 220 | - torchvision=0.11.1=py38_cu113 221 | - tornado=6.1=py38h497a2fe_2 222 | - typing_extensions=3.10.0.2=pyha770c72_0 223 | - urllib3=1.26.7=pyhd8ed1ab_0 224 | - wheel=0.37.0=pyhd8ed1ab_1 225 | - xorg-libxau=1.0.9=h7f98852_0 226 | - xorg-libxdmcp=1.1.3=h7f98852_0 227 | - xz=5.2.5=h516909a_1 228 | - yaml=0.2.5=h516909a_0 229 | - zfp=0.5.5=h9c3ff4c_7 230 | - zipp=3.6.0=pyhd8ed1ab_0 231 | - zlib=1.2.11=h36c2ea0_1013 232 | - zstd=1.5.0=ha95c52a_0 233 | - pip: 234 | - absl-py==0.15.0 235 | - astunparse==1.6.3 236 | - cachetools==4.2.4 237 | - flatbuffers==2.0 238 | - gast==0.4.0 239 | - google-auth==2.3.3 240 | - google-auth-oauthlib==0.4.6 241 | - google-pasta==0.2.0 242 | - grpcio==1.41.1 243 | - h5py==3.5.0 244 | - keras==2.7.0 245 | - keras-preprocessing==1.1.2 246 | - libclang==12.0.0 247 | - markdown==3.3.4 248 | - oauthlib==3.1.1 249 | - opt-einsum==3.3.0 250 | - protobuf==3.19.1 251 | - pyasn1==0.4.8 252 | - pyasn1-modules==0.2.8 253 | - requests-oauthlib==1.3.0 254 | - rsa==4.7.2 255 | - tensorboard==2.7.0 256 | - tensorboard-data-server==0.6.1 257 | - tensorboard-plugin-wit==1.8.0 258 | - tensorflow==2.7.0 259 | - tensorflow-estimator==2.7.0 260 | - tensorflow-io-gcs-filesystem==0.21.0 261 | - termcolor==1.1.0 262 | - werkzeug==2.0.2 263 | - wrapt==1.13.3 264 | -------------------------------------------------------------------------------- /model_schem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SINGROUP/Graph-AFM/37f4c68b2f3d06c68a04e0194e13034c415f3ab4/model_schem.png -------------------------------------------------------------------------------- /pretrained_weights/model_random.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SINGROUP/Graph-AFM/37f4c68b2f3d06c68a04e0194e13034c415f3ab4/pretrained_weights/model_random.pth -------------------------------------------------------------------------------- /pretrained_weights/model_y.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SINGROUP/Graph-AFM/37f4c68b2f3d06c68a04e0194e13034c415f3ab4/pretrained_weights/model_y.pth -------------------------------------------------------------------------------- /pretrained_weights/model_z.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SINGROUP/Graph-AFM/37f4c68b2f3d06c68a04e0194e13034c415f3ab4/pretrained_weights/model_z.pth -------------------------------------------------------------------------------- /scripts/generate_data.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os 4 | import sys 5 | import time 6 | import glob 7 | import h5py 8 | import random 9 | import numpy as np 10 | 11 | sys.path.append('../ProbeParticleModel') # Make sure ProbeParticleModel is on PATH 12 | from pyProbeParticle import oclUtils as oclu 13 | from pyProbeParticle import fieldOCL as FFcl 14 | from pyProbeParticle import RelaxOpenCL as oclr 15 | from pyProbeParticle import AuxMap as aux 16 | from pyProbeParticle.AFMulatorOCL_Simple import AFMulator 17 | from pyProbeParticle.GeneratorOCL_Simple2 import InverseAFMtrainer 18 | 19 | sys.path.append('../src') 20 | from utils import download_molecules 21 | 22 | # Set random seeds for reproducibility 23 | random.seed(0) 24 | np.random.seed(0) 25 | 26 | def pad_xyzs(xyzs, max_len): 27 | xyzs_padded = [np.pad(xyz, ((0, max_len - len(xyz)), (0, 0))) for xyz in xyzs] 28 | xyzs = np.stack(xyzs_padded, axis=0) 29 | return xyzs 30 | 31 | class Trainer(InverseAFMtrainer): 32 | 33 | # Override to randomize tip distance for each tip independently 34 | def handle_distance(self): 35 | self.randomize_distance(delta=0.25) 36 | self.randomize_tip(max_tilt=0.5) 37 | super().handle_distance() 38 | 39 | # Options 40 | molecules_dir = './Molecules/' # Where to save molecule database 41 | save_path = './graph_dataset.hdf5' # Where to save training data 42 | 43 | # Initialize OpenCL environment on GPU 44 | env = oclu.OCLEnvironment( i_platform = 0 ) 45 | FFcl.init(env) 46 | oclr.init(env) 47 | 48 | afmulator_args = { 49 | 'pixPerAngstrome' : 20, 50 | 'lvec' : np.array([ 51 | [ 0.0, 0.0, 0.0], 52 | [20.0, 0.0, 0.0], 53 | [ 0.0, 20.0, 0.0], 54 | [ 0.0, 0.0, 5.0] 55 | ]), 56 | 'scan_dim' : (128, 128, 20), 57 | 'scan_window' : ((2.0, 2.0, 6.0), (18.0, 18.0, 8.0)), 58 | 'amplitude' : 1.0, 59 | 'df_steps' : 10, 60 | 'initFF' : True 61 | } 62 | 63 | generator_kwargs = { 64 | 'batch_size' : 30, 65 | 'distAbove' : 5.3, 66 | 'iZPPs' : [8], 67 | 'Qs' : [[ -10, 20, -10, 0 ]], 68 | 'QZs' : [[ 0.1, 0, -0.1, 0 ]] 69 | } 70 | 71 | # Define AFMulator 72 | afmulator = AFMulator(**afmulator_args) 73 | afmulator.npbc = (0,0,0) 74 | 75 | # Define AuxMaps 76 | aux_maps = [] 77 | 78 | # Download molecules if not already there 79 | download_molecules(molecules_dir, verbose=1) 80 | 81 | # Paths to molecule xyz files 82 | train_paths = glob.glob(os.path.join(molecules_dir, 'train/*.xyz')) 83 | val_paths = glob.glob(os.path.join(molecules_dir, 'validation/*.xyz')) 84 | test_paths = glob.glob(os.path.join(molecules_dir, 'test/*.xyz')) 85 | 86 | with h5py.File(save_path, 'w') as f: 87 | 88 | start_time = time.time() 89 | counter = 1 90 | total_len = np.floor((len(train_paths)+len(val_paths)+len(test_paths))/generator_kwargs['batch_size']) 91 | for mode, paths in zip(['train', 'val', 'test'], [train_paths, val_paths, test_paths]): 92 | 93 | # Define generator 94 | trainer = Trainer(afmulator, aux_maps, paths, **generator_kwargs) 95 | 96 | # Shuffle 97 | trainer.shuffle_molecules() 98 | 99 | # Calculate dataset shapes 100 | n_mol = len(trainer.molecules) 101 | max_mol_len = max([len(m) for m in trainer.molecules]) 102 | X_shape = ( 103 | n_mol, # Number of samples 104 | len(trainer.iZPPs), # Number of tips 105 | afmulator.scan_dim[0], # x size 106 | afmulator.scan_dim[1], # y size 107 | afmulator.scan_dim[2] - afmulator.scanner.nDimConvOut # z size 108 | ) 109 | Y_shape = (n_mol, len(aux_maps)) + X_shape[2:4] 110 | xyz_shape = (n_mol, max_mol_len, 5) 111 | 112 | # Create new group in HDF5 file and add datasets to the group 113 | g = f.create_group(mode) 114 | X_h5 = g.create_dataset('X', shape=X_shape, chunks=(1,)+X_shape[1:], dtype='f') 115 | if len(aux_maps) > 0: 116 | Y_h5 = g.create_dataset('Y', shape=Y_shape, chunks=(1,)+Y_shape[1:], dtype='f') 117 | xyz_h5 = g.create_dataset('xyz', shape=xyz_shape, chunks=(1,)+xyz_shape[1:], dtype='f') 118 | 119 | # Generate data 120 | ind = 0 121 | for i, (X, Y, xyz) in enumerate(trainer): 122 | 123 | # Write batch to the HDF5 file 124 | n_batch = len(xyz) 125 | X_h5[ind:ind+n_batch] = np.stack(X, axis=1) 126 | if len(aux_maps) > 0: 127 | Y_h5[ind:ind+n_batch] = np.stack(Y, axis=1) 128 | xyz_h5[ind:ind+n_batch] = pad_xyzs(xyz, max_mol_len) 129 | ind += n_batch 130 | 131 | # Print progress info 132 | eta = (time.time() - start_time)/counter * (total_len - counter) 133 | print(f'Generated {mode} batch {i+1}/{len(trainer)} - ETA: {eta:.1f}s') 134 | counter += 1 135 | 136 | print(f'Total time taken: {time.time() - start_time:.1f}s') 137 | -------------------------------------------------------------------------------- /scripts/predict_examples.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import time 5 | import string 6 | import random 7 | import imageio 8 | import numpy as np 9 | import matplotlib 10 | import matplotlib.pyplot as plt 11 | 12 | import torch 13 | from torch import nn, optim 14 | 15 | sys.path.append('../ProbeParticleModel') # Make sure ProbeParticleModel is on PATH 16 | from pyProbeParticle import oclUtils as oclu 17 | from pyProbeParticle import fieldOCL as FFcl 18 | from pyProbeParticle import RelaxOpenCL as oclr 19 | from pyProbeParticle.AFMulatorOCL_Simple import AFMulator 20 | from pyProbeParticle.GeneratorOCL_Simple2 import InverseAFMtrainer 21 | 22 | sys.path.append('../src') # Add source code directory to Python PATH 23 | import utils 24 | import preprocessing as pp 25 | from models import load_pretrained_model 26 | 27 | # # Set matplotlib font rendering to use LaTex 28 | # plt.rcParams.update({ 29 | # "text.usetex": True, 30 | # "font.family": "serif", 31 | # "font.serif": ["Computer Modern Roman"] 32 | # }) 33 | 34 | # Set random seeds for reproducibility 35 | random.seed(0) 36 | np.random.seed(0) 37 | torch.manual_seed(0) 38 | 39 | classes = [[1], [6, 14], [7, 15], [8, 16], [9, 17, 35]] # List of elements in each class 40 | box_res = (0.125, 0.125, 0.1) # Real-space voxel size for position distribution 41 | zmin = -0.8 # Maximum depth used for thresholding 42 | peak_std = 0.25 # Standard deviation of atom position peaks in angstroms 43 | sequence_order = None # Order for graph construction 44 | device = 'cuda' # Device to run inference on 45 | base_dir = '../data' # Base directory for molecule data 46 | class_colors = ['w', 'dimgray', 'b', 'r', 'yellowgreen'] # Colors for classes 47 | afm_slices = [0, 5, 9] 48 | marker_size = 25 49 | z_min_marker = -1.0 50 | z_max_marker = 0.0 51 | 52 | def predict(model, molecules, box_borders): 53 | 54 | scan_dim = ( 55 | int((box_borders[1][0] - box_borders[0][0]) / box_res[0]), 56 | int((box_borders[1][1] - box_borders[0][1]) / box_res[1]), 57 | 20 58 | ) 59 | print(scan_dim) 60 | afmulator_args = { 61 | 'pixPerAngstrome' : 20, 62 | 'lvec' : np.array([ 63 | [ 0.0, 0.0, 0.0], 64 | [box_borders[1][0]+2.0, 0.0, 0.0], 65 | [ 0.0, box_borders[1][1]+2.0, 0.0], 66 | [ 0.0, 0.0, 6.0] 67 | ]), 68 | 'scan_dim' : scan_dim, 69 | 'scan_window' : (box_borders[0][:2] + (7.0,), box_borders[1][:2] + (9.0,)), 70 | 'amplitude' : 1.0, 71 | 'df_steps' : 10, 72 | 'initFF' : True 73 | } 74 | 75 | generator_kwargs = { 76 | 'batch_size' : 30, 77 | 'distAbove' : 5.3, 78 | 'iZPPs' : [8], 79 | 'Qs' : [[ -10, 20, -10, 0 ]], 80 | 'QZs' : [[ 0.1, 0, -0.1, 0 ]] 81 | } 82 | 83 | # Define AFMulator 84 | afmulator = AFMulator(**afmulator_args) 85 | afmulator.npbc = (0,0,0) 86 | 87 | # Define generator 88 | trainer = InverseAFMtrainer(afmulator, [], molecules, **generator_kwargs) 89 | 90 | # Generate batch 91 | batch = next(iter(trainer)) 92 | X, ref_graphs, ref_dist, box_borders = apply_preprocessing(batch, box_borders) 93 | 94 | with torch.no_grad(): 95 | X_gpu = torch.from_numpy(X).unsqueeze(1).to(device) 96 | pred_graphs, pred_dist, pred_sequence, completed = model.predict_sequence(X_gpu, box_borders, order=sequence_order) 97 | pred_dist = pred_dist.cpu().numpy() 98 | 99 | return X, pred_graphs, pred_dist, ref_graphs, ref_dist, pred_sequence, box_borders 100 | 101 | def apply_preprocessing(batch, box_borders): 102 | 103 | X, Y, atoms = batch 104 | 105 | pp.add_norm(X) 106 | pp.add_noise(X, c=0.1, randomize_amplitude=False) 107 | X = X[0] 108 | 109 | atoms = pp.top_atom_to_zero(atoms) 110 | bonds = utils.find_bonds(atoms) 111 | mols = [utils.MoleculeGraph(a, b, classes=classes) for a, b in zip(atoms, bonds)] 112 | mols = utils.threshold_atoms_bonds(mols, zmin) 113 | 114 | # shift_xy = [2, 0] 115 | # box_borders = ( 116 | # (box_borders[0][0] + shift_xy[0], box_borders[0][1] + shift_xy[1], box_borders[0][2]), 117 | # (box_borders[1][0] + shift_xy[0], box_borders[1][1] + shift_xy[1], box_borders[1][2]) 118 | # ) 119 | # for m in mols: 120 | # for a in m.atoms: 121 | # a.xyz[0] += shift_xy[0] 122 | # a.xyz[1] += shift_xy[1] 123 | 124 | ref_dist = utils.make_position_distribution([m.atoms for m in mols], box_borders, 125 | box_res=box_res, std=peak_std) 126 | 127 | return X, mols, ref_dist, box_borders 128 | 129 | def get_marker_size(z, max_size=marker_size): 130 | return max_size * (z - z_min_marker) / (z_max_marker - z_min_marker) 131 | 132 | def plot_xy(ax, mol, box_borders): 133 | 134 | if len(mol) > 0: 135 | 136 | mol_pos = mol.array(xyz=True) 137 | 138 | s = get_marker_size(mol_pos[:,2]) 139 | if (s < 0).any(): 140 | raise ValueError('Encountered atom z position(s) below box borders.') 141 | 142 | c = [class_colors[atom.class_index] for atom in mol.atoms] 143 | 144 | ax.scatter(mol_pos[:,0], mol_pos[:,1], c=c, s=s, edgecolors='k', zorder=2, linewidth=0.5) 145 | for b in mol.bonds: 146 | pos = np.vstack([mol_pos[b[0]], mol_pos[b[1]]]) 147 | ax.plot(pos[:,0], pos[:,1], 'k', linewidth=1, zorder=1) 148 | 149 | ax.set_xlim(box_borders[0][0], box_borders[1][0]) 150 | ax.set_ylim(box_borders[0][1], box_borders[1][1]) 151 | ax.set_aspect('equal', 'box') 152 | 153 | def plot_xz(ax, mol, box_borders): 154 | 155 | if len(mol) > 0: 156 | 157 | order = list(np.argsort(mol.array(xyz=True)[:, 1])[::-1]) 158 | mol = mol.permute(order) 159 | mol_pos = mol.array(xyz=True) 160 | 161 | s = get_marker_size(mol_pos[:,2]) 162 | if (s < 0).any(): 163 | raise ValueError('Encountered atom z position(s) below box borders.') 164 | 165 | c = [class_colors[atom.class_index] for atom in mol.atoms] 166 | 167 | for b in mol.bonds: 168 | pos = np.vstack([mol_pos[b[0]], mol_pos[b[1]]]) 169 | ax.plot(pos[:,0], pos[:,2], 'k', linewidth=1, zorder=1) 170 | ax.scatter(mol_pos[:,0], mol_pos[:,2], c=c, s=s, edgecolors='k', zorder=2, linewidth=0.5) 171 | 172 | ax.set_xlim(box_borders[0][0], box_borders[1][0]) 173 | ax.set_ylim(box_borders[0][2], box_borders[1][2]) 174 | ax.set_aspect('equal', 'box') 175 | 176 | # Initialize OpenCL environment on GPU 177 | env = oclu.OCLEnvironment( i_platform = 0 ) 178 | FFcl.init(env) 179 | oclr.init(env) 180 | 181 | # Load model 182 | model = load_pretrained_model('random', device=device) 183 | 184 | # Make predictions 185 | data = [ 186 | predict( 187 | model, 188 | molecules=[os.path.join(base_dir, 'bcb.xyz')], 189 | box_borders=((2,2,-1.5),(18,18,0.5)) 190 | ), 191 | predict( 192 | model, 193 | molecules=[os.path.join(base_dir, 'water.xyz')], 194 | box_borders=((2,2,-1.5),(18,18,0.5)) 195 | ), 196 | predict( 197 | model, 198 | molecules=[os.path.join(base_dir, 'ptcda.xyz')], 199 | box_borders=((2,2,-1.5),(22,18,0.5)) 200 | ) 201 | ] 202 | img_paths = [ 203 | os.path.join(base_dir, 'bcb.png'), 204 | os.path.join(base_dir, 'water.png'), 205 | os.path.join(base_dir, 'ptcda.png') 206 | ] 207 | 208 | # Initialize figure 209 | width = 16.0 / 2.54 210 | ns = len(afm_slices) 211 | a = 1.2 212 | h1, h2, h3 = 0.03 / 2.54, 0.2 / 2.54, 0.2 / 2.54 213 | w1, w2, w3, w4, w5 = 0.1 / 2.54, 0.7 / 2.54, 0.75 / 2.54, 0.75 / 2.54, 0.2 / 2.54 214 | y = (width - (w1 + w2 + w3 + w4 + w5) + (2*(1+a)*h1/ns + h2/2 + 32*h3/19)) / ((1+a)/ns + 1/2 + 32/19) 215 | widths = [ 216 | a*(y-2*h1)/ns, # Molecule geometry 217 | w1, # Padding 218 | (y-2*h1)/ns, # AFM 219 | w2, # Padding 220 | (y-h2)/2, # Position grid 221 | w3, # Padding 222 | 16*(y-h3)/19, # Predicted graph 223 | w4, # Padding 224 | 16*(y-h3)/19, # Reference graph 225 | w5 # Padding 226 | ] 227 | assert np.allclose(sum(widths), width) 228 | heights = [y * d[0].shape[2] / d[0].shape[1] for d in data] 229 | between_pad, bottom_pad, top_pad = 0.8 / 2.54, 0.7 / 2.54, 0.4 / 2.54 230 | height_legend = 0.5 / 2.54 231 | height = sum(heights) + (len(data) - 1) * between_pad + bottom_pad + top_pad + height_legend 232 | fig = plt.figure(figsize=(width, height)) 233 | 234 | y0 = height - top_pad 235 | for i, (X, pred_graphs, pred_dist, ref_graphs, ref_dist, pred_sequence, box_borders) in enumerate(data): 236 | 237 | extent = [box_borders[0][0], box_borders[1][0], box_borders[0][1], box_borders[1][1]] 238 | xticks = np.linspace(box_borders[0][0], box_borders[1][0], 5).astype(int) 239 | yticks = np.linspace(box_borders[0][1], box_borders[1][1], 5).astype(int) 240 | 241 | # Set subfigure reference letters 242 | fig.text(0.05/width, y0/height, string.ascii_uppercase[i], fontsize=10) 243 | 244 | # Create axes 245 | x = 0 246 | height_afm = (heights[i]-(ns-1)*h1)/ns 247 | height_grid = (heights[i]-h2)/2 248 | dy_graph = (heights[i]-h3)/19 249 | ax_img = fig.add_axes([x/width, (y0-heights[i])/height, widths[0]/width, heights[i]/height]) 250 | x += widths[0] + widths[1] 251 | axes_afm = [fig.add_axes([x/width, (y0-height_afm*(j+1)-j*h1)/height, widths[2]/width, 252 | height_afm/height]) for j in range(ns)] 253 | x += widths[2] + widths[3] 254 | axes_grid2d = [ 255 | fig.add_axes([x/width, (y0-height_grid)/height, widths[4]/width, height_grid/height]), 256 | fig.add_axes([x/width, (y0-heights[i])/height, widths[4]/width, height_grid/height]) 257 | ] 258 | x += widths[4] + widths[5] 259 | axes_pred = [ 260 | fig.add_axes([x/width, (y0-16*dy_graph)/height, widths[6]/width, 16*dy_graph/height]), 261 | fig.add_axes([x/width, (y0-heights[i])/height, widths[6]/width, 3*dy_graph/height]) 262 | ] 263 | x += widths[6] + widths[7] 264 | axes_ref = [ 265 | fig.add_axes([x/width, (y0-16*dy_graph)/height, widths[8]/width, 16*dy_graph/height]), 266 | fig.add_axes([x/width, (y0-heights[i])/height, widths[8]/width, 3*dy_graph/height]) 267 | ] 268 | y0 -= heights[i] + between_pad 269 | 270 | 271 | # Plot molecule geometry 272 | xyz_img = np.flipud(imageio.imread(img_paths[i])) 273 | ax_img.imshow(xyz_img, origin='lower') 274 | ax_img.axis('off') 275 | 276 | # Plot AFM 277 | for s, ax in zip(afm_slices, axes_afm): 278 | ax.imshow(X[0][:, :, s].T, origin='lower', cmap='afmhot') 279 | ax.axis('off') 280 | 281 | # Plot grid in 2D 282 | p_mean, r_mean = pred_dist.mean(axis=-1), ref_dist.mean(axis=-1) 283 | vmin = min(r_mean.min(), p_mean.min()) 284 | vmax = max(r_mean.max(), p_mean.max()) 285 | for ax, d in zip(axes_grid2d, [p_mean, r_mean]): 286 | ax.imshow(d.T, origin='lower', vmin=vmin, vmax=vmax, extent=extent) 287 | ax.tick_params('both', length=1, width=0.5, pad=1.5, labelsize=6) 288 | ax.spines[:].set_linewidth(0.3) 289 | ax.set_yticks(yticks) 290 | axes_grid2d[0].set_ylabel('Prediction, $y$(Å)', fontsize=6, labelpad=0) 291 | axes_grid2d[0].tick_params('x', bottom=False, labelbottom=False) 292 | axes_grid2d[1].set_xlabel('$x$(Å)', fontsize=6, labelpad=0) 293 | axes_grid2d[1].set_ylabel('Reference, $y$(Å)', fontsize=6, labelpad=0) 294 | axes_grid2d[1].set_xticks(xticks) 295 | 296 | # Plot graphs 297 | for axes, d in zip([axes_pred, axes_ref], [pred_graphs, ref_graphs]): 298 | plot_xy(axes[0], d[0], box_borders) 299 | plot_xz(axes[1], d[0], box_borders) 300 | axes[0].set_ylabel('$y$(Å)', fontsize=6, labelpad=0) 301 | axes[0].tick_params('both', length=1, width=0.5, pad=1, labelsize=6) 302 | axes[0].set_yticks(yticks) 303 | axes[0].tick_params('x', bottom=False, labelbottom=False) 304 | axes[0].spines[:].set_linewidth(0.3) 305 | axes[1].set_xlabel('$x$(Å)', fontsize=6, labelpad=0) 306 | axes[1].set_ylabel('$z$(Å)', fontsize=6, labelpad=0) 307 | axes[1].set_ylim(-2, 1) 308 | axes[1].tick_params('both', length=1, width=0.5, pad=1, labelsize=6) 309 | axes[1].set_xticks(xticks) 310 | axes[1].set_yticks([-2, -1, 0, 1]) 311 | axes[1].spines[:].set_linewidth(0.3) 312 | 313 | if i == 0: 314 | axes_afm[0].set_title('AFM input', fontsize=9, pad=4) 315 | axes_grid2d[0].set_title('Position grid (2D)', fontsize=9, pad=4) 316 | axes_pred[0].set_title('Predicted graph', fontsize=9, pad=4) 317 | axes_ref[0].set_title('Reference graph', fontsize=9, pad=4) 318 | 319 | # Add legend of classes and marker sizes 320 | y0 += -between_pad + bottom_pad 321 | ax_legend = fig.add_axes([0, 0, 1, height_legend/height]) 322 | ax_legend.axis('off') 323 | ax_legend.set_xlim([0, 1]) 324 | ax_legend.set_ylim([0, 1]) 325 | 326 | y = 0.5 327 | dx = [(len(c) + 1) * 0.022 for c in classes] 328 | dx += [0.04, 0.092, 0.097, 0.097] 329 | x = 0.5 - sum(dx)/2 330 | 331 | # Class colors 332 | for i, c in enumerate(classes): 333 | ax_legend.scatter(x, y, s=marker_size, c=class_colors[i], edgecolors='k', linewidth=0.5) 334 | t = ax_legend.text(x+0.01, y, ', '.join([utils.elements[e-1] for e in c]), fontsize=7, 335 | ha='left', va='center_baseline') 336 | x += dx[i] 337 | 338 | # Marker sizes 339 | x += dx[len(classes)] 340 | marker_zs = np.array([z_max_marker, (z_min_marker + z_max_marker + 0.2) / 2, z_min_marker + 0.2]) 341 | ss = get_marker_size(marker_zs) 342 | for i, (s, z) in enumerate(zip(ss, marker_zs)): 343 | ax_legend.scatter(x, y, s=s, c='w', edgecolors='k', linewidth=0.5) 344 | ax_legend.text(x + 0.01, y, f'z = {z}Å', fontsize=7, ha='left', va='center_baseline') 345 | x += dx[i+len(classes)+1] 346 | 347 | plt.savefig('./predictions.pdf', dpi=300) 348 | plt.close() 349 | 350 | # 3D Distribution grids 351 | fig = plt.figure(figsize=(16.0 / 2.54, 20 / 2.54)) 352 | fig_grid = fig.add_gridspec(len(data), 1, hspace=0.1, left=0.04, right=0.99, bottom=0.02, top=0.97, 353 | height_ratios=heights) 354 | for i, (_, _, pred_dist, _, ref_dist, _, box_borders) in enumerate(data): 355 | p, r = pred_dist[0], ref_dist[0] 356 | z_start = box_borders[0][2] 357 | z_res = (box_borders[1][2] - box_borders[0][2]) / pred_dist.shape[-1] 358 | extent = [box_borders[0][0], box_borders[1][0], box_borders[0][1], box_borders[1][1]] 359 | vmin = min(r.min(), p.min()) 360 | vmax = max(r.max(), p.max()) 361 | nrows, ncols = 2, 10 362 | sample_grid = fig_grid[i].subgridspec(nrows, ncols, wspace=0.02, hspace=0.2/heights[i]) 363 | for iz in range(p.shape[-1]): 364 | ix = iz % ncols 365 | iy = iz // ncols 366 | ax1, ax2 = sample_grid[iy, ix].subgridspec(2, 1, hspace=0.01).subplots() 367 | ax1.imshow(p[:,:,iz].T, origin='lower', vmin=vmin, vmax=vmax, extent=extent) 368 | ax2.imshow(r[:,:,iz].T, origin='lower', vmin=vmin, vmax=vmax, extent=extent) 369 | ax1.axis('off') 370 | ax2.axis('off') 371 | ax1.set_title(f'z = {z_start + (iz + 0.5) * z_res:.2f}Å', fontsize=6, pad=2) 372 | if ix == 0: 373 | ax1.text(-0.1, 0.5, 'Prediction', ha='center', va='center', 374 | transform=ax1.transAxes, rotation='vertical', fontsize=6) 375 | ax2.text(-0.1, 0.5, 'Reference', ha='center', va='center', 376 | transform=ax2.transAxes, rotation='vertical', fontsize=6) 377 | if iz == 0: 378 | fig.text(-0.4, 1, string.ascii_uppercase[i], fontsize=10, transform=ax1.transAxes) 379 | plt.savefig(f'grid3D.pdf', dpi=300) 380 | plt.close() 381 | 382 | # Plot prediction sequences 383 | for i, (_, pred_graphs, _, _, _, pred_sequence, box_borders) in enumerate(data): 384 | seq = pred_sequence[0] 385 | atom_pos = pred_graphs[0].array(xyz=True) 386 | seq_len = len(seq) + 1 387 | x_seq = min(6, seq_len) 388 | y_seq = int(seq_len / (6+1e-6)) + 1 389 | fig = plt.figure(figsize=(16 / 2.54, (1.4 * heights[i] * y_seq + 0.4*(y_seq - 1)) / 2.54)) 390 | grid_seq = fig.add_gridspec(y_seq, x_seq, wspace=0.02, hspace=0.3/heights[i], left=0.04, right=0.99, 391 | bottom=0.03/y_seq, top=1-0.1/y_seq) 392 | for j in range(len(seq) + 1): 393 | 394 | x_grid = j % x_seq 395 | y_grid = j // x_seq 396 | ax = fig.add_subplot(grid_seq[y_grid, x_grid]) 397 | ax.spines[:].set_linewidth(0.3) 398 | 399 | s = get_marker_size(atom_pos[:,2], 8) 400 | ax.scatter(atom_pos[:,0], atom_pos[:,1], c='lightgray', s=s) 401 | 402 | if j > 0: 403 | mol, atom, bonds = seq[j-1] 404 | mol_pos = mol.array(xyz=True) 405 | atom_xyz = atom.array(xyz=True) 406 | c = class_colors[atom.class_index] 407 | s = get_marker_size(atom_xyz[2], 10) 408 | ax.scatter(atom_xyz[0], atom_xyz[1], c=c, s=s, edgecolors='k', zorder=2, linewidth=0.3) 409 | bonds = [i for i in range(len(bonds)) if bonds[i] > 0.5] 410 | for b in bonds: 411 | pos = np.vstack([mol_pos[b], atom_xyz]) 412 | ax.plot(pos[:,0], pos[:,1], 'k', linewidth=0.6, zorder=1) 413 | if j > 1: 414 | s = get_marker_size(mol_pos[:,2], 10) 415 | c = [class_colors[atom.class_index] for atom in mol.atoms] 416 | ax.scatter(mol_pos[:,0], mol_pos[:,1], c=c, s=s, edgecolors='k', zorder=2, linewidth=0.3) 417 | for b in mol.bonds: 418 | pos = np.vstack([mol_pos[b[0]], mol_pos[b[1]]]) 419 | ax.plot(pos[:,0], pos[:,1], 'k', linewidth=0.6, zorder=1) 420 | 421 | ax.set_xlim(box_borders[0][0], box_borders[1][0]) 422 | ax.set_ylim(box_borders[0][1], box_borders[1][1]) 423 | ax.set_aspect('equal', 'box') 424 | 425 | ax.set_title(f'{j}', fontsize=7, pad=2) 426 | ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False) 427 | 428 | if j == 0: 429 | fig.text(-0.25, 1, string.ascii_uppercase[i], fontsize=10, transform=ax.transAxes) 430 | 431 | plt.savefig(f'pred_sequence_{i}.pdf', dpi=300) 432 | -------------------------------------------------------------------------------- /scripts/predict_random.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import time 5 | import glob 6 | import string 7 | import random 8 | import imageio 9 | import numpy as np 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | 13 | import torch 14 | from torch import nn, optim 15 | 16 | sys.path.append('../ProbeParticleModel') # Make sure ProbeParticleModel is on PATH 17 | from pyProbeParticle import oclUtils as oclu 18 | from pyProbeParticle import fieldOCL as FFcl 19 | from pyProbeParticle import RelaxOpenCL as oclr 20 | from pyProbeParticle.AFMulatorOCL_Simple import AFMulator 21 | from pyProbeParticle.GeneratorOCL_Simple2 import InverseAFMtrainer 22 | 23 | sys.path.append('../src') # Add source code directory to Python PATH 24 | import utils 25 | import preprocessing as pp 26 | from models import load_pretrained_model 27 | 28 | # # Set matplotlib font rendering to use LaTex 29 | # plt.rcParams.update({ 30 | # "text.usetex": True, 31 | # "font.family": "serif", 32 | # "font.serif": ["Computer Modern Roman"] 33 | # }) 34 | 35 | # Set random seeds for reproducibility 36 | random.seed(0) 37 | np.random.seed(0) 38 | torch.manual_seed(0) 39 | 40 | model_path = './model.pth' # Path to trained model weights 41 | save_path = './random_predictions.pdf' # File to save image in 42 | classes = [[1], [6, 14], [7, 15], [8, 16], [9, 17, 35]] # List of elements in each class 43 | box_res = (0.125, 0.125, 0.1) # Real-space voxel size for position distribution 44 | zmin = -0.8 # Maximum depth used for thresholding 45 | peak_std = 0.25 # Standard deviation of atom position peaks in angstroms 46 | sequence_order = None # Order for graph construction 47 | device = 'cuda' # Device to run inference on 48 | molecules_dir = './Molecules/' # Base directory for molecule data 49 | class_colors = ['w', 'dimgray', 'b', 'r', 'yellowgreen'] # Colors for classes 50 | marker_size = 15 51 | z_min_marker = -1.0 52 | z_max_marker = 0.0 53 | 54 | # Choose random molecules from the test set 55 | num_mols = 10 56 | molecules = [os.path.join(molecules_dir, f'test/{n}.xyz') 57 | for n in np.random.choice(range(35554), size=num_mols, replace=False)] 58 | print(molecules) 59 | 60 | def predict(model, molecules, box_borders): 61 | 62 | scan_dim = ( 63 | int((box_borders[1][0] - box_borders[0][0]) / box_res[0]), 64 | int((box_borders[1][1] - box_borders[0][1]) / box_res[1]), 65 | 20 66 | ) 67 | print(scan_dim) 68 | afmulator_args = { 69 | 'pixPerAngstrome' : 20, 70 | 'lvec' : np.array([ 71 | [ 0.0, 0.0, 0.0], 72 | [box_borders[1][0]+2.0, 0.0, 0.0], 73 | [ 0.0, box_borders[1][1]+2.0, 0.0], 74 | [ 0.0, 0.0, 6.0] 75 | ]), 76 | 'scan_dim' : scan_dim, 77 | 'scan_window' : (box_borders[0][:2] + (7.0,), box_borders[1][:2] + (9.0,)), 78 | 'amplitude' : 1.0, 79 | 'df_steps' : 10, 80 | 'initFF' : True 81 | } 82 | 83 | generator_kwargs = { 84 | 'batch_size' : 30, 85 | 'distAbove' : 5.3, 86 | 'iZPPs' : [8], 87 | 'Qs' : [[ -10, 20, -10, 0 ]], 88 | 'QZs' : [[ 0.1, 0, -0.1, 0 ]] 89 | } 90 | 91 | # Define AFMulator 92 | afmulator = AFMulator(**afmulator_args) 93 | afmulator.npbc = (0,0,0) 94 | 95 | # Define generator 96 | trainer = InverseAFMtrainer(afmulator, [], molecules, **generator_kwargs) 97 | 98 | # Generate batch 99 | batch = next(iter(trainer)) 100 | X, ref_graphs, ref_dist, box_borders = apply_preprocessing(batch, box_borders) 101 | 102 | with torch.no_grad(): 103 | X_gpu = torch.from_numpy(X).unsqueeze(1).to(device) 104 | pred_graphs, pred_dist, pred_sequence, completed = model.predict_sequence(X_gpu, box_borders, order=sequence_order) 105 | pred_dist = pred_dist.cpu().numpy() 106 | 107 | return X, pred_graphs, pred_dist, ref_graphs, ref_dist, pred_sequence, box_borders 108 | 109 | def apply_preprocessing(batch, box_borders): 110 | 111 | X, Y, atoms = batch 112 | 113 | pp.add_norm(X) 114 | pp.add_noise(X, c=0.1, randomize_amplitude=False) 115 | X = X[0] 116 | 117 | atoms = pp.top_atom_to_zero(atoms) 118 | bonds = utils.find_bonds(atoms) 119 | mols = [utils.MoleculeGraph(a, b, classes=classes) for a, b in zip(atoms, bonds)] 120 | mols = utils.threshold_atoms_bonds(mols, zmin) 121 | 122 | ref_dist = utils.make_position_distribution([m.atoms for m in mols], box_borders, 123 | box_res=box_res, std=peak_std) 124 | 125 | return X, mols, ref_dist, box_borders 126 | 127 | def get_marker_size(z, max_size=marker_size): 128 | return max_size * (z - z_min_marker) / (z_max_marker - z_min_marker) 129 | 130 | def plot_xy(ax, mol, box_borders): 131 | 132 | if len(mol) > 0: 133 | 134 | mol_pos = mol.array(xyz=True) 135 | 136 | s = get_marker_size(mol_pos[:,2]) 137 | if (s < 0).any(): 138 | raise ValueError('Encountered atom z position(s) below box borders.') 139 | 140 | c = [class_colors[atom.class_index] for atom in mol.atoms] 141 | 142 | ax.scatter(mol_pos[:,0], mol_pos[:,1], c=c, s=s, edgecolors='k', zorder=2, linewidth=0.5) 143 | for b in mol.bonds: 144 | pos = np.vstack([mol_pos[b[0]], mol_pos[b[1]]]) 145 | ax.plot(pos[:,0], pos[:,1], 'k', linewidth=1, zorder=1) 146 | 147 | ax.set_xlim(box_borders[0][0], box_borders[1][0]) 148 | ax.set_ylim(box_borders[0][1], box_borders[1][1]) 149 | ax.set_aspect('equal', 'box') 150 | 151 | # Initialize OpenCL environment on GPU 152 | env = oclu.OCLEnvironment( i_platform = 0 ) 153 | FFcl.init(env) 154 | oclr.init(env) 155 | 156 | # Download molecules if not already there 157 | utils.download_molecules(molecules_dir, verbose=1) 158 | 159 | # Load model 160 | model = load_pretrained_model('random', device=device) 161 | 162 | # Make predictions 163 | X, pred_graphs, pred_dist, ref_graphs, ref_dist, pred_sequence, box_borders = predict( 164 | model, molecules=molecules, box_borders=((2,2,-1.5),(18,18,0.5))) 165 | 166 | # Initialize figure 167 | fig = plt.figure(figsize=(16 / 2.54, 16.5 / 2.54)) 168 | fig_grid = fig.add_gridspec(num_mols // 2, 2, wspace=0.07, hspace=0.22, top=0.97, bottom=0.03, left=0.01, right=0.99) 169 | 170 | xticks = np.linspace(box_borders[0][0], box_borders[1][0], 5).astype(int) 171 | yticks = np.linspace(box_borders[0][1], box_borders[1][1], 5).astype(int) 172 | 173 | for i, (x, p, r) in enumerate(zip(X, pred_graphs, ref_graphs)): 174 | 175 | ix, iy = i % (num_mols // 2), 2 * i // num_mols 176 | sample_grid = fig_grid[ix, iy].subgridspec(1, 2, width_ratios=(1, 4), wspace=0.2) 177 | afm_axes = sample_grid[0, 0].subgridspec(2, 1, hspace=0.01).subplots() 178 | pred_ax, ref_ax = sample_grid[0, 1].subgridspec(1, 2, wspace=0.1).subplots() 179 | 180 | # Plot AFM 181 | for s, ax in zip([0, -1], afm_axes): 182 | ax.imshow(x[:, :, s].T, origin='lower', cmap='afmhot') 183 | ax.axis('off') 184 | 185 | # Plot graphs 186 | for ax, d in zip([pred_ax, ref_ax], [p, r]): 187 | plot_xy(ax, d, box_borders) 188 | ax.tick_params('both', length=1, width=0.5, pad=1, labelsize=6) 189 | ax.set_xticks(xticks) 190 | ax.set_xlabel('$x$(Å)', fontsize=6, labelpad=0) 191 | ax.spines[:].set_linewidth(0.3) 192 | pred_ax.set_yticks(yticks) 193 | pred_ax.set_ylabel('$y$(Å)', fontsize=6, labelpad=0) 194 | ref_ax.tick_params('y', left=False, labelleft=False) 195 | 196 | if (i % (num_mols // 2)) == 0: 197 | afm_axes[0].set_title('AFM sim.', fontsize=9, pad=4) 198 | pred_ax.set_title('Prediction', fontsize=9, pad=4) 199 | ref_ax.set_title('Reference', fontsize=9, pad=4) 200 | 201 | plt.savefig(save_path, dpi=300) 202 | -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | 2 | from train import * 3 | 4 | # Set random seeds for reproducibility 5 | np.random.seed(0) 6 | random.seed(0) 7 | torch.manual_seed(0) 8 | 9 | # Inference device 10 | device = 'cuda' 11 | 12 | # How many test set batches to make predictions on 13 | pred_batches = 2 14 | 15 | def batch_to_host(pred, batch): 16 | X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, Ns, xyz, ref_graphs = batch 17 | pred_nodes, pred_edges, pred_dist = pred 18 | X = X.cpu() 19 | node_inputs = node_inputs.cpu() 20 | edges = edges.cpu() 21 | node_rem = [n.cpu() for n in node_rem] 22 | edge_rem = [e.cpu() for e in edge_rem] 23 | terminate = terminate.cpu() 24 | ref_dist = ref_dist.cpu() 25 | ref_atoms = ref_atoms.cpu() 26 | pred_nodes = pred_nodes.cpu() 27 | pred_dist = pred_dist.cpu() 28 | pred_edges = [e.cpu() for e in pred_edges] 29 | X, mols, pred, ref, xyz, ref_dist, pred_dist, ref_graphs, ref_atoms = dl.uncollate( 30 | (pred_nodes, pred_edges, pred_dist), 31 | (X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, Ns, xyz, ref_graphs) 32 | ) 33 | return X, mols, pred, ref, xyz, ref_dist, pred_dist, ref_graphs, ref_atoms 34 | 35 | if __name__ == "__main__": 36 | 37 | start_time = time.time() 38 | 39 | # Check checkpoint directory 40 | checkpoint_dir = os.path.join(model_dir, 'CheckPoints/') 41 | if not os.path.exists(checkpoint_dir): 42 | raise RuntimeError('No checkpoint directory. Cannot load model for testing.') 43 | 44 | # Define model, optimizer, and loss 45 | model, criterion, optimizer = make_model(device) 46 | 47 | print(f'CUDA is AVAILABLE = {torch.cuda.is_available()}') 48 | print(f'Model total parameters: {utils.count_parameters(model)}') 49 | 50 | # Create dataset and dataloader 51 | test_set, test_loader, _ = make_dataloader('test', world_size=1, rank=device) 52 | 53 | # Load checkpoint 54 | for last_epoch in reversed(range(1, epochs+1)): 55 | if os.path.exists( state_file := os.path.join(checkpoint_dir, f'model_{last_epoch}.pth') ): 56 | state = torch.load(state_file, map_location={'cuda:0': device}) 57 | model.load_state_dict(state['model_params']) 58 | optimizer.load_state_dict(state['optim_params']) 59 | break 60 | 61 | print(f'\n ========= Testing with model from epoch {last_epoch}') 62 | 63 | stats = analysis.GraphPredStats(len(classes)) 64 | seq_stats = analysis.GraphSeqStats(len(classes)) 65 | eval_losses = [] 66 | eval_start = time.time() 67 | if timings: t0 = eval_start 68 | 69 | model.eval() 70 | with torch.no_grad(): 71 | 72 | for ib, batch in enumerate(test_loader): 73 | 74 | # Transfer batch to device 75 | (X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, 76 | Ns, xyz, ref_graphs, box_borders) = batch_to_device(batch, device) 77 | 78 | if timings: 79 | torch.cuda.synchronize() 80 | t1 = time.time() 81 | 82 | # Forward 83 | pred = model(X, node_inputs, edges, Ns, ref_atoms, box_borders, return_attention=True) 84 | pred_nodes, pred_edges, pred_dist, unet_attention_maps, encoding_attention_maps = pred 85 | losses, min_inds = criterion( 86 | (pred_nodes, pred_edges, pred_dist), 87 | (node_rem, edge_rem, terminate, ref_dist), 88 | separate_loss_factors=True 89 | ) 90 | 91 | if timings: 92 | torch.cuda.synchronize() 93 | t2 = time.time() 94 | 95 | # Predict full molecule graphs in sequence 96 | pred_graphs, pred_dist, pred_sequence, completed = model.predict_sequence(X, box_borders, 97 | order=sequence_order) 98 | 99 | if timings: 100 | torch.cuda.synchronize() 101 | t3 = time.time() 102 | 103 | # Back to host 104 | batch = (X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, Ns, xyz, ref_graphs) 105 | X, mols, pred, ref, xyz, ref_dist, pred_dist, ref_graphs, ref_atoms = batch_to_host( 106 | (pred_nodes, pred_edges, pred_dist), batch) 107 | 108 | # Gather statistical information 109 | ref = [r[i] + (t, d) for i, (r, t, d) in zip(min_inds, ref)] 110 | stats.add_batch_grid(pred, ref) 111 | seq_stats.add_batch(pred_graphs, ref_graphs) 112 | eval_losses.append([loss.item() for loss in losses]) 113 | 114 | if (ib+1) % print_interval == 0: 115 | eta = (time.time() - eval_start) / (ib + 1) * (len(test_loader) - (ib + 1)) 116 | print(f'Test Batch {ib+1}/{len(test_loader)} - ETA: {eta:.2f}s') 117 | 118 | if timings: 119 | torch.cuda.synchronize() 120 | t4 = time.time() 121 | print('(Test) t0/Load Batch/Forward/Seq prediction/Stats: ' 122 | f'{t0:6f}/{t1-t0:6f}/{t2-t1:6f}/{t3-t2:6f}/{t4-t3:6f}') 123 | t0 = t4 124 | 125 | # Save statistical information 126 | stats_dir1 = os.path.join(model_dir,'stats_single') 127 | stats_dir2 = os.path.join(model_dir,'stats_sequence') 128 | stats.plot(stats_dir1) 129 | stats.report(stats_dir1) 130 | seq_stats.plot(stats_dir2) 131 | seq_stats.report(stats_dir2) 132 | 133 | # Average losses and print 134 | eval_loss = np.mean(eval_losses, axis=0) 135 | print(f'Test set loss: {loss_str(eval_loss)}') 136 | 137 | # Save test set loss to file 138 | with open(os.path.join(model_dir, 'test_loss.txt'),'w') as f: 139 | f.write(';'.join([str(l) for l in eval_loss])) 140 | 141 | # Make predictions 142 | print(f'\n ========= Predict on {pred_batches} batches from the test set') 143 | counter = 0 144 | pred_dir = os.path.join(model_dir, 'predictions/') 145 | pred_dir2 = os.path.join(model_dir, 'predictions_sequence/') 146 | 147 | with torch.no_grad(): 148 | 149 | for ib, batch in enumerate(test_loader): 150 | 151 | if ib >= pred_batches: break 152 | 153 | # Transfer batch to device 154 | (X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, 155 | Ns, xyz, ref_graphs, box_borders) = batch_to_device(batch, device) 156 | 157 | # Forward 158 | pred = model(X, node_inputs, edges, Ns, ref_atoms, box_borders, return_attention=True) 159 | pred_nodes, pred_edges, pred_dist, unet_attention_maps, encoding_attention_maps = pred 160 | _, min_inds = criterion( 161 | (pred_nodes, pred_edges, pred_dist), 162 | (node_rem, edge_rem, terminate, ref_dist), 163 | separate_loss_factors=True 164 | ) 165 | 166 | # Predict full molecule graphs in sequence 167 | pred_graphs, pred_dist, pred_sequence, completed = model.predict_sequence(X, box_borders, 168 | order=sequence_order) 169 | 170 | # Back to host 171 | batch = (X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, Ns, xyz, ref_graphs) 172 | X, mols, pred, ref, xyz, ref_dist, pred_dist, ref_graphs, ref_atoms = batch_to_host( 173 | (pred_nodes, pred_edges, pred_dist), batch) 174 | unet_attention_maps = [a.cpu() for a in unet_attention_maps] 175 | encoding_attention_maps = [a.cpu() for a in encoding_attention_maps] 176 | 177 | # Save xyzs 178 | utils.batch_write_xyzs(xyz, outdir=pred_dir, start_ind=counter) 179 | utils.batch_write_xyzs(xyz, outdir=pred_dir2, start_ind=counter) 180 | utils.save_graphs_to_xyzs(pred_graphs, classes, 181 | outfile_format=os.path.join(pred_dir2, '{ind}_graph_pred.xyz'), start_ind=counter) 182 | utils.save_graphs_to_xyzs(ref_graphs, classes, 183 | outfile_format=os.path.join(pred_dir2, '{ind}_graph_ref.xyz'), start_ind=counter) 184 | 185 | # Visualize predictions 186 | ref = [r[i] + (t, d) for i, (r, t, d) in zip(min_inds, ref)] 187 | vis.visualize_graph_grid(mols, pred, ref, box_borders=box_borders, 188 | outdir=pred_dir, start_ind=counter) 189 | vis.plot_graph_sequence_grid(pred_graphs, ref_graphs, pred_sequence, box_borders=box_borders, 190 | outdir=pred_dir2, start_ind=counter, classes=classes, class_colors=class_colors) 191 | vis.plot_distribution_grid(pred_dist, ref_dist, box_borders=box_borders, outdir=pred_dir2, 192 | start_ind=counter) 193 | vis.make_input_plots([X], outdir=pred_dir, start_ind=counter) 194 | vis.make_input_plots([X], outdir=pred_dir2, start_ind=counter) 195 | 196 | counter += len(mols) 197 | 198 | print(f'Done. Total time: {time.time() - start_time:.0f}s') 199 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import copy 5 | import time 6 | import random 7 | import numpy as np 8 | 9 | import torch 10 | from torch import nn, optim 11 | 12 | sys.path.append('../src') # Add source code directory to Python PATH 13 | import utils 14 | import analysis 15 | import preprocessing as pp 16 | import data_loading as dl 17 | import visualization as vis 18 | from models import AttentionEncoderUNet, GNN, GridGraphImgNet, GridMultiAtomLoss 19 | 20 | hdf5_path = './graph_dataset.hdf5' # Path to HDF5 database where data is read from 21 | model_dir = './model_random_order' # Directory where all model files are saved 22 | epochs = 50 # Number of epochs to train 23 | batch_size = 32 # Number of samples per batch 24 | classes = [[1], [6, 14], [7, 15], [8, 16], [9, 17, 35]] # List of elements in each class 25 | box_borders = ((2,2,-1.5),(18,18,0.5)) # Real-space extent of plotting region 26 | box_res = (0.125, 0.125, 0.1) # Real-space voxel size for position distribution 27 | zmin = -0.8 # Maximum depth used for thresholding 28 | peak_std = 0.25 # Standard deviation of atom position peaks in angstroms 29 | sequence_order = None # Order for graph construction 30 | num_workers = 8 # Number of parallel workers 31 | timings = False # Print timings for each batch 32 | device = 'cuda' # Device to use 33 | print_interval = 10 # Losses will be printed every print_interval batches 34 | class_colors = ['w', 'dimgray', 'b', 'r', 'yellowgreen'] # Colors for classes 35 | loss_weights = { # Weights for loss components 36 | 'pos_factor' : 100.0, 37 | 'class_factor' : 1.0, 38 | 'edge_factor' : 1.0 39 | } 40 | 41 | def make_model(device): 42 | gnn = GNN( 43 | hidden_size = 64, 44 | iters = 3, 45 | n_node_features = 20, 46 | n_edge_features = 20 47 | ) 48 | cnn = AttentionEncoderUNet( 49 | conv3d_in_channels = 1, 50 | conv3d_block_channels = [4, 8, 16, 32], 51 | conv3d_block_depth = 2, 52 | encoding_block_channels = [4, 8, 16, 32], 53 | encoding_block_depth = 2, 54 | upscale_block_channels = [32, 16, 8], 55 | upscale_block_depth = 2, 56 | upscale_block_channels2 = [32, 16, 8], 57 | upscale_block_depth2 = 2, 58 | attention_channels = [32, 32, 32], 59 | query_size = 64, 60 | res_connections = True, 61 | hidden_dense_units = [], 62 | out_units = 128, 63 | activation = 'relu', 64 | padding_mode = 'zeros', 65 | pool_type = 'avg', 66 | pool_z_strides = [2, 1, 2], 67 | decoder_z_sizes = [4, 10, 20], 68 | attention_activation = 'softmax' 69 | ) 70 | model = GridGraphImgNet( 71 | cnn = cnn, 72 | gnn = gnn, 73 | n_classes = len(classes), 74 | expansion_hidden = 32, 75 | expanded_size = 128, 76 | query_hidden = 64, 77 | class_hidden = 32, 78 | edge_hidden = 32, 79 | peak_std = peak_std, 80 | match_method = 'msd_norm', 81 | match_threshold = 0.7, 82 | dist_threshold = 0.5, 83 | teacher_forcing_rate = 1.0, 84 | device = device 85 | ) 86 | criterion = GridMultiAtomLoss(**loss_weights) 87 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 88 | return model, criterion, optimizer 89 | 90 | def apply_preprocessing(batch): 91 | 92 | X, Y, atoms = batch 93 | 94 | pp.add_norm(X) 95 | pp.add_noise(X, c=0.1, randomize_amplitude=True, normal_amplitude=True) 96 | pp.rand_shift_xy_trend(X, shift_step_max=0.02, max_shift_total=0.04) 97 | pp.add_cutout(X, n_holes=5) 98 | X = X[0] 99 | 100 | atoms = pp.top_atom_to_zero(atoms) 101 | xyz = atoms.copy() 102 | bonds = utils.find_bonds(atoms) 103 | mols = [utils.MoleculeGraph(a, b, classes=classes) for a, b in zip(atoms, bonds)] 104 | mols = utils.threshold_atoms_bonds(mols, zmin) 105 | ref_graphs = copy.deepcopy(mols) 106 | mols, removed = utils.remove_atoms(mols, order=sequence_order) 107 | 108 | ref_dist = utils.make_position_distribution([m.atoms for m in ref_graphs], box_borders, 109 | box_res=box_res, std=peak_std) 110 | 111 | if sequence_order: 112 | order_ind = 'xyz'.index(sequence_order) 113 | ref_atoms = [sorted(r, key=lambda x: x[0].xyz[order_ind])[-1][0] if len(r) > 0 else None 114 | for r in removed] 115 | else: 116 | ref_atoms = [random.choice(r)[0] if len(r) > 0 else None for r in removed] 117 | utils.randomize_atom_positions(ref_atoms, std=[0.2, 0.2, 0.05], cutoff=0.5) 118 | 119 | return X, mols, removed, xyz, ref_graphs, ref_dist, ref_atoms, box_borders 120 | 121 | def make_dataloader(mode, world_size=1, rank=0): 122 | dataset, dataloader, sampler = dl.make_hdf5_dataloader( 123 | hdf5_path, 124 | apply_preprocessing, 125 | mode=mode, 126 | collate_fn=dl.collate, 127 | batch_size=batch_size // world_size, 128 | shuffle=True, 129 | num_workers=num_workers, 130 | world_size=world_size, 131 | rank=rank 132 | ) 133 | return dataset, dataloader, sampler 134 | 135 | def batch_to_device(batch, device=device): 136 | X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, Ns, xyz, ref_graphs, box_borders = batch 137 | X = X.to(device) 138 | node_inputs = node_inputs.to(device) 139 | edges = edges.to(device) 140 | terminate = terminate.to(device) 141 | ref_dist = ref_dist.to(device) 142 | node_rem = [n.to(device) for n in node_rem] 143 | edge_rem = [e.to(device) for e in edge_rem] 144 | ref_atoms = ref_atoms.to(device) 145 | return X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, Ns, xyz, ref_graphs, box_borders 146 | 147 | def loss_str(losses): 148 | loss_msg = ( 149 | f'{losses[0]:.4f} (' 150 | f'MSE (Pos.): {losses[1]:.4f} x {loss_weights["pos_factor"]}, ' 151 | f'NLL (Class): {losses[2]:.4f} x {loss_weights["class_factor"]}, ' 152 | f'NLL (Edge): {losses[3]:.4f} x {loss_weights["edge_factor"]})' 153 | ) 154 | return loss_msg 155 | 156 | if __name__ == '__main__': 157 | 158 | start_time = time.time() 159 | 160 | model, criterion, optimizer = make_model(device) 161 | 162 | print(f'CUDA is available = {torch.cuda.is_available()}') 163 | print(f'Model total parameters: {utils.count_parameters(model)}') 164 | 165 | # Create model directory 166 | if not os.path.exists(model_dir): 167 | os.makedirs(model_dir) 168 | 169 | # Create datasets and dataloaders 170 | train_set, train_loader, _ = make_dataloader('train') 171 | val_set, val_loader, _ = make_dataloader('val') 172 | test_set, test_loader, _ = make_dataloader('test') 173 | 174 | # Create a folder for model checkpoints 175 | checkpoint_dir = os.path.join(model_dir, 'CheckPoints/') 176 | if not os.path.exists(checkpoint_dir): 177 | os.makedirs(checkpoint_dir) 178 | 179 | # Load checkpoint if available 180 | for init_epoch in reversed(range(1, epochs+1)): 181 | if os.path.exists( state_file := os.path.join(checkpoint_dir, f'model_{init_epoch}.pth') ): 182 | utils.load_checkpoint(model, optimizer, state_file) 183 | init_epoch += 1 184 | break 185 | 186 | if init_epoch <= epochs: 187 | print(f'\n ========= Starting training from epoch {init_epoch}') 188 | else: 189 | print('Model already trained') 190 | 191 | # Setup logging 192 | log_path = os.path.join(model_dir, 'loss_log.csv') 193 | plot_path = os.path.join(model_dir, 'loss_history.png') 194 | loss_labels = ['Total', 'MSE (Pos.)', 'NLL (Class)', 'NLL (Edge)'] 195 | logger = utils.LossLogPlot(log_path, plot_path, loss_labels, 196 | ['', loss_weights['pos_factor'], loss_weights['class_factor'], loss_weights['edge_factor']] 197 | ) 198 | 199 | for epoch in range(init_epoch, epochs+1): 200 | 201 | print(f'\n === Epoch {epoch}') 202 | 203 | # Train 204 | train_losses = [] 205 | epoch_start = time.time() 206 | if timings: t0 = epoch_start 207 | 208 | model.train() 209 | for ib, batch in enumerate(train_loader): 210 | 211 | # Transfer batch to device 212 | (X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, Ns, xyz, 213 | ref_graphs, box_borders) = batch_to_device(batch) 214 | 215 | if timings: 216 | if device == 'cuda': torch.cuda.synchronize() 217 | t1 = time.time() 218 | 219 | # Forward 220 | pred = model(X, node_inputs, edges, Ns, ref_atoms, box_borders) 221 | losses, _ = criterion(pred, (node_rem, edge_rem, terminate, ref_dist), separate_loss_factors=True) 222 | loss = losses[0] 223 | 224 | if timings: 225 | if device == 'cuda': torch.cuda.synchronize() 226 | t2 = time.time() 227 | 228 | # Backward 229 | optimizer.zero_grad() 230 | loss.backward() 231 | optimizer.step() 232 | 233 | train_losses.append([loss.item() for loss in losses]) 234 | 235 | if ib == len(train_loader) or (ib+1) % print_interval == 0: 236 | eta = (time.time() - epoch_start) / (ib + 1) * ((len(train_loader)+len(val_loader)) - (ib + 1)) 237 | mean_loss = np.mean(train_losses[-print_interval:], axis=0) 238 | print(f'Epoch {epoch}, Train Batch {ib+1}/{len(train_loader)} - Loss: {loss_str(mean_loss)} - ETA: {eta:.2f}s') 239 | 240 | if timings: 241 | t3 = time.time() 242 | print(f'(Train) t0/Load Batch/Forward/Backward: {t0}/{t1-t0}/{t2-t1}/{t3-t2}') 243 | t0 = t3 244 | 245 | # Validate 246 | val_losses = [] 247 | val_start = time.time() 248 | if timings: t0 = val_start 249 | 250 | model.eval() 251 | with torch.no_grad(): 252 | 253 | for ib, batch in enumerate(val_loader): 254 | 255 | # Transfer batch to device 256 | (X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, Ns, xyz, 257 | ref_graphs, box_borders) = batch_to_device(batch) 258 | 259 | if timings: 260 | if device == 'cuda': torch.cuda.synchronize() 261 | t1 = time.time() 262 | 263 | # Forward 264 | pred = model(X, node_inputs, edges, Ns, ref_atoms, box_borders) 265 | losses, _ = criterion(pred, (node_rem, edge_rem, terminate, ref_dist), separate_loss_factors=True) 266 | 267 | val_losses.append([loss.item() for loss in losses]) 268 | 269 | if (ib+1) % print_interval == 0: 270 | eta = (time.time() - epoch_start) / (len(train_loader) + ib + 1) * (len(val_loader) - (ib + 1)) 271 | print(f'Epoch {epoch}, Val Batch {ib+1}/{len(val_loader)} - ETA: {eta:.2f}s') 272 | 273 | if timings: 274 | t2 = time.time() 275 | print(f'(Val) t0/Load Batch/Forward: {t0}/{t1-t0}/{t2-t1}') 276 | t0 = t2 277 | 278 | train_loss = np.mean(train_losses, axis=0) 279 | val_loss = np.mean(val_losses, axis=0) 280 | print(f'End of epoch {epoch}') 281 | print(f'Train loss: {loss_str(train_loss)}') 282 | print(f'Val loss: {loss_str(val_loss)}') 283 | 284 | epoch_end = time.time() 285 | train_step = (val_start - epoch_start) / len(train_loader) 286 | val_step = (epoch_end - val_start) / len(val_loader) 287 | print(f'Epoch time: {epoch_end - epoch_start:.2f}s - Train step: {train_step:.5f}s - Val step: {val_step:.5f}s') 288 | 289 | # Add losses to log 290 | logger.add_losses(train_loss, val_loss) 291 | logger.plot_history() 292 | 293 | # Save checkpoint 294 | utils.save_checkpoint(model, optimizer, epoch, checkpoint_dir) 295 | 296 | # Save final model 297 | torch.save(model, save_path := os.path.join(model_dir, 'model.pth')) 298 | print(f'\nModel saved to {save_path}') 299 | 300 | print(f'Done. Total time: {time.time() - start_time:.0f}s') -------------------------------------------------------------------------------- /scripts/train_distributed.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import time 5 | import copy 6 | import random 7 | import numpy as np 8 | 9 | import torch 10 | from torch import nn, optim 11 | import torch.distributed as dist 12 | import torch.multiprocessing as mp 13 | from torch.nn.parallel import DistributedDataParallel 14 | from torch.utils.data.distributed import DistributedSampler 15 | 16 | sys.path.append('../src') # Add source code directory to Python PATH 17 | import utils 18 | import analysis 19 | import preprocessing as pp 20 | import data_loading as dl 21 | import visualization as vis 22 | from models import AttentionEncoderUNet, GNN, GridGraphImgNet, GridMultiAtomLoss 23 | 24 | hdf5_path = './graph_dataset.hdf5' # Path to HDF5 database where data is read from 25 | model_dir = './model_random_order' # Directory where all model files are saved 26 | epochs = 50 # Number of epochs to train 27 | batch_size = 32 # Number of samples per batch 28 | classes = [[1], [6, 14], [7, 15], [8, 16], [9, 17, 35]] # List of elements in each class 29 | box_borders = ((2,2,-1.5),(18,18,0.5)) # Real-space extent of plotting region 30 | box_res = (0.125, 0.125, 0.1) # Real-space voxel size for position distribution 31 | zmin = -0.8 # Maximum depth used for thresholding 32 | peak_std = 0.25 # Standard deviation of atom position peaks in angstroms 33 | sequence_order = None # Order for graph construction 34 | num_workers = 4 # Number of parallel workers for each GPU 35 | timings = False # Print timings for each batch 36 | device = 'cuda' # Device to use 37 | comm_backend = 'nccl' # Backend for interprocess communications 38 | print_interval = 10 # Losses will be printed every print_interval batches 39 | class_colors = ['w', 'dimgray', 'b', 'r', 'yellowgreen'] # Colors for classes 40 | loss_weights = { # Weights for loss components 41 | 'pos_factor' : 100.0, 42 | 'class_factor' : 1.0, 43 | 'edge_factor' : 1.0 44 | } 45 | 46 | def make_model(device): 47 | gnn = GNN( 48 | hidden_size = 64, 49 | iters = 3, 50 | n_node_features = 20, 51 | n_edge_features = 20 52 | ) 53 | cnn = AttentionEncoderUNet( 54 | conv3d_in_channels = 1, 55 | conv3d_block_channels = [4, 8, 16, 32], 56 | conv3d_block_depth = 2, 57 | encoding_block_channels = [4, 8, 16, 32], 58 | encoding_block_depth = 2, 59 | upscale_block_channels = [32, 16, 8], 60 | upscale_block_depth = 2, 61 | upscale_block_channels2 = [32, 16, 8], 62 | upscale_block_depth2 = 2, 63 | attention_channels = [32, 32, 32], 64 | query_size = 64, 65 | res_connections = True, 66 | hidden_dense_units = [], 67 | out_units = 128, 68 | activation = 'relu', 69 | padding_mode = 'zeros', 70 | pool_type = 'avg', 71 | pool_z_strides = [2, 1, 2], 72 | decoder_z_sizes = [4, 10, 20], 73 | attention_activation = 'softmax' 74 | ) 75 | model = GridGraphImgNet( 76 | cnn = cnn, 77 | gnn = gnn, 78 | n_classes = len(classes), 79 | expansion_hidden = 32, 80 | expanded_size = 128, 81 | query_hidden = 64, 82 | class_hidden = 32, 83 | edge_hidden = 32, 84 | peak_std = peak_std, 85 | match_method = 'msd_norm', 86 | match_threshold = 0.7, 87 | dist_threshold = 0.5, 88 | teacher_forcing_rate = 1.0, 89 | device = device 90 | ) 91 | criterion = GridMultiAtomLoss(**loss_weights) 92 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 93 | return model, criterion, optimizer 94 | 95 | def apply_preprocessing(batch): 96 | 97 | X, Y, atoms = batch 98 | 99 | pp.add_norm(X) 100 | pp.add_noise(X, c=0.1, randomize_amplitude=True, normal_amplitude=True) 101 | pp.rand_shift_xy_trend(X, shift_step_max=0.02, max_shift_total=0.04) 102 | pp.add_cutout(X, n_holes=5) 103 | X = X[0] 104 | 105 | atoms = pp.top_atom_to_zero(atoms) 106 | xyz = atoms.copy() 107 | bonds = utils.find_bonds(atoms) 108 | mols = [utils.MoleculeGraph(a, b, classes=classes) for a, b in zip(atoms, bonds)] 109 | mols = utils.threshold_atoms_bonds(mols, zmin) 110 | ref_graphs = copy.deepcopy(mols) 111 | mols, removed = utils.remove_atoms(mols, order=sequence_order) 112 | 113 | ref_dist = utils.make_position_distribution([m.atoms for m in ref_graphs], box_borders, 114 | box_res=box_res, std=peak_std) 115 | 116 | if sequence_order: 117 | order_ind = 'xyz'.index(sequence_order) 118 | ref_atoms = [sorted(r, key=lambda x: x[0].xyz[order_ind])[-1][0] if len(r) > 0 else None 119 | for r in removed] 120 | else: 121 | ref_atoms = [random.choice(r)[0] if len(r) > 0 else None for r in removed] 122 | utils.randomize_atom_positions(ref_atoms, std=[0.2, 0.2, 0.05], cutoff=0.5) 123 | 124 | return X, mols, removed, xyz, ref_graphs, ref_dist, ref_atoms, box_borders 125 | 126 | def make_dataloader(mode, world_size, rank): 127 | dataset, dataloader, sampler = dl.make_hdf5_dataloader( 128 | hdf5_path, 129 | apply_preprocessing, 130 | mode=mode, 131 | collate_fn=dl.collate, 132 | batch_size=batch_size // world_size, 133 | shuffle=False, 134 | num_workers=num_workers, 135 | world_size=world_size, 136 | rank=rank 137 | ) 138 | return dataset, dataloader, sampler 139 | 140 | def batch_to_device(batch, device): 141 | X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, Ns, xyz, ref_graphs, box_borders = batch 142 | X = X.to(device) 143 | node_inputs = node_inputs.to(device) 144 | edges = edges.to(device) 145 | terminate = terminate.to(device) 146 | ref_dist = ref_dist.to(device) 147 | node_rem = [n.to(device) for n in node_rem] 148 | edge_rem = [e.to(device) for e in edge_rem] 149 | ref_atoms = ref_atoms.to(device) 150 | return X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, Ns, xyz, ref_graphs, box_borders 151 | 152 | def loss_str(losses): 153 | loss_msg = ( 154 | f'{losses[0]:.4f} (' 155 | f'MSE (Pos.): {losses[1]:.4f} x {loss_weights["pos_factor"]}, ' 156 | f'NLL (Class): {losses[2]:.4f} x {loss_weights["class_factor"]}, ' 157 | f'NLL (Edge): {losses[3]:.4f} x {loss_weights["edge_factor"]})' 158 | ) 159 | return loss_msg 160 | 161 | def recv_losses(losses, world_size, rank=0): 162 | losses_ = torch.zeros(world_size, len(losses)).float().to(rank) 163 | losses_[0] = torch.tensor(losses).float().to(rank) 164 | for i in range(1, world_size): 165 | dist.recv(losses_[i], src=i) 166 | losses = losses_.mean(dim=0) 167 | return losses 168 | 169 | def run(rank, world_size, q): 170 | 171 | # Initialize the distributed environment. 172 | os.environ['MASTER_ADDR'] = 'localhost' 173 | os.environ['MASTER_PORT'] = '12355' 174 | dist.init_process_group(comm_backend, rank=rank, world_size=world_size) 175 | 176 | start_time = time.time() 177 | 178 | # Create directories 179 | checkpoint_dir = os.path.join(model_dir, 'CheckPoints/') 180 | if rank == 0: 181 | if not os.path.exists(model_dir): 182 | os.makedirs(model_dir) 183 | if not os.path.exists(checkpoint_dir): 184 | os.makedirs(checkpoint_dir) 185 | 186 | # Define model, optimizer, and loss 187 | model, criterion, optimizer = make_model(rank) 188 | 189 | if rank == 0: 190 | print(f'CUDA is available = {torch.cuda.is_available()}') 191 | print(f'Model total parameters: {utils.count_parameters(model)}') 192 | 193 | # Create datasets and dataloaders 194 | train_set, train_loader, sampler = make_dataloader('train', world_size, rank) 195 | val_set, val_loader, _ = make_dataloader('val', world_size, rank) 196 | 197 | # Load checkpoint if available 198 | dist.barrier() 199 | for init_epoch in reversed(range(1, epochs+1)): 200 | if os.path.exists( state_file := os.path.join(checkpoint_dir, f'model_{init_epoch}.pth') ): 201 | state = torch.load(state_file, map_location={'cuda:0': f'cuda:{rank}'}) 202 | model.load_state_dict(state['model_params']) 203 | optimizer.load_state_dict(state['optim_params']) 204 | init_epoch += 1 205 | break 206 | 207 | # Wrap model in DistributedDataParallel. 208 | model = DistributedDataParallel(model, device_ids=[rank], find_unused_parameters=True) 209 | 210 | # Setup logging 211 | if rank == 0: 212 | log_path = os.path.join(model_dir, 'loss_log.csv') 213 | plot_path = os.path.join(model_dir, 'loss_history.png') 214 | loss_labels = ['Total', 'MSE (Pos.)', 'NLL (Class)', 'NLL (Edge)'] 215 | logger = utils.LossLogPlot(log_path, plot_path, loss_labels, 216 | ['', loss_weights['pos_factor'], loss_weights['class_factor'], loss_weights['edge_factor']]) 217 | 218 | if rank == 0: 219 | if init_epoch <= epochs: 220 | print(f'\n ========= Starting training from epoch {init_epoch}') 221 | else: 222 | print('Model already trained') 223 | 224 | for epoch in range(init_epoch, epochs+1): 225 | 226 | if rank == 0: print(f'\n === Epoch {epoch}') 227 | 228 | if sampler: 229 | sampler.set_epoch(epoch) 230 | 231 | # Train 232 | if rank == 0: 233 | train_losses = [] 234 | epoch_start = time.time() 235 | if timings: t0 = epoch_start 236 | 237 | model.train() 238 | for ib, batch in enumerate(train_loader): 239 | 240 | # Transfer batch to device 241 | (X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, 242 | Ns, xyz, ref_graphs, box_borders) = batch_to_device(batch, rank) 243 | 244 | if timings and rank == 0: 245 | torch.cuda.synchronize() 246 | t1 = time.time() 247 | 248 | # Forward 249 | pred = model(X, node_inputs, edges, Ns, ref_atoms, box_borders) 250 | losses, _ = criterion(pred, (node_rem, edge_rem, terminate, ref_dist), separate_loss_factors=True) 251 | loss = losses[0] 252 | 253 | if timings and rank == 0: 254 | torch.cuda.synchronize() 255 | t2 = time.time() 256 | 257 | # Backward 258 | optimizer.zero_grad() 259 | loss.backward() 260 | optimizer.step() 261 | 262 | if rank == 0: 263 | 264 | losses = recv_losses(losses, world_size) # Get losses from other ranks 265 | train_losses.append([loss.item() for loss in losses]) 266 | 267 | if ib == len(train_loader)-1 or (ib+1) % print_interval == 0: 268 | eta = (time.time() - epoch_start) / (ib + 1) * ((len(train_loader)+len(val_loader)) - (ib + 1)) 269 | mean_loss = np.mean(train_losses[-print_interval:], axis=0) 270 | loss_msg = f'Loss: {loss_str(mean_loss)}' 271 | print(f'Epoch {epoch}, Train Batch {ib+1}/{len(train_loader)} - {loss_msg} - ETA: {eta:.2f}s') 272 | 273 | if timings: 274 | torch.cuda.synchronize() 275 | t3 = time.time() 276 | print(f'(Train) t0/Load Batch/Forward/Backward: {t0}/{t1-t0}/{t2-t1}/{t3-t2}') 277 | t0 = t3 278 | else: 279 | 280 | # Send loss to rank 0 281 | dist.send(torch.tensor(losses).float().to(rank), dst=0) 282 | 283 | # Validate 284 | if rank == 0: 285 | val_losses = [] 286 | val_start = time.time() 287 | if timings: t0 = val_start 288 | 289 | model.eval() 290 | with torch.no_grad(): 291 | 292 | for ib, batch in enumerate(val_loader): 293 | 294 | # Transfer batch to device 295 | (X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, 296 | Ns, xyz, ref_graphs, box_borders) = batch_to_device(batch, rank) 297 | 298 | if timings: 299 | torch.cuda.synchronize() 300 | t1 = time.time() 301 | 302 | # Forward 303 | pred = model(X, node_inputs, edges, Ns, ref_atoms, box_borders) 304 | losses, _ = criterion(pred, (node_rem, edge_rem, terminate, ref_dist), separate_loss_factors=True) 305 | 306 | if rank == 0: 307 | 308 | losses = recv_losses(losses, world_size) # Get losses from other ranks 309 | val_losses.append([loss.item() for loss in losses]) 310 | 311 | if (ib+1) % print_interval == 0: 312 | eta = (time.time() - epoch_start) / (len(train_loader) + ib + 1) * (len(val_loader) - (ib + 1)) 313 | print(f'Epoch {epoch}, Val Batch {ib+1}/{len(val_loader)} - ETA: {eta:.2f}s') 314 | 315 | if timings: 316 | torch.cuda.synchronize() 317 | t2 = time.time() 318 | print(f'(Val) t0/Load Batch/Forward: {t0}/{t1-t0}/{t2-t1}') 319 | t0 = t2 320 | 321 | else: 322 | 323 | # Send loss to rank 0 324 | dist.send(torch.tensor(losses).float().to(rank), dst=0) 325 | 326 | if rank == 0: 327 | 328 | train_loss = np.mean(train_losses, axis=0) 329 | val_loss = np.mean(val_losses, axis=0) 330 | 331 | print(f'End of epoch {epoch}') 332 | print(f'Train loss: {loss_str(train_loss)}') 333 | print(f'Val loss: {loss_str(val_loss)}') 334 | 335 | epoch_end = time.time() 336 | train_step = (val_start - epoch_start) / len(train_loader) 337 | val_step = (epoch_end - val_start) / len(val_loader) 338 | print(f'Epoch time: {epoch_end - epoch_start:.2f}s - Train step: {train_step:.5f}s ' 339 | f'- Val step: {val_step:.5f}s') 340 | 341 | # Add losses to log 342 | logger.add_losses(train_loss, val_loss) 343 | logger.plot_history() 344 | 345 | # Save checkpoint 346 | utils.save_checkpoint(model, optimizer, epoch, checkpoint_dir) 347 | 348 | # Save final model 349 | if rank == 0: 350 | torch.save(model.module, save_path := os.path.join(model_dir, 'model.pth')) 351 | print(f'\nModel saved to {save_path}') 352 | 353 | print(f'Done at rank {rank}. Total time: {time.time() - start_time:.0f}s') 354 | 355 | dist.barrier() 356 | dist.destroy_process_group() 357 | 358 | if __name__ == "__main__": 359 | mp.set_start_method('spawn') 360 | q = mp.Queue(maxsize=10) 361 | world_size = torch.cuda.device_count() 362 | mp.spawn(run, args=(world_size, q), nprocs=world_size, join=True) -------------------------------------------------------------------------------- /src/c/bindings.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import multiprocessing as mp 5 | 6 | import ctypes 7 | from ctypes import c_int, c_float, POINTER 8 | 9 | N_POOL_PROC = mp.cpu_count() 10 | 11 | this_dir = os.path.dirname(os.path.abspath(__file__)) 12 | clib = ctypes.CDLL(os.path.join(this_dir, 'lib.so')) 13 | 14 | fp_p = POINTER(c_float) 15 | clib.match_template_mad.argtypes = [ 16 | c_int, c_int, c_int, fp_p, 17 | c_int, c_int, c_int, fp_p, 18 | fp_p 19 | ] 20 | clib.match_template_msd.argtypes = clib.match_template_mad.argtypes 21 | clib.match_template_mad_norm.argtypes = clib.match_template_mad.argtypes 22 | clib.match_template_msd_norm.argtypes = clib.match_template_mad.argtypes 23 | clib.match_template_mad_2d.argtypes = [ 24 | c_int, c_int, fp_p, 25 | c_int, c_int, fp_p, 26 | fp_p 27 | ] 28 | clib.match_template_msd_2d.argtypes = clib.match_template_mad_2d.argtypes 29 | clib.match_template_mad_norm_2d.argtypes = clib.match_template_mad_2d.argtypes 30 | clib.match_template_msd_norm_2d.argtypes = clib.match_template_mad_2d.argtypes 31 | 32 | def match_template(array, template, method='mad'): 33 | 34 | nax, nay, naz = array.shape 35 | ntx, nty, ntz = template.shape 36 | array_c = array.astype(np.float32).ctypes.data_as(fp_p) 37 | template_c = template.astype(np.float32).ctypes.data_as(fp_p) 38 | dist_array = np.empty((nax, nay, naz), dtype=np.float32) 39 | dist_array_c = dist_array.ctypes.data_as(fp_p) 40 | 41 | if method == 'mad': 42 | clib.match_template_mad( 43 | nax, nay, naz, array_c, 44 | ntx, nty, ntz, template_c, 45 | dist_array_c 46 | ) 47 | elif method == 'msd': 48 | clib.match_template_msd( 49 | nax, nay, naz, array_c, 50 | ntx, nty, ntz, template_c, 51 | dist_array_c 52 | ) 53 | elif method == 'mad_norm': 54 | clib.match_template_mad_norm( 55 | nax, nay, naz, array_c, 56 | ntx, nty, ntz, template_c, 57 | dist_array_c 58 | ) 59 | elif method == 'msd_norm': 60 | clib.match_template_msd_norm( 61 | nax, nay, naz, array_c, 62 | ntx, nty, ntz, template_c, 63 | dist_array_c 64 | ) 65 | else: 66 | raise ValueError(f'Unknown matching method `{method}`.') 67 | 68 | return dist_array 69 | 70 | def match_template_pool(arrays, template, method='mad'): 71 | inp = [(array, template, method) for array in arrays] 72 | with mp.Pool(processes=N_POOL_PROC) as pool: 73 | dist_arrays = pool.starmap(match_template, inp) 74 | dist_arrays = np.stack(dist_arrays, axis=0) 75 | return dist_arrays 76 | 77 | def match_template_2d(array, template, method='mad'): 78 | 79 | nax, nay = array.shape 80 | ntx, nty = template.shape 81 | array_c = array.astype(np.float32).ctypes.data_as(fp_p) 82 | template_c = template.astype(np.float32).ctypes.data_as(fp_p) 83 | dist_array = np.empty((nax, nay), dtype=np.float32) 84 | dist_array_c = dist_array.ctypes.data_as(fp_p) 85 | 86 | if method == 'mad': 87 | clib.match_template_mad_2d( 88 | nax, nay, array_c, 89 | ntx, nty, template_c, 90 | dist_array_c 91 | ) 92 | elif method == 'msd': 93 | clib.match_template_msd_2d( 94 | nax, nay, array_c, 95 | ntx, nty, template_c, 96 | dist_array_c 97 | ) 98 | elif method == 'mad_norm': 99 | clib.match_template_mad_norm_2d( 100 | nax, nay, array_c, 101 | ntx, nty, template_c, 102 | dist_array_c 103 | ) 104 | elif method == 'msd_norm': 105 | clib.match_template_msd_norm_2d( 106 | nax, nay, array_c, 107 | ntx, nty, template_c, 108 | dist_array_c 109 | ) 110 | else: 111 | raise ValueError(f'Unknown matching method `{method}`.') 112 | 113 | return dist_array 114 | -------------------------------------------------------------------------------- /src/c/matching.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | extern "C" { 10 | 11 | void match_template_mad( 12 | int nax, int nay, int naz, float *array, 13 | int ntx, int nty, int ntz, float *temp, 14 | float *dist_array 15 | ) { 16 | 17 | if (ntx % 2 == 0 || nty % 2 == 0 || ntz % 2 == 0) { 18 | throw std::invalid_argument("Template dimensions must all be odd size."); 19 | } 20 | 21 | int ti_middle = (ntx - 1) / 2; 22 | int tj_middle = (nty - 1) / 2; 23 | int tk_middle = (ntz - 1) / 2; 24 | 25 | for (int i = 0; i < nax; i++) { 26 | int ai_start = std::max(0, i - ti_middle); 27 | int ti_start = std::max(0, ti_middle - i); 28 | int ii_end = std::min(i + ti_middle + 1, nax) - ai_start; 29 | for (int j = 0; j < nay; j++) { 30 | int aj_start = std::max(0, j - tj_middle); 31 | int tj_start = std::max(0, tj_middle - j); 32 | int jj_end = std::min(j + tj_middle + 1, nay) - aj_start; 33 | for (int k = 0; k < naz; k++) { 34 | int ak_start = std::max(0, k - tk_middle); 35 | int tk_start = std::max(0, tk_middle - k); 36 | int kk_end = std::min(k + tk_middle + 1, naz) - ak_start; 37 | 38 | float sum_array = 0; 39 | for (int ii = 0; ii < ii_end; ii++) { 40 | int ti = ti_start + ii; 41 | int ai = ai_start + ii; 42 | for (int jj = 0; jj < jj_end; jj++) { 43 | int tj = tj_start + jj; 44 | int aj = aj_start + jj; 45 | for (int kk = 0; kk < kk_end; kk++) { 46 | int tk = tk_start + kk; 47 | int ak = ak_start + kk; 48 | int a_ind = ai*nay*naz + aj*naz + ak; 49 | int t_ind = ti*nty*ntz + tj*ntz + tk; 50 | float diff = std::abs(array[a_ind] - temp[t_ind]); 51 | sum_array += diff; 52 | } 53 | } 54 | } 55 | 56 | dist_array[i*nay*naz + j*naz + k] = sum_array / (ii_end * jj_end * kk_end); 57 | 58 | } 59 | } 60 | } 61 | 62 | } 63 | 64 | void match_template_mad_norm( 65 | int nax, int nay, int naz, float *array, 66 | int ntx, int nty, int ntz, float *temp, 67 | float *dist_array 68 | ) { 69 | 70 | if (ntx % 2 == 0 || nty % 2 == 0 || ntz % 2 == 0) { 71 | throw std::invalid_argument("Template dimensions must all be odd size."); 72 | } 73 | 74 | int ti_middle = (ntx - 1) / 2; 75 | int tj_middle = (nty - 1) / 2; 76 | int tk_middle = (ntz - 1) / 2; 77 | 78 | for (int i = 0; i < nax; i++) { 79 | int ai_start = std::max(0, i - ti_middle); 80 | int ti_start = std::max(0, ti_middle - i); 81 | int ii_end = std::min(i + ti_middle + 1, nax) - ai_start; 82 | for (int j = 0; j < nay; j++) { 83 | int aj_start = std::max(0, j - tj_middle); 84 | int tj_start = std::max(0, tj_middle - j); 85 | int jj_end = std::min(j + tj_middle + 1, nay) - aj_start; 86 | for (int k = 0; k < naz; k++) { 87 | int ak_start = std::max(0, k - tk_middle); 88 | int tk_start = std::max(0, tk_middle - k); 89 | int kk_end = std::min(k + tk_middle + 1, naz) - ak_start; 90 | 91 | float sum_diff = 0; 92 | float sum_temp = 0; 93 | for (int ii = 0; ii < ii_end; ii++) { 94 | int ti = ti_start + ii; 95 | int ai = ai_start + ii; 96 | for (int jj = 0; jj < jj_end; jj++) { 97 | int tj = tj_start + jj; 98 | int aj = aj_start + jj; 99 | for (int kk = 0; kk < kk_end; kk++) { 100 | int tk = tk_start + kk; 101 | int ak = ak_start + kk; 102 | int a_ind = ai*nay*naz + aj*naz + ak; 103 | int t_ind = ti*nty*ntz + tj*ntz + tk; 104 | float diff = std::abs(array[a_ind] - temp[t_ind]); 105 | sum_diff += diff; 106 | sum_temp += temp[t_ind]; 107 | } 108 | } 109 | } 110 | 111 | float mean_diff = sum_diff / (ii_end * jj_end * kk_end); 112 | float mean_temp = sum_temp / (ii_end * jj_end * kk_end); 113 | dist_array[i*nay*naz + j*naz + k] = mean_diff / mean_temp; 114 | 115 | } 116 | } 117 | } 118 | 119 | } 120 | 121 | void match_template_msd( 122 | int nax, int nay, int naz, float *array, 123 | int ntx, int nty, int ntz, float *temp, 124 | float *dist_array 125 | ) { 126 | 127 | if (ntx % 2 == 0 || nty % 2 == 0 || ntz % 2 == 0) { 128 | throw std::invalid_argument("Template dimensions must all be odd size."); 129 | } 130 | 131 | int ti_middle = (ntx - 1) / 2; 132 | int tj_middle = (nty - 1) / 2; 133 | int tk_middle = (ntz - 1) / 2; 134 | 135 | for (int i = 0; i < nax; i++) { 136 | int ai_start = std::max(0, i - ti_middle); 137 | int ti_start = std::max(0, ti_middle - i); 138 | int ii_end = std::min(i + ti_middle + 1, nax) - ai_start; 139 | for (int j = 0; j < nay; j++) { 140 | int aj_start = std::max(0, j - tj_middle); 141 | int tj_start = std::max(0, tj_middle - j); 142 | int jj_end = std::min(j + tj_middle + 1, nay) - aj_start; 143 | for (int k = 0; k < naz; k++) { 144 | int ak_start = std::max(0, k - tk_middle); 145 | int tk_start = std::max(0, tk_middle - k); 146 | int kk_end = std::min(k + tk_middle + 1, naz) - ak_start; 147 | 148 | float sum_array = 0; 149 | for (int ii = 0; ii < ii_end; ii++) { 150 | int ti = ti_start + ii; 151 | int ai = ai_start + ii; 152 | for (int jj = 0; jj < jj_end; jj++) { 153 | int tj = tj_start + jj; 154 | int aj = aj_start + jj; 155 | for (int kk = 0; kk < kk_end; kk++) { 156 | int tk = tk_start + kk; 157 | int ak = ak_start + kk; 158 | int a_ind = ai*nay*naz + aj*naz + ak; 159 | int t_ind = ti*nty*ntz + tj*ntz + tk; 160 | float diff = array[a_ind] - temp[t_ind]; 161 | sum_array += diff * diff; 162 | } 163 | } 164 | } 165 | 166 | dist_array[i*nay*naz + j*naz + k] = sum_array / (ii_end * jj_end * kk_end); 167 | 168 | } 169 | } 170 | } 171 | 172 | } 173 | 174 | void match_template_msd_norm( 175 | int nax, int nay, int naz, float *array, 176 | int ntx, int nty, int ntz, float *temp, 177 | float *dist_array 178 | ) { 179 | 180 | if (ntx % 2 == 0 || nty % 2 == 0 || ntz % 2 == 0) { 181 | throw std::invalid_argument("Template dimensions must all be odd size."); 182 | } 183 | 184 | int ti_middle = (ntx - 1) / 2; 185 | int tj_middle = (nty - 1) / 2; 186 | int tk_middle = (ntz - 1) / 2; 187 | 188 | for (int i = 0; i < nax; i++) { 189 | int ai_start = std::max(0, i - ti_middle); 190 | int ti_start = std::max(0, ti_middle - i); 191 | int ii_end = std::min(i + ti_middle + 1, nax) - ai_start; 192 | for (int j = 0; j < nay; j++) { 193 | int aj_start = std::max(0, j - tj_middle); 194 | int tj_start = std::max(0, tj_middle - j); 195 | int jj_end = std::min(j + tj_middle + 1, nay) - aj_start; 196 | for (int k = 0; k < naz; k++) { 197 | int ak_start = std::max(0, k - tk_middle); 198 | int tk_start = std::max(0, tk_middle - k); 199 | int kk_end = std::min(k + tk_middle + 1, naz) - ak_start; 200 | 201 | float sum_diff = 0; 202 | float sum_temp = 0; 203 | for (int ii = 0; ii < ii_end; ii++) { 204 | int ti = ti_start + ii; 205 | int ai = ai_start + ii; 206 | for (int jj = 0; jj < jj_end; jj++) { 207 | int tj = tj_start + jj; 208 | int aj = aj_start + jj; 209 | for (int kk = 0; kk < kk_end; kk++) { 210 | int tk = tk_start + kk; 211 | int ak = ak_start + kk; 212 | int a_ind = ai*nay*naz + aj*naz + ak; 213 | int t_ind = ti*nty*ntz + tj*ntz + tk; 214 | float diff = array[a_ind] - temp[t_ind]; 215 | sum_diff += diff * diff; 216 | sum_temp += temp[t_ind] * temp[t_ind]; 217 | } 218 | } 219 | } 220 | 221 | float mean_diff = sum_diff / (ii_end * jj_end * kk_end); 222 | float mean_temp = sum_temp / (ii_end * jj_end * kk_end); 223 | dist_array[i*nay*naz + j*naz + k] = mean_diff / mean_temp; 224 | 225 | } 226 | } 227 | } 228 | 229 | } 230 | 231 | void match_template_mad_2d( 232 | int nax, int nay, float *array, 233 | int ntx, int nty, float *temp, 234 | float *dist_array 235 | ) { 236 | 237 | if (ntx % 2 == 0 || nty % 2 == 0) { 238 | throw std::invalid_argument("Template dimensions must all be odd size."); 239 | } 240 | 241 | int ti_middle = (ntx - 1) / 2; 242 | int tj_middle = (nty - 1) / 2; 243 | 244 | for (int i = 0; i < nax; i++) { 245 | int ai_start = std::max(0, i - ti_middle); 246 | int ti_start = std::max(0, ti_middle - i); 247 | int ii_end = std::min(i + ti_middle + 1, nax) - ai_start; 248 | for (int j = 0; j < nay; j++) { 249 | int aj_start = std::max(0, j - tj_middle); 250 | int tj_start = std::max(0, tj_middle - j); 251 | int jj_end = std::min(j + tj_middle + 1, nay) - aj_start; 252 | 253 | float sum_array = 0; 254 | for (int ii = 0; ii < ii_end; ii++) { 255 | int ti = ti_start + ii; 256 | int ai = ai_start + ii; 257 | for (int jj = 0; jj < jj_end; jj++) { 258 | int tj = tj_start + jj; 259 | int aj = aj_start + jj; 260 | int a_ind = ai*nay + aj; 261 | int t_ind = ti*nty + tj; 262 | float diff = std::abs(array[a_ind] - temp[t_ind]); 263 | sum_array += diff; 264 | } 265 | } 266 | 267 | dist_array[i*nay + j] = sum_array / (ii_end * jj_end); 268 | 269 | } 270 | } 271 | 272 | } 273 | 274 | void match_template_mad_norm_2d( 275 | int nax, int nay, float *array, 276 | int ntx, int nty, float *temp, 277 | float *dist_array 278 | ) { 279 | 280 | if (ntx % 2 == 0 || nty % 2 == 0) { 281 | throw std::invalid_argument("Template dimensions must all be odd size."); 282 | } 283 | 284 | int ti_middle = (ntx - 1) / 2; 285 | int tj_middle = (nty - 1) / 2; 286 | 287 | for (int i = 0; i < nax; i++) { 288 | int ai_start = std::max(0, i - ti_middle); 289 | int ti_start = std::max(0, ti_middle - i); 290 | int ii_end = std::min(i + ti_middle + 1, nax) - ai_start; 291 | for (int j = 0; j < nay; j++) { 292 | int aj_start = std::max(0, j - tj_middle); 293 | int tj_start = std::max(0, tj_middle - j); 294 | int jj_end = std::min(j + tj_middle + 1, nay) - aj_start; 295 | 296 | float sum_diff = 0; 297 | float sum_temp = 0; 298 | for (int ii = 0; ii < ii_end; ii++) { 299 | int ti = ti_start + ii; 300 | int ai = ai_start + ii; 301 | for (int jj = 0; jj < jj_end; jj++) { 302 | int tj = tj_start + jj; 303 | int aj = aj_start + jj; 304 | int a_ind = ai*nay + aj; 305 | int t_ind = ti*nty + tj; 306 | float diff = std::abs(array[a_ind] - temp[t_ind]); 307 | sum_diff += diff; 308 | sum_temp += temp[t_ind]; 309 | } 310 | } 311 | 312 | float mean_diff = sum_diff / (ii_end * jj_end); 313 | float mean_temp = sum_temp / (ii_end * jj_end); 314 | dist_array[i*nay + j] = mean_diff / mean_temp; 315 | 316 | } 317 | } 318 | 319 | } 320 | 321 | void match_template_msd_2d( 322 | int nax, int nay, float *array, 323 | int ntx, int nty, float *temp, 324 | float *dist_array 325 | ) { 326 | 327 | if (ntx % 2 == 0 || nty % 2 == 0) { 328 | throw std::invalid_argument("Template dimensions must all be odd size."); 329 | } 330 | 331 | int ti_middle = (ntx - 1) / 2; 332 | int tj_middle = (nty - 1) / 2; 333 | 334 | for (int i = 0; i < nax; i++) { 335 | int ai_start = std::max(0, i - ti_middle); 336 | int ti_start = std::max(0, ti_middle - i); 337 | int ii_end = std::min(i + ti_middle + 1, nax) - ai_start; 338 | for (int j = 0; j < nay; j++) { 339 | int aj_start = std::max(0, j - tj_middle); 340 | int tj_start = std::max(0, tj_middle - j); 341 | int jj_end = std::min(j + tj_middle + 1, nay) - aj_start; 342 | 343 | float sum_array = 0; 344 | for (int ii = 0; ii < ii_end; ii++) { 345 | int ti = ti_start + ii; 346 | int ai = ai_start + ii; 347 | for (int jj = 0; jj < jj_end; jj++) { 348 | int tj = tj_start + jj; 349 | int aj = aj_start + jj; 350 | int a_ind = ai*nay + aj; 351 | int t_ind = ti*nty + tj; 352 | float diff = array[a_ind] - temp[t_ind]; 353 | sum_array += diff * diff; 354 | } 355 | } 356 | 357 | dist_array[i*nay + j] = sum_array / (ii_end * jj_end); 358 | 359 | } 360 | } 361 | 362 | } 363 | 364 | void match_template_msd_norm_2d( 365 | int nax, int nay, float *array, 366 | int ntx, int nty, float *temp, 367 | float *dist_array 368 | ) { 369 | 370 | if (ntx % 2 == 0 || nty % 2 == 0) { 371 | throw std::invalid_argument("Template dimensions must all be odd size."); 372 | } 373 | 374 | int ti_middle = (ntx - 1) / 2; 375 | int tj_middle = (nty - 1) / 2; 376 | 377 | for (int i = 0; i < nax; i++) { 378 | int ai_start = std::max(0, i - ti_middle); 379 | int ti_start = std::max(0, ti_middle - i); 380 | int ii_end = std::min(i + ti_middle + 1, nax) - ai_start; 381 | for (int j = 0; j < nay; j++) { 382 | int aj_start = std::max(0, j - tj_middle); 383 | int tj_start = std::max(0, tj_middle - j); 384 | int jj_end = std::min(j + tj_middle + 1, nay) - aj_start; 385 | 386 | float sum_diff = 0; 387 | float sum_temp = 0; 388 | for (int ii = 0; ii < ii_end; ii++) { 389 | int ti = ti_start + ii; 390 | int ai = ai_start + ii; 391 | for (int jj = 0; jj < jj_end; jj++) { 392 | int tj = tj_start + jj; 393 | int aj = aj_start + jj; 394 | int a_ind = ai*nay + aj; 395 | int t_ind = ti*nty + tj; 396 | float diff = array[a_ind] - temp[t_ind]; 397 | sum_diff += diff * diff; 398 | sum_temp += temp[t_ind] * temp[t_ind]; 399 | } 400 | } 401 | 402 | float mean_diff = sum_diff / (ii_end * jj_end); 403 | float mean_temp = sum_temp / (ii_end * jj_end); 404 | dist_array[i*nay + j] = mean_diff / mean_temp; 405 | 406 | } 407 | } 408 | 409 | } 410 | 411 | } -------------------------------------------------------------------------------- /src/cuda/bindings.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | from torch.utils.cpp_extension import load 5 | 6 | compile_ok = False 7 | if torch.cuda.is_available(): 8 | this_dir = os.path.dirname(os.path.abspath(__file__)) 9 | peak_find = load(name='peak_find', sources=[ 10 | os.path.join(this_dir, 'peak_find.cpp'), 11 | os.path.join(this_dir, 'matching_cuda.cu'), 12 | os.path.join(this_dir, 'ccl_cuda.cu')], 13 | extra_include_paths=[this_dir] 14 | ) 15 | tm_methods = peak_find.TMMethod.__members__ 16 | compile_ok = True 17 | 18 | def _check_compile(): 19 | if not compile_ok: 20 | raise RuntimeError('Cuda extensions were not compiled correctly. Is cuda available?') 21 | 22 | def match_template(array, template, method='mad'): 23 | _check_compile() 24 | if not (isinstance(array, torch.Tensor) and isinstance(template, torch.Tensor)): 25 | raise ValueError('Input arrays must be torch tensors.') 26 | try: 27 | method = tm_methods[method] 28 | except KeyError: 29 | raise ValueError(f'Unknown matching method `{method}`.') 30 | return peak_find.match_template(array, template, method) 31 | 32 | def ccl(array): 33 | _check_compile() 34 | return peak_find.ccl(array.int()) 35 | 36 | def find_label_min(array, labels): 37 | _check_compile() 38 | assert array.shape == labels.shape 39 | return peak_find.find_label_min(array, labels) 40 | -------------------------------------------------------------------------------- /src/cuda/ccl_cuda.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | __device__ int find_root(const int* __restrict__ labels, int a) { 11 | int l = labels[a]; 12 | while (l != a) { 13 | a = l; 14 | l = labels[a]; 15 | } 16 | return a; 17 | } 18 | 19 | __device__ void tree_union(int* __restrict__ labels, int a, int b) { 20 | bool done = false; 21 | int old; 22 | while (!done) { 23 | a = find_root(labels, a); 24 | b = find_root(labels, b); 25 | if (a < b) { 26 | old = atomicMin(&labels[b], a); 27 | done = (old == b); 28 | b = old; 29 | } else if (b < a){ 30 | old = atomicMin(&labels[a], b); 31 | done = (old == a); 32 | a = old; 33 | } else { 34 | break; 35 | } 36 | } 37 | } 38 | 39 | __global__ void init_labels( 40 | const int* __restrict__ array, 41 | int* __restrict__ labels, 42 | int d, int w, int h 43 | ) { 44 | 45 | int i = blockDim.x * blockIdx.x + threadIdx.x; 46 | int j = blockDim.y * blockIdx.y + threadIdx.y; 47 | int k = blockDim.z * blockIdx.z + threadIdx.z; 48 | 49 | if (i < d && j < w && k < h) { 50 | int a = i*w*h + j*h + k; 51 | labels[a] = (array[a] == 1) ? a : -1; 52 | } 53 | 54 | 55 | } 56 | 57 | __global__ void merge( 58 | const int* __restrict__ array, 59 | int* __restrict__ labels, 60 | int d, int w, int h 61 | ) { 62 | 63 | int i = blockDim.x * blockIdx.x + threadIdx.x; 64 | int j = blockDim.y * blockIdx.y + threadIdx.y; 65 | int k = blockDim.z * blockIdx.z + threadIdx.z; 66 | 67 | if (i < d && j < w && k < h) { 68 | 69 | int wh = w*h; 70 | int a = i*wh + j*h + k; 71 | 72 | bool js = (j > 0); 73 | bool ks = (k > 0); 74 | bool je = (j < (w - 1)); 75 | bool ke = (k < (h - 1)); 76 | 77 | int n; 78 | if (array[a] == 1) { 79 | if (i > 0) { 80 | n = a - wh; if (array[n] == 1) {tree_union(labels, a, n);} 81 | n = a - wh - h; if (js && array[n] == 1) {tree_union(labels, a, n);} 82 | n = a - wh + h; if (je && array[n] == 1) {tree_union(labels, a, n);} 83 | if (ks) { 84 | n = a - wh - 1; if (array[n] == 1) {tree_union(labels, a, n);} 85 | n = a - wh - h - 1; if (js && array[n] == 1) {tree_union(labels, a, n);} 86 | n = a - wh + h - 1; if (je && array[n] == 1) {tree_union(labels, a, n);} 87 | } 88 | if (ke) { 89 | n = a - wh + 1; if (array[n] == 1) {tree_union(labels, a, n);} 90 | n = a - wh - h + 1; if (js && array[n] == 1) {tree_union(labels, a, n);} 91 | n = a - wh + h + 1; if (je && array[n] == 1) {tree_union(labels, a, n);} 92 | } 93 | } 94 | if (js) { 95 | n = a - h; if (array[n] == 1) {tree_union(labels, a, n);} 96 | n = a - h - 1; if (ks && array[n] == 1) {tree_union(labels, a, n);} 97 | n = a - h + 1; if (ke && array[n] == 1) {tree_union(labels, a, n);} 98 | } 99 | n = a - 1; if (ks && array[n] == 1) {tree_union(labels, a, n);} 100 | } 101 | 102 | } 103 | 104 | } 105 | 106 | __global__ void compress( 107 | const int* __restrict__ array, 108 | int* __restrict__ labels, 109 | int d, int w, int h 110 | ) { 111 | 112 | int i = blockDim.x * blockIdx.x + threadIdx.x; 113 | int j = blockDim.y * blockIdx.y + threadIdx.y; 114 | int k = blockDim.z * blockIdx.z + threadIdx.z; 115 | 116 | if (i < d && j < w && k < h) { 117 | int a = i*w*h + j*h + k; 118 | if (array[a] == 1) { 119 | int b = a; 120 | int l = labels[b]; 121 | while (l != b) { 122 | b = l; 123 | l = labels[b]; 124 | labels[a] = b; 125 | } 126 | } 127 | } 128 | 129 | } 130 | 131 | __global__ void find_min_inds( 132 | const float* __restrict__ array, 133 | const int* __restrict__ labels, 134 | int* __restrict__ max_ind_array, 135 | int d, int w, int h, 136 | int* __restrict__ counter, int* __restrict__ unique_labels 137 | ) { 138 | 139 | int i = blockDim.x * blockIdx.x + threadIdx.x; 140 | int j = blockDim.y * blockIdx.y + threadIdx.y; 141 | int k = blockDim.z * blockIdx.z + threadIdx.z; 142 | 143 | if (i < d && j < w && k < h) { 144 | 145 | int a = i*w*h + j*h + k; 146 | int root = labels[a]; 147 | 148 | if (root >= 0) { 149 | 150 | int *address = &max_ind_array[root]; 151 | int old_ind = max_ind_array[root]; 152 | float new_val = array[a]; 153 | int assumed; 154 | float old_val; 155 | do { 156 | if (old_ind < 0 || new_val < array[old_ind]) { 157 | assumed = old_ind; 158 | old_ind = atomicCAS(address, assumed, a); 159 | } else { 160 | break; 161 | } 162 | } while (assumed != old_ind); 163 | 164 | if (a == root) { 165 | int label_ind = atomicAdd(counter, 1); 166 | unique_labels[label_ind] = a; 167 | } 168 | 169 | } 170 | } 171 | 172 | } 173 | 174 | __global__ void cpy_max_inds( 175 | int* __restrict__ max_inds, 176 | int* __restrict__ max_ind_array, 177 | int* __restrict__ unique_labels, 178 | int* __restrict__ counter, 179 | int d, int w, int h 180 | ) { 181 | 182 | for (int i = threadIdx.x; i < *counter; i += blockDim.x) { 183 | int ind = max_ind_array[unique_labels[i]]; 184 | max_inds[3*i ] = ind / (w*h); 185 | max_inds[3*i+1] = (ind / h) % w; 186 | max_inds[3*i+2] = ind % h; 187 | } 188 | 189 | } 190 | 191 | torch::Tensor ccl_cuda(torch::Tensor array) { 192 | 193 | torch::Tensor labels = torch::empty_like(array); 194 | 195 | dim3 threads(8, 8, 8); 196 | dim3 blocks( 197 | ceil(float(array.size(1)) / threads.x), 198 | ceil(float(array.size(2)) / threads.y), 199 | ceil(float(array.size(3)) / threads.z) 200 | ); 201 | 202 | int batch_size = array.size(0); 203 | int d = array.size(1); 204 | int w = array.size(2); 205 | int h = array.size(3); 206 | 207 | for (int b = 0; b < batch_size; b++) { 208 | 209 | int* array_ptr = (int*) &array.data()[b*d*w*h]; 210 | int* label_ptr = (int*) &labels.data()[b*d*w*h]; 211 | 212 | init_labels<<>>( 213 | array_ptr, label_ptr, 214 | d, w, h 215 | ); 216 | 217 | merge<<>>( 218 | array_ptr, label_ptr, 219 | d, w, h 220 | ); 221 | 222 | compress<<>>( 223 | array_ptr, label_ptr, 224 | d, w, h 225 | ); 226 | 227 | } 228 | 229 | return labels; 230 | 231 | }; 232 | 233 | std::vector find_label_min_cuda(torch::Tensor array, torch::Tensor labels) { 234 | 235 | cudaSetDevice(array.device().index()); 236 | 237 | dim3 threads(8, 8, 8); 238 | dim3 blocks( 239 | ceil(float(array.size(1)) / threads.x), 240 | ceil(float(array.size(2)) / threads.y), 241 | ceil(float(array.size(3)) / threads.z) 242 | ); 243 | 244 | int batch_size = labels.size(0); 245 | int d = labels.size(1); 246 | int w = labels.size(2); 247 | int h = labels.size(3); 248 | 249 | std::vector max_inds; 250 | torch::Tensor max_ind_array = torch::full_like(labels, -1); 251 | 252 | for (int b = 0; b < batch_size; b++) { 253 | 254 | int counter = 0; 255 | int* counter_d; 256 | cudaMalloc(&counter_d, sizeof(int)); 257 | cudaMemcpy(counter_d, &counter, sizeof(int), cudaMemcpyHostToDevice); 258 | 259 | torch::Tensor unique_labels = torch::zeros(MAX_LABELS, 260 | torch::TensorOptions().dtype(torch::kInt32).device(array.device())); 261 | 262 | float* array_ptr = (float*) &array.data()[b*d*w*h]; 263 | int* label_ptr = (int*) &labels.data()[b*d*w*h]; 264 | int* max_ind_array_ptr = (int*) &max_ind_array.data()[b*d*w*h]; 265 | 266 | find_min_inds<<>>( 267 | array_ptr, label_ptr,max_ind_array_ptr, 268 | d, w, h, counter_d, unique_labels.data() 269 | ); 270 | 271 | cudaMemcpy(&counter, counter_d, sizeof(int), cudaMemcpyDeviceToHost); 272 | assert (counter <= MAX_LABELS); 273 | 274 | torch::Tensor max_inds_ = torch::zeros({counter, 3}, 275 | torch::TensorOptions().dtype(torch::kInt32).device(array.device())); 276 | 277 | cpy_max_inds<<<1, 32>>>( 278 | max_inds_.data(), max_ind_array_ptr, 279 | unique_labels.data(), counter_d, 280 | d, w, h 281 | ); 282 | 283 | max_inds.push_back(max_inds_); 284 | 285 | } 286 | 287 | return max_inds; 288 | 289 | } -------------------------------------------------------------------------------- /src/cuda/defs.h: -------------------------------------------------------------------------------- 1 | 2 | // Pytorch C++/Cuda extension tutorial: 3 | // https://pytorch.org/tutorials/advanced/cpp_extension.html 4 | 5 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 6 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 7 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 8 | #define CHECK3D(x) TORCH_CHECK(x.dim() == 3, #x " must be a 3D tensor") 9 | #define CHECK4D(x) TORCH_CHECK(x.dim() == 4, #x " must be a 4D tensor") 10 | 11 | #define MAX_LABELS 1024 12 | 13 | enum TMMethod { 14 | TM_MAD, 15 | TM_MSD, 16 | TM_MAD_NORM, 17 | TM_MSD_NORM 18 | }; 19 | 20 | torch::Tensor match_template_cuda(torch::Tensor array, torch::Tensor temp, TMMethod method); 21 | torch::Tensor ccl_cuda(torch::Tensor array); 22 | std::vector find_label_min_cuda(torch::Tensor array, torch::Tensor labels); 23 | -------------------------------------------------------------------------------- /src/cuda/matching_cuda.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | __device__ float sum_mad( 10 | int ti_start, int tj_start, int tk_start, 11 | int ai_start, int aj_start, int ak_start, 12 | int i_range, int j_range, int k_range, 13 | int aj_size, int ak_size, float *array, 14 | int tj_size, int tk_size, float *temp 15 | ) { 16 | 17 | float sum_array = 0; 18 | for (int ii = 0; ii < i_range; ii++) { 19 | int ti = ti_start + ii; 20 | int ai = ai_start + ii; 21 | for (int jj = 0; jj < j_range; jj++) { 22 | int tj = tj_start + jj; 23 | int aj = aj_start + jj; 24 | for (int kk = 0; kk < k_range; kk++) { 25 | int tk = tk_start + kk; 26 | int ak = ak_start + kk; 27 | sum_array += abs( 28 | array[ai*ak_size*aj_size + aj*ak_size + ak] 29 | - temp[ti*tj_size*tk_size + tj*tk_size + tk]); 30 | } 31 | } 32 | } 33 | 34 | return sum_array / (i_range * j_range * k_range); 35 | 36 | } 37 | 38 | __device__ float sum_msd( 39 | int ti_start, int tj_start, int tk_start, 40 | int ai_start, int aj_start, int ak_start, 41 | int i_range, int j_range, int k_range, 42 | int aj_size, int ak_size, float *array, 43 | int tj_size, int tk_size, float *temp 44 | ) { 45 | 46 | float sum_array = 0; 47 | for (int ii = 0; ii < i_range; ii++) { 48 | int ti = ti_start + ii; 49 | int ai = ai_start + ii; 50 | for (int jj = 0; jj < j_range; jj++) { 51 | int tj = tj_start + jj; 52 | int aj = aj_start + jj; 53 | for (int kk = 0; kk < k_range; kk++) { 54 | int tk = tk_start + kk; 55 | int ak = ak_start + kk; 56 | float diff = array[ai*ak_size*aj_size + aj*ak_size + ak] 57 | - temp[ti*tj_size*tk_size + tj*tk_size + tk]; 58 | sum_array += diff*diff; 59 | } 60 | } 61 | } 62 | 63 | return sum_array / (i_range * j_range * k_range); 64 | 65 | } 66 | 67 | __device__ float sum_mad_norm( 68 | int ti_start, int tj_start, int tk_start, 69 | int ai_start, int aj_start, int ak_start, 70 | int i_range, int j_range, int k_range, 71 | int aj_size, int ak_size, float *array, 72 | int tj_size, int tk_size, float *temp 73 | ) { 74 | 75 | float sum_diff = 0; 76 | float sum_temp = 0; 77 | for (int ii = 0; ii < i_range; ii++) { 78 | int ti = ti_start + ii; 79 | int ai = ai_start + ii; 80 | for (int jj = 0; jj < j_range; jj++) { 81 | int tj = tj_start + jj; 82 | int aj = aj_start + jj; 83 | for (int kk = 0; kk < k_range; kk++) { 84 | int tk = tk_start + kk; 85 | int ak = ak_start + kk; 86 | int a_ind = ai*ak_size*aj_size + aj*ak_size + ak; 87 | int t_ind = ti*tj_size*tk_size + tj*tk_size + tk; 88 | sum_diff += abs(array[a_ind] - temp[t_ind]); 89 | sum_temp += temp[t_ind]; 90 | } 91 | } 92 | } 93 | 94 | float mean_diff = sum_diff / (i_range * j_range * k_range); 95 | float mean_temp = sum_temp / (i_range * j_range * k_range); 96 | 97 | return mean_diff / mean_temp; 98 | 99 | } 100 | 101 | __device__ float sum_msd_norm( 102 | int ti_start, int tj_start, int tk_start, 103 | int ai_start, int aj_start, int ak_start, 104 | int i_range, int j_range, int k_range, 105 | int aj_size, int ak_size, float *array, 106 | int tj_size, int tk_size, float *temp 107 | ) { 108 | 109 | float sum_diff = 0; 110 | float sum_temp = 0; 111 | for (int ii = 0; ii < i_range; ii++) { 112 | int ti = ti_start + ii; 113 | int ai = ai_start + ii; 114 | for (int jj = 0; jj < j_range; jj++) { 115 | int tj = tj_start + jj; 116 | int aj = aj_start + jj; 117 | for (int kk = 0; kk < k_range; kk++) { 118 | int tk = tk_start + kk; 119 | int ak = ak_start + kk; 120 | int a_ind = ai*ak_size*aj_size + aj*ak_size + ak; 121 | int t_ind = ti*tj_size*tk_size + tj*tk_size + tk; 122 | float diff = array[a_ind] - temp[t_ind]; 123 | sum_diff += diff * diff; 124 | sum_temp += temp[t_ind] * temp[t_ind]; 125 | } 126 | } 127 | } 128 | 129 | float mean_diff = sum_diff / (i_range * j_range * k_range); 130 | float mean_temp = sum_temp / (i_range * j_range * k_range); 131 | 132 | return mean_diff / mean_temp; 133 | 134 | } 135 | 136 | template 137 | __global__ void match_template_mad( 138 | const torch::PackedTensorAccessor32 array, 139 | const torch::PackedTensorAccessor32 temp, 140 | torch::PackedTensorAccessor32 dist_array 141 | ) { 142 | 143 | int nax = array.size(1); 144 | int nay = array.size(2); 145 | int naz = array.size(3); 146 | 147 | int ib = blockIdx.x * blockDim.x + threadIdx.x; 148 | int xthreads_per_batch = ceil(float(nax) / blockDim.x) * blockDim.x; 149 | 150 | int b = ib / xthreads_per_batch; 151 | int i = ib % xthreads_per_batch; 152 | int j = blockIdx.y * blockDim.y + threadIdx.y; 153 | int k = blockIdx.z * blockDim.z + threadIdx.z; 154 | 155 | if (i < nax && j < nay && k < naz) { 156 | 157 | int ntx = temp.size(0); 158 | int nty = temp.size(1); 159 | int ntz = temp.size(2); 160 | 161 | int ti_middle = (ntx - 1) / 2; 162 | int tj_middle = (nty - 1) / 2; 163 | int tk_middle = (ntz - 1) / 2; 164 | 165 | int ai_start = max(0, i - ti_middle); 166 | int aj_start = max(0, j - tj_middle); 167 | int ak_start = max(0, k - tk_middle); 168 | 169 | int ti_start = max(0, ti_middle - i); 170 | int tj_start = max(0, tj_middle - j); 171 | int tk_start = max(0, tk_middle - k); 172 | 173 | int ii_end = min(i + ti_middle + 1, nax) - ai_start; 174 | int jj_end = min(j + tj_middle + 1, nay) - aj_start; 175 | int kk_end = min(k + tk_middle + 1, naz) - ak_start; 176 | 177 | float sum_array = 0; 178 | for (int ii = 0; ii < ii_end; ii++) { 179 | int ti = ti_start + ii; 180 | int ai = ai_start + ii; 181 | for (int jj = 0; jj < jj_end; jj++) { 182 | int tj = tj_start + jj; 183 | int aj = aj_start + jj; 184 | for (int kk = 0; kk < kk_end; kk++) { 185 | int tk = tk_start + kk; 186 | int ak = ak_start + kk; 187 | float diff = abs(array[b][ai][aj][ak] - temp[ti][tj][tk]); 188 | sum_array += diff; 189 | } 190 | } 191 | } 192 | 193 | float d = sum_array / (ii_end * jj_end * kk_end); 194 | dist_array[b][i][j][k] = d; 195 | 196 | } 197 | 198 | } 199 | 200 | template 201 | __global__ void match_template_kernel( 202 | const torch::PackedTensorAccessor32 array, 203 | const torch::PackedTensorAccessor32 temp, 204 | torch::PackedTensorAccessor32 dist_array, 205 | int b, TMMethod method 206 | ) { 207 | 208 | int nax = array.size(1); 209 | int nay = array.size(2); 210 | int naz = array.size(3); 211 | 212 | int ntx = temp.size(0); 213 | int nty = temp.size(1); 214 | int ntz = temp.size(2); 215 | 216 | int i = blockIdx.x * blockDim.x + threadIdx.x; 217 | int j = blockIdx.y * blockDim.y + threadIdx.y; 218 | int k = blockIdx.z * blockDim.z + threadIdx.z; 219 | 220 | int ti_middle = (ntx - 1) / 2; 221 | int tj_middle = (nty - 1) / 2; 222 | int tk_middle = (ntz - 1) / 2; 223 | 224 | // Shared memory 225 | extern __shared__ float shmem[]; 226 | float *temp_s = (float*) shmem; 227 | float *array_s = (float*) &shmem[ntx*nty*ntz]; 228 | 229 | // Copy template to shared memory 230 | for (int ti = threadIdx.x; ti < ntx; ti += blockDim.x) { 231 | for (int tj = threadIdx.y; tj < nty; tj += blockDim.y) { 232 | for (int tk = threadIdx.z; tk < ntz; tk += blockDim.z) { 233 | temp_s[ti*nty*ntz + tj*ntz + tk] = temp[ti][tj][tk]; 234 | } 235 | } 236 | } 237 | 238 | // Copy subarray to shared memory 239 | int i_start = blockIdx.x * blockDim.x; 240 | int j_start = blockIdx.y * blockDim.y; 241 | int k_start = blockIdx.z * blockDim.z; 242 | int ai_start = max(0, i_start - ti_middle); 243 | int aj_start = max(0, j_start - tj_middle); 244 | int ak_start = max(0, k_start - tk_middle); 245 | int i_range = min(i_start + ti_middle + int(blockDim.x), nax) - ai_start; 246 | int j_range = min(j_start + tj_middle + int(blockDim.y), nay) - aj_start; 247 | int k_range = min(k_start + tk_middle + int(blockDim.z), naz) - ak_start; 248 | for (int ii = threadIdx.x; ii < i_range; ii += blockDim.x) { 249 | int ai = ai_start + ii; 250 | for (int jj = threadIdx.y; jj < j_range; jj += blockDim.y) { 251 | int aj = aj_start + jj; 252 | for (int kk = threadIdx.z; kk < k_range; kk += blockDim.z) { 253 | int ak = ak_start + kk; 254 | array_s[ii*k_range*j_range + jj*k_range + kk] = array[b][ai][aj][ak]; 255 | } 256 | } 257 | } 258 | 259 | __syncthreads(); 260 | 261 | if (i < nax && j < nay && k < naz) { 262 | 263 | int tsi_start = max(0, ti_middle - i); 264 | int tsj_start = max(0, tj_middle - j); 265 | int tsk_start = max(0, tk_middle - k); 266 | 267 | int asi_start = max(0, int(threadIdx.x) + min(0, i_start - ti_middle)); 268 | int asj_start = max(0, int(threadIdx.y) + min(0, j_start - tj_middle)); 269 | int ask_start = max(0, int(threadIdx.z) + min(0, k_start - tk_middle)); 270 | 271 | int si_range = min(i + ti_middle + 1, nax) - max(0, i - ti_middle); 272 | int sj_range = min(j + tj_middle + 1, nay) - max(0, j - tj_middle); 273 | int sk_range = min(k + tk_middle + 1, naz) - max(0, k - tk_middle); 274 | 275 | if (method == TM_MAD) { 276 | dist_array[b][i][j][k] = sum_mad( 277 | tsi_start, tsj_start, tsk_start, 278 | asi_start, asj_start, ask_start, 279 | si_range, sj_range, sk_range, 280 | j_range, k_range, array_s, 281 | nty, ntz, temp_s 282 | ); 283 | } else if (method == TM_MSD) { 284 | dist_array[b][i][j][k] = sum_msd( 285 | tsi_start, tsj_start, tsk_start, 286 | asi_start, asj_start, ask_start, 287 | si_range, sj_range, sk_range, 288 | j_range, k_range, array_s, 289 | nty, ntz, temp_s 290 | ); 291 | } else if (method == TM_MAD_NORM) { 292 | dist_array[b][i][j][k] = sum_mad_norm( 293 | tsi_start, tsj_start, tsk_start, 294 | asi_start, asj_start, ask_start, 295 | si_range, sj_range, sk_range, 296 | j_range, k_range, array_s, 297 | nty, ntz, temp_s 298 | ); 299 | } else if (method == TM_MSD_NORM) { 300 | dist_array[b][i][j][k] = sum_msd_norm( 301 | tsi_start, tsj_start, tsk_start, 302 | asi_start, asj_start, ask_start, 303 | si_range, sj_range, sk_range, 304 | j_range, k_range, array_s, 305 | nty, ntz, temp_s 306 | ); 307 | } 308 | 309 | 310 | } 311 | 312 | } 313 | 314 | torch::Tensor match_template_cuda(torch::Tensor array, torch::Tensor temp, TMMethod method) { 315 | 316 | cudaSetDevice(array.device().index()); 317 | 318 | if (temp.size(0) % 2 == 0 || temp.size(1) % 2 == 0 || temp.size(2) % 2 == 0) { 319 | throw std::invalid_argument("Template dimensions must all be odd size."); 320 | } 321 | 322 | auto dist_array = torch::zeros_like(array); 323 | 324 | dim3 threads(10, 10, 2); 325 | dim3 blocks( 326 | ceil(float(array.size(1)) / threads.x), 327 | ceil(float(array.size(2)) / threads.y), 328 | ceil(float(array.size(3)) / threads.z) 329 | ); 330 | int batch_size = array.size(0); 331 | 332 | int temp_size = temp.size(0) * temp.size(1) * temp.size(2); 333 | int subarray_size = (temp.size(0)+threads.x-1) * (temp.size(1)+threads.y-1) * (temp.size(2)+threads.z-1); 334 | 335 | for (int batch_ind = 0; batch_ind < batch_size; batch_ind++) { 336 | AT_DISPATCH_FLOATING_TYPES(array.type(), "match_template_kernel", ([&] { 337 | match_template_kernel<<>>( 338 | array.packed_accessor32(), 339 | temp.packed_accessor32(), 340 | dist_array.packed_accessor32(), 341 | batch_ind, method 342 | ); 343 | })); 344 | } 345 | 346 | return dist_array; 347 | 348 | } 349 | -------------------------------------------------------------------------------- /src/cuda/peak_find.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | 4 | #include 5 | 6 | torch::Tensor match_template(torch::Tensor array, torch::Tensor temp, TMMethod method) { 7 | CHECK_INPUT(array); 8 | CHECK_INPUT(temp); 9 | CHECK4D(array); 10 | CHECK3D(temp); 11 | return match_template_cuda(array, temp, method); 12 | } 13 | 14 | torch::Tensor ccl(torch::Tensor array) { 15 | CHECK_INPUT(array); 16 | CHECK4D(array); 17 | return ccl_cuda(array); 18 | } 19 | 20 | std::vector find_label_min(torch::Tensor array, torch::Tensor labels) { 21 | CHECK_INPUT(array); 22 | CHECK_INPUT(labels); 23 | CHECK4D(array); 24 | CHECK4D(labels); 25 | return find_label_min_cuda(array, labels); 26 | } 27 | 28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 29 | m.def("match_template", &match_template, "Match template"); 30 | m.def("ccl", &ccl, "Connected-component labelling"); 31 | m.def("find_label_min", &find_label_min, "Find min indices in labelled regions"); 32 | py::enum_(m, "TMMethod") 33 | .value("mad", TMMethod::TM_MAD) 34 | .value("msd", TMMethod::TM_MSD) 35 | .value("mad_norm", TMMethod::TM_MAD_NORM) 36 | .value("msd_norm", TMMethod::TM_MSD_NORM) 37 | .export_values(); 38 | m.attr("MAX_LABELS") = MAX_LABELS; 39 | } 40 | -------------------------------------------------------------------------------- /src/data_loading.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | import os 4 | import h5py 5 | import glob 6 | import time 7 | import random 8 | import numpy as np 9 | 10 | import torch 11 | from torch.utils.data import Dataset, DataLoader, DistributedSampler 12 | from utils import Atom, MoleculeGraph 13 | 14 | class HDF5Dataset(Dataset): 15 | ''' 16 | Pytorch dataset for AFM data using HDF5 database. 17 | 18 | Arguments: 19 | hdf5_path: str. Path to HDF5 database file. 20 | mode: 'train', 'val', or 'test'. Which dataset to use. 21 | ''' 22 | def __init__(self, hdf5_path, mode='train'): 23 | self.hdf5_path = hdf5_path 24 | self.mode = mode 25 | if self.mode not in ['train', 'val', 'test']: 26 | raise ValueError(f'mode should be one of "train", "val", or "test", but got {self.mode}') 27 | 28 | def __len__(self): 29 | with h5py.File(self.hdf5_path, 'r') as f: 30 | length = len(f[self.mode]['X']) 31 | return length 32 | 33 | def __getitem__(self, idx): 34 | with h5py.File(self.hdf5_path, 'r') as f: 35 | dataset = f[self.mode] 36 | X = dataset['X'][idx] 37 | Y = dataset['Y'][idx] if 'Y' in dataset.keys() else [] 38 | xyz = self._unpad_xyz(dataset['xyz'][idx]) 39 | return X, Y, xyz 40 | 41 | def _unpad_xyz(self, xyz): 42 | return xyz[xyz[:,-1] > 0] 43 | 44 | def _worker_init_fn(worker_id): 45 | np.random.seed(int((time.time() % 1e5)*1000) + worker_id) 46 | 47 | class _collate_wrapper(): 48 | 49 | def __init__(self, collate_fn, preproc_fn): 50 | self.collate_fn = collate_fn 51 | self.preproc_fn = preproc_fn 52 | 53 | def __call__(self, batch): 54 | 55 | # Combine samples into a batch 56 | Xs = [] 57 | Ys = [] 58 | xyzs = [] 59 | for X, Y, xyz in batch: 60 | Xs.append(X) 61 | Ys.append(Y) 62 | xyzs.append(xyz) 63 | Xs = list(np.stack(Xs, axis=0).transpose(1, 0, 2, 3, 4)) 64 | if len(Ys[0]) > 0: 65 | Ys = list(np.stack(Ys, axis=0).transpose(1, 0, 2, 3)) 66 | 67 | # Run preprocessing and collate 68 | batch = (Xs, Ys, xyzs) 69 | batch = self.preproc_fn(batch) 70 | batch = self.collate_fn(batch) 71 | 72 | return batch 73 | 74 | def collate(batch): 75 | ''' 76 | Collate graph samples into a batch. 77 | 78 | Arguments: 79 | batch: tuple (X, mols, removed, xyz, ref_graphs, ref_dist, ref_atoms, box_borders) 80 | X: np.ndarray of shape (batch_size, x, y, z). Input AFM image. 81 | mols: list of MoleculeGraph. Input molecules. 82 | removed: list of tuples (atom, bonds), where atom is an Atom object and bonds is a list of 83 | 0s and 1s indicating the existence of bond connection to atoms in mols. 84 | xyz: list of np.ndarray of shape (num_atoms, 5). List of original molecules. 85 | ref_graphs: list of MoleculeGraph. Complete reference graphs before atom removals. 86 | ref_dist: np.ndarray of shape (batch_size, x, y, z). Reference position distribution. 87 | ref_atoms: list of Atom. Reference atom positions and classes for teacher forcing. 88 | box_borders: tuple ((x_start, y_start, z_start), (x_end, y_end, z_end)). Real-space extent of the 89 | position distribution region in angstroms. 90 | 91 | Returns: tuple (X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, Ns, xyz, ref_graphs, box_borders) 92 | X: torch.Tensor of shape (batch_size, 1, x, y, z). Input AFM images. 93 | node_inputs: torch.Tensor of shape (total_atoms, 3+n_classes), where total_atoms is the total number 94 | of atoms in input molecules. Graph node inputs. 95 | edges: torch.Tensor of shape (2, total_edges), where total_edges is the total number of edge connections 96 | in input molecules. Edge connections between input nodes. 97 | node_rem: list of torch.Tensor of shape (n_nodes, 4). Coordinates and classes of nodes to be predicted. Each batch 98 | item has a varying number of nodes n_nodes that can be predicted. 99 | edge_rem: list of torch.Tensor of shape (num_atoms,). Indicators for bond connections of the node to be predicted. 100 | terminate: torch.Tensor of shape (batch_size,). Indicator list for whether the molecule graph is complete. 101 | ref_dist: torch.Tensor of shape (batch_size, x, y, z). Reference position distribution. 102 | ref_atoms: torch.Tensor of shape (batch_size, 4). Reference atom positions and classes for teacher forcing. 103 | Ns: list in ints. Number of input nodes in each batch item. 104 | xyz: list of np.ndarray of shape (num_atoms, 5). Unchanged from input argument. 105 | ref_graphs: list of MoleculeGraph. Unchanged from input argument. 106 | box_borders: tuple. Unchanged from input argument. 107 | ''' 108 | 109 | X, mols, removed, xyz, ref_graphs, ref_dist, ref_atoms, box_borders = batch 110 | assert len(X) == len(mols) == len(removed) == len(xyz) == len(ref_graphs) == len(ref_dist) 111 | 112 | mol_arrays = [] 113 | edges = [] 114 | edge_rem = [] 115 | node_rem = [] 116 | terminate = [] 117 | ind_count = 0 118 | Ns = [] 119 | 120 | for i, (mol, rem) in enumerate(zip(mols, removed)): 121 | 122 | if (mol_array := mol.array(xyz=True, class_weights=True)) != []: 123 | mol_arrays.append(mol_array) 124 | edges += [[b[0]+ind_count, b[1]+ind_count] for b in mol.bonds] 125 | 126 | if len(rem) > 0: 127 | e = []; n = []; t = 0 128 | for atom, bond_rem in rem: 129 | e.append(torch.tensor(bond_rem).float()) 130 | n.append(atom.array(xyz=True, class_index=True)) 131 | else: 132 | e = [torch.zeros(len(mol))] 133 | n = [np.zeros(4)] 134 | t = 1 135 | edge_rem.append(torch.from_numpy(np.stack(e, axis=0)).float()) 136 | node_rem.append(torch.from_numpy(np.stack(n, axis=0)).float()) 137 | terminate.append(t) 138 | 139 | ind_count += len(mol) 140 | Ns.append(len(mol)) 141 | 142 | terminate = torch.tensor(terminate).float() 143 | 144 | X = torch.from_numpy(X).float() 145 | if X.ndim == 4: 146 | X = X.unsqueeze(1) 147 | 148 | ref_dist = torch.from_numpy(ref_dist).float() 149 | 150 | if len(mol_arrays) > 0: 151 | node_inputs = torch.from_numpy(np.concatenate(mol_arrays, axis=0)).float() 152 | else: 153 | node_inputs = torch.empty((0)) 154 | edges = torch.tensor(edges).long().T 155 | 156 | ref_atoms = np.stack([r.array(xyz=True, class_index=True) if r else np.zeros(4) 157 | for r in ref_atoms], axis=0) 158 | ref_atoms = torch.from_numpy(ref_atoms).float() 159 | 160 | return X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, Ns, xyz, ref_graphs, box_borders 161 | 162 | def uncollate(pred, batch): 163 | ''' 164 | Convert graph batch back into separated format. 165 | Arguments: 166 | pred: tuple (pred_nodes, pred_edges, pred_terminate, pred_dist) 167 | pred_nodes: torch.Tensor of shape (batch_size, 3+n_classes). Predicted nodes. 168 | pred_edges: list of torch.Tensor of shape (num_atoms,). Predicted probabilities for edge connections. 169 | pred_dist: torch.Tensor of shape (batch_size, x, y, z). Predicted distribution for atom position. 170 | batch: tuple (X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atom, Ns, xyz, ref_graphs). 171 | Same as return values of collate_grid. 172 | Returns: tuple (X, mols, pred, ref, xyz, ref_dist, ref_graph) 173 | X: np.ndarray of shape (batch_size, x, y, z). Input AFM image. 174 | mols: list of MoleculeGraph. Input molecules. 175 | pred: list of tuples (atom, bond, pred_dist), where atom is an Atom object, bond is a list, 176 | and pred_dist is np.ndarray of shape (x, y, z). 177 | ref: list of tuples (atom, bond, terminate, ref_dist), where atom is an Atom object, bond is a list, 178 | terminate is an int, and ref_dist is np.ndarray of shape (x, y, z). 179 | xyz: list of np.ndarray of shape (num_atoms, 5). Same as input xyz. 180 | ref_dist: np.ndarray of shape (batch_size, x, y, z). Reference position distribution. 181 | ref_graphs: list of MoleculeGraph. Unchanged from input. 182 | ''' 183 | X, node_inputs, edges, node_rem, edge_rem, terminate, ref_dist, ref_atoms, Ns, xyz, ref_graphs = batch 184 | 185 | n_classes = pred[0].size(1) - 3 186 | 187 | X = X.squeeze().numpy() 188 | node_inputs = [n.numpy() for n in torch.split(node_inputs, split_size_or_sections=Ns)] 189 | edges = edges.numpy() 190 | 191 | node_rem = [n.numpy() for n in node_rem] 192 | edge_rem = [e.numpy() for e in edge_rem] 193 | terminate = terminate.numpy() 194 | ref_dist = ref_dist.numpy() 195 | ref_atoms = ref_atoms.numpy() 196 | 197 | pred_nodes, pred_edges, pred_dist = pred 198 | pred_class_weights = torch.nn.functional.softmax(pred_nodes[:,3:], dim=1).numpy() 199 | pred_xyz = pred_nodes[:,:3].numpy() 200 | pred_edges = [list(e.numpy()) for e in pred_edges] 201 | pred_dist = pred_dist.numpy() 202 | 203 | mols = [] 204 | ref = [] 205 | pred = [] 206 | count = 0 207 | prev_ind = 0 208 | 209 | for i, N in enumerate(Ns): 210 | 211 | atoms = [Atom(a[:3], class_weights=a[3:]) for a in node_inputs[i]] 212 | bonds = [] 213 | if edges.size > 0: 214 | ind = np.searchsorted(edges[0], count+N) 215 | for j in range(prev_ind, ind): 216 | bonds.append((edges[0,j]-count, edges[1,j]-count)) 217 | prev_ind = ind 218 | count += N 219 | mols.append(MoleculeGraph(atoms, bonds)) 220 | 221 | pred_atom = Atom(pred_xyz[i], class_weights=pred_class_weights[i]) 222 | pred.append((pred_atom, pred_edges[i], pred_dist[i])) 223 | 224 | ref_pairs = [] 225 | for a, b in zip(node_rem[i], edge_rem[i]): 226 | ref_atom = Atom(a[:3], class_weights=np.eye(n_classes)[int(a[3])]) 227 | ref_bonds = [int(v) for v in b] 228 | ref_pairs.append((ref_atom, ref_bonds)) 229 | ref.append((ref_pairs, int(terminate[i]), ref_dist[i])) 230 | 231 | return X, mols, pred, ref, xyz, ref_dist, pred_dist, ref_graphs, ref_atoms 232 | 233 | def make_hdf5_dataloader(datadir, preproc_fn, collate_fn=collate, mode='train', batch_size=30, 234 | shuffle=True, num_workers=6, world_size=1, rank=0): 235 | ''' 236 | Produce a dataset and dataloader from data directory. 237 | 238 | Arguments: 239 | hdf5_path: str. Path to HDF5 database file. 240 | preproc_fn: Python function. Preprocessing function to apply to each batch. 241 | collate_fn: Python function. Collate function that returns each batch. 242 | mode: 'train', 'val', or 'test'. Which dataset to use. 243 | batch_size: int. Number of samples in each batch. 244 | shuffle: bool. Whether to shuffle the sample order on each epoch. 245 | num_workers: int. Number of parallel processes for data loading. 246 | world_size: int. Number of parallel processes if using distributed training. 247 | rank: int. Index of current process if using distributed training. 248 | 249 | Returns: tuple (dataset, dataloader, sampler) 250 | dataset: HDF5Dataset. 251 | dataloader: DataLoader. 252 | sampler: DistributedSampler or None if world size == 1. 253 | ''' 254 | dataset = HDF5Dataset(datadir, mode) 255 | if world_size > 1: 256 | sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=(mode == 'train')) 257 | shuffle = False 258 | else: 259 | sampler = None 260 | dataloader = DataLoader( 261 | dataset, 262 | batch_size=batch_size, 263 | shuffle=shuffle, 264 | collate_fn=_collate_wrapper(collate_fn, preproc_fn), 265 | sampler=sampler, 266 | num_workers=num_workers, 267 | worker_init_fn=_worker_init_fn, 268 | timeout=300, 269 | pin_memory=True 270 | ) 271 | return dataset, dataloader, sampler 272 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def _get_padding(kernel_size, nd): 7 | if isinstance(kernel_size, int): 8 | kernel_size = (kernel_size, )*nd 9 | padding = [] 10 | for i in range(nd): 11 | padding += [(kernel_size[i]-1) // 2] 12 | return tuple(padding) 13 | 14 | class _ConvNdBlock(nn.Module): 15 | 16 | def __init__(self, 17 | in_channels, 18 | out_channels, 19 | nd, 20 | kernel_size=3, 21 | depth=2, 22 | padding_mode='zeros', 23 | res_connection=True, 24 | activation=None, 25 | last_activation=True 26 | ): 27 | 28 | assert depth >= 1 29 | 30 | if nd == 2: 31 | conv = nn.Conv2d 32 | elif nd == 3: 33 | conv = nn.Conv3d 34 | else: 35 | raise ValueError(f'Invalid convolution dimensionality {nd}.') 36 | 37 | super().__init__() 38 | 39 | self.res_connection = res_connection 40 | if not activation: 41 | self.act = nn.ReLU() 42 | else: 43 | self.act = activation 44 | 45 | if last_activation: 46 | self.acts = [self.act] * depth 47 | else: 48 | self.acts = [self.act] * (depth-1) + [self._identity] 49 | 50 | padding = _get_padding(kernel_size, nd) 51 | self.convs = nn.ModuleList([conv(in_channels, out_channels, kernel_size=kernel_size, padding=padding, padding_mode=padding_mode)]) 52 | for i in range(depth-1): 53 | self.convs.append(conv(out_channels, out_channels, kernel_size=kernel_size, padding=padding, padding_mode=padding_mode)) 54 | if res_connection and in_channels != out_channels: 55 | self.res_conv = conv(in_channels, out_channels, kernel_size=1) 56 | else: 57 | self.res_conv = None 58 | 59 | def _identity(self, x): 60 | return x 61 | 62 | def forward(self, x_in): 63 | x = x_in 64 | for conv, act in zip(self.convs, self.acts): 65 | x = act(conv(x)) 66 | if self.res_connection: 67 | if self.res_conv: 68 | x = x + self.res_conv(x_in) 69 | else: 70 | x = x + x_in 71 | return x 72 | 73 | class Conv3dBlock(_ConvNdBlock): 74 | ''' 75 | Pytorch 3D convolution block module. 76 | 77 | Arguments: 78 | in_channels: int. Number of channels entering the first convolution layer. 79 | out_channels: int. Number of output channels in each layer of the block. 80 | kernel_size: int or tuple. Size of convolution kernel. 81 | depth: int >= 1. Number of convolution layers in the block. 82 | padding_mode: str. Type of padding in each convolution layer. 'zeros', 'reflect', 'replicate' or 'circular'. 83 | res_connection: Boolean. Whether to use residual connection over the block (f(x) = h(x) + x). 84 | If in_channels != out_channels, a 1x1x1 convolution is applied to the res connection 85 | to make the channel numbers match. 86 | activation: torch.nn.Module. Activation function to use after every layer in block. If None, 87 | defaults to ReLU. 88 | last_activation: Bool. Whether to apply the activation after the last conv layer (before res connection). 89 | ''' 90 | def __init__(self, 91 | in_channels, 92 | out_channels, 93 | kernel_size=3, 94 | depth=2, 95 | padding_mode='zeros', 96 | res_connection=True, 97 | activation=None, 98 | last_activation=True 99 | ): 100 | super().__init__(in_channels, out_channels, 3, kernel_size, depth, padding_mode, res_connection, activation, last_activation) 101 | 102 | class Conv2dBlock(_ConvNdBlock): 103 | ''' 104 | Pytorch 2D convolution block module. 105 | 106 | Arguments: 107 | in_channels: int. Number of channels entering the first convolution layer. 108 | out_channels: int. Number of output channels in each layer of the block. 109 | kernel_size: int or tuple. Size of convolution kernel. 110 | depth: int >= 1. Number of convolution layers in the block. 111 | padding_mode: str. Type of padding in each convolution layer. 'zeros', 'reflect', 'replicate' or 'circular'. 112 | res_connection: Boolean. Whether to use residual connection over the block (f(x) = h(x) + x). 113 | If in_channels != out_channels, a 1x1 convolution is applied to the res connection 114 | to make the channel numbers match. 115 | activation: torch.nn.Module. Activation function to use after every layer in block. If None, 116 | defaults to ReLU. 117 | last_activation: Bool. Whether to apply the activation after the last conv layer (before res connection). 118 | ''' 119 | def __init__(self, 120 | in_channels, 121 | out_channels, 122 | kernel_size=3, 123 | depth=2, 124 | padding_mode='zeros', 125 | res_connection=True, 126 | activation=None, 127 | last_activation=True 128 | ): 129 | super().__init__(in_channels, out_channels, 2, kernel_size, depth, padding_mode, res_connection, activation, last_activation) 130 | 131 | class _AttentionConvQueryNd(nn.Module): 132 | 133 | def __init__(self, in_channels, global_channels, query_features, kernel_size, 134 | padding_mode, activation, upsample_mode, nd): 135 | super().__init__() 136 | self.nd = nd 137 | if nd == 2: 138 | conv = nn.Conv2d 139 | self.sum_dims = [2,3] 140 | elif nd == 3: 141 | conv = nn.Conv3d 142 | self.sum_dims = [2,3,4] 143 | else: 144 | raise ValueError(f'Invalid convolution dimensionality {nd}.') 145 | padding = _get_padding(kernel_size, nd) 146 | self.a_conv = conv(in_channels, query_features, kernel_size=kernel_size, 147 | padding=padding, padding_mode=padding_mode) 148 | self.g_conv = conv(global_channels, in_channels, kernel_size=kernel_size, 149 | padding=padding, padding_mode=padding_mode) 150 | self.softmax = nn.Softmax(dim=1) 151 | self.sigmoid = nn.Sigmoid() 152 | self.activation = activation 153 | self.upsample_mode = upsample_mode 154 | 155 | def _unsqueeze_query(self, q): 156 | if self.nd == 2: 157 | q = q[:,:,None,None] 158 | else: 159 | q = q[:,:,None,None,None] 160 | return q 161 | 162 | def forward(self, x, q, g=None): 163 | 164 | if g is not None: 165 | # Add global features to x 166 | g = F.interpolate(g, size=x.size()[2:], mode=self.upsample_mode, align_corners=False) 167 | g = F.relu(self.g_conv(g)) 168 | a = x+g 169 | else: 170 | a = x 171 | 172 | # Get attention map from input and query vector 173 | a = self.a_conv(a) # (batch, query_features, width, height, (depth)) 174 | q = self.softmax(q) # (batch, query_features) 175 | q = self._unsqueeze_query(q) # (batch, query_features, 1, 1, (1)) 176 | a = torch.sum(a*q, dim=1, keepdim=True) # (batch, 1, width, height, (depth)) 177 | 178 | # Apply activation to normalize attention map 179 | if self.activation == 'softmax': 180 | size = a.size() 181 | a = self.softmax(a.reshape(a.size(0), -1)).reshape(size) 182 | elif self.activation == 'sigmoid': 183 | a = self.sigmoid(a) 184 | else: 185 | raise ValueError(f'Unrecognized attention map activation {self.activation}.') 186 | 187 | # Get attention-gated features from x 188 | x = torch.sum(a*x, dim=self.sum_dims) 189 | 190 | return x, a.squeeze(dim=1) 191 | 192 | class AttentionConvQuery2d(_AttentionConvQueryNd): 193 | ''' 194 | Pytorch attention layer for 2D convolution with a query 1D vector. 195 | 196 | Arguments: 197 | in_channels: int. Number of channels in the attended feature map. 198 | global_channels: int. Number of channels in the global feature map. 199 | query_features: int. Number of features in query vector. 200 | kernel_size: int. Size of convolution kernel. 201 | padding_mode: str. Type of padding in each convolution layer. 'zeros', 'reflect', 'replicate' or 'circular'. 202 | activation: str. Type of activation to use for attention map. 'sigmoid' or 'softmax'. 203 | upsample_mode: str. Algorithm for upsampling global feature map to the attended 204 | feature map size. See torch.nn.functional.interpolate. 205 | ''' 206 | def __init__(self, in_channels, global_channels, query_features, kernel_size, 207 | padding_mode='zeros', activation='softmax', upsample_mode='bilinear'): 208 | super().__init__(in_channels, global_channels, query_features, kernel_size, 209 | padding_mode, activation, upsample_mode, 2) 210 | 211 | class AttentionConvQuery3d(_AttentionConvQueryNd): 212 | ''' 213 | Pytorch attention layer for 3D convolution with a query 1D vector. 214 | 215 | Arguments: 216 | in_channels: int. Number of channels in the attended feature map. 217 | global_channels: int. Number of channels in the global feature map. 218 | query_features: int. Number of features in query vector. 219 | kernel_size: int. Size of convolution kernel. 220 | padding_mode: str. Type of padding in each convolution layer. 'zeros', 'reflect', 'replicate' or 'circular'. 221 | activation: str. Type of activation to use for attention map. 'sigmoid' or 'softmax'. 222 | upsample_mode: str. Algorithm for upsampling global feature map to the attended 223 | feature map size. See torch.nn.functional.interpolate. 224 | ''' 225 | def __init__(self, in_channels, global_channels, query_features, kernel_size, 226 | padding_mode='zeros', activation='softmax', upsample_mode='trilinear'): 227 | super().__init__(in_channels, global_channels, query_features, kernel_size, 228 | padding_mode, activation, upsample_mode, 3) 229 | 230 | class UNetAttentionConv(nn.Module): 231 | ''' 232 | Pytorch attention layer for U-net model upsampling stage. 233 | 234 | Arguments: 235 | in_channels: int. Number of channels in the attended feature map. 236 | query_channels: int. Number of channels in query feature map. 237 | attention_channels: int. Number of channels in hidden convolution layer before computing attention. 238 | kernel_size: int. Size of convolution kernel. 239 | padding_mode: str. Type of padding in each convolution layer. 'zeros', 'reflect', 'replicate' or 'circular'. 240 | conv_activation: nn.Module. Activation function to use after convolution layers 241 | attention_activation: str. Type of activation to use for attention map. 'sigmoid' or 'softmax'. 242 | upsample_mode: str. Algorithm for upsampling query feature map to the attended 243 | feature map size. See torch.nn.functional.interpolate. 244 | ndim: 2 or 3. Dimensionality of convolution. 245 | References: 246 | https://arxiv.org/abs/1804.03999 247 | ''' 248 | def __init__(self, 249 | in_channels, 250 | query_channels, 251 | attention_channels, 252 | kernel_size, 253 | padding_mode='zeros', 254 | conv_activation=nn.ReLU(), 255 | attention_activation='softmax', 256 | upsample_mode='bilinear', 257 | ndim=2 258 | ): 259 | super().__init__() 260 | 261 | self.ndim = ndim 262 | if ndim == 2: 263 | conv = nn.Conv2d 264 | elif ndim == 3: 265 | conv = nn.Conv3d 266 | else: 267 | raise ValueError(f'Invalid convolution dimensionality {ndim}.') 268 | 269 | if attention_activation == 'softmax': 270 | self.attention_activation = self._softmax 271 | elif attention_activation == 'sigmoid': 272 | self.attention_activation = self._sigmoid 273 | else: 274 | raise ValueError(f'Unrecognized attention map activation {attention_activation}.') 275 | 276 | padding = _get_padding(kernel_size, ndim) 277 | self.x_conv = conv(in_channels, attention_channels, kernel_size=kernel_size, 278 | padding=padding, padding_mode=padding_mode) 279 | self.q_conv = conv(query_channels, attention_channels, kernel_size=kernel_size, 280 | padding=padding, padding_mode=padding_mode) 281 | self.a_conv = conv(attention_channels, 1, kernel_size=kernel_size, 282 | padding=padding, padding_mode=padding_mode) 283 | self.softmax = nn.Softmax(dim=1) 284 | self.sigmoid = nn.Sigmoid() 285 | self.upsample_mode = upsample_mode 286 | self.conv_activation = conv_activation 287 | 288 | def _softmax(self, a): 289 | shape = a.shape 290 | return self.softmax(a.reshape(shape[0], -1)).reshape(shape) 291 | 292 | def _sigmoid(self, a): 293 | return self.sigmoid(a) 294 | 295 | def forward(self, x, q): 296 | 297 | # Upsample query q to the size of input x and convolve 298 | q = F.interpolate(q, size=x.size()[2:], mode=self.upsample_mode, align_corners=False) 299 | q = self.conv_activation(self.q_conv(q)) 300 | 301 | # Convolve input x and sum with q 302 | a = self.conv_activation(self.x_conv(x)) 303 | a = self.conv_activation(a + q) 304 | 305 | # Get attention map and mix it with x 306 | a = self.attention_activation(self.a_conv(a)) 307 | x = a * x 308 | 309 | return x, a.squeeze(dim=1) 310 | -------------------------------------------------------------------------------- /src/preprocessing.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import numpy as np 4 | import scipy.ndimage as nimg 5 | from PIL import Image 6 | 7 | def top_atom_to_zero(xyzs): 8 | ''' 9 | Set the z coordinate of the highest atom in each molecule to 0. 10 | Arguments: 11 | xyzs: list of np.ndarray of shape (num_atoms, :). First three elements in axis 1 are xyz. 12 | Returns: new list of np.ndarrays of same shape as xyzs. 13 | ''' 14 | new_xyzs = [] 15 | for xyz in xyzs: 16 | xyz[:,2] -= xyz[:,2].max() 17 | new_xyzs.append(xyz) 18 | return new_xyzs 19 | 20 | def add_noise(Xs, c=0.1, randomize_amplitude=False, normal_amplitude=False): 21 | ''' 22 | Add uniform random noise to arrays. In-place operation. 23 | Arguments: 24 | Xs: list of np.ndarray of shape (batch_size, ...). 25 | c: float. Amplitude of noise. Is multiplied by (max-min) of sample. 26 | randomize_amplitude: Boolean. If True, noise amplitude is uniform random in [0,c] 27 | for each sample in the batch. 28 | normal_amplitude: Boolean. If True and randomize_amplitude=True, then the noise amplitude 29 | is distributed like the absolute value of a normally distributed variable 30 | with zero mean and standard deviation equal to c. 31 | ''' 32 | for X in Xs: 33 | sh = X.shape 34 | R = np.random.rand(*sh) - 0.5 35 | if randomize_amplitude: 36 | if normal_amplitude: 37 | amp = np.abs(np.random.normal(0, c, sh[0])) 38 | else: 39 | amp = np.random.uniform(0.0, 1.0, sh[0]) * c 40 | else: 41 | amp = [c] * sh[0] 42 | for j in range(sh[0]): 43 | X[j] += R[j] * amp[j]*(X[j].max()-X[j].min()) 44 | 45 | def add_norm(Xs, per_layer=True): 46 | ''' 47 | Normalize arrays by subracting the mean and dividing by standard deviation. In-place operation. 48 | Arguments: 49 | Xs: list of np.ndarray of shape (batch_size, ...). 50 | per_layer: Boolean. If True, normalized separately for each element in last axis of Xs. 51 | ''' 52 | for X in Xs: 53 | sh = X.shape 54 | for j in range(sh[0]): 55 | if per_layer: 56 | for i in range(sh[-1]): 57 | X[j,...,i] = (X[j,...,i] - np.mean(X[j,...,i])) / np.std(X[j,...,i]) 58 | else: 59 | X[j] = (X[j] - np.mean(X[j])) / np.std(X[j]) 60 | 61 | def rand_shift_xy_trend(Xs, shift_step_max=0.02, max_shift_total=0.1): 62 | ''' 63 | Randomly shift z layers in x and y. Each shift is relative to previous one. In-place operation. 64 | Arguments: 65 | Xs: list of np.ndarray of shape (batch_size, x_dim, y_dim, z_dim). 66 | shift_step_max: float in [0,1]. Maximum fraction of image size by which to shift for each step. 67 | max_shift_total: float in [0,1]. Maximum fraction of image size by which to shift in total. 68 | ''' 69 | for X in Xs: 70 | sh= X.shape 71 | #calculate max possible shifts in pixexls between neighbor slices 72 | max_slice_shift_pix=np.floor(np.maximum(sh[1],sh[2])*shift_step_max).astype(int) 73 | #claculate max total shift in pixels 74 | max_trend_pix = np.floor(np.maximum(sh[1],sh[2])*max_shift_total).astype(int) 75 | for j in range(sh[0]): 76 | rand_shift = np.zeros((sh[3],2)) 77 | #calc values of random shift for slices in reverse order 78 | # (0 values for closest slice) and biggest values for most far slices 79 | for i in range(rand_shift.shape[0]-1, 0, -1): 80 | shift_values = [random.choice(np.arange(-max_slice_shift_pix,max_slice_shift_pix+1)),random.choice(np.arange(-max_slice_shift_pix,max_slice_shift_pix+1))] 81 | #print('shift_values = ', shift_values) 82 | for slice_ind in range(i): rand_shift[slice_ind,:] = rand_shift[slice_ind,:] + shift_values 83 | # cut shift values bigger than max_total_shift value 84 | rand_shift = np.clip(rand_shift, -max_trend_pix,max_trend_pix).astype(int) 85 | for i in range(sh[3]): 86 | shift_y = rand_shift[i,1] 87 | shift_x = rand_shift[i,0] 88 | X[j,:,:,i] = nimg.shift (X[j,:,:,i], (shift_y,shift_x), mode='mirror' ) 89 | 90 | def add_cutout(Xs, n_holes=5): 91 | ''' 92 | Randomly add cutouts (square patches of zeros) to images. In-place operation. 93 | Arguments: 94 | Xs: list of np.ndarray of shape (batch_size, x_dim, y_dim, z_dim). 95 | n_holes: int. Maximum number of cutouts to add. 96 | ''' 97 | 98 | def get_random_eraser(input_img,p=0.2, s_l=0.001, s_h=0.01, r_1=0.1, r_2=1./0.1, v_l=0, v_h=0): 99 | ''' 100 | p : the probability that random erasing is performed 101 | s_l, s_h : minimum / maximum proportion of erased area against input image 102 | r_1, r_2 : minimum / maximum aspect ratio of erased area 103 | v_l, v_h : minimum / maximum value for erased area 104 | ''' 105 | 106 | sh = input_img.shape 107 | img_h, img_w = [sh[0], sh[1]] 108 | 109 | if np.random.uniform(0, 1) > p: 110 | return input_img 111 | 112 | while True: 113 | 114 | s = np.exp(np.random.uniform(np.log(s_l), np.log(s_h))) * img_h * img_w 115 | r = np.exp(np.random.uniform(np.log(r_1), np.log(r_2))) 116 | 117 | w = int(np.sqrt(s / r)) 118 | h = int(np.sqrt(s * r)) 119 | left = np.random.randint(0, img_w) 120 | top = np.random.randint(0, img_h) 121 | 122 | if left + w <= img_w and top + h <= img_h: 123 | break 124 | 125 | input_img[top:top + h, left:left + w] = 0.0 126 | 127 | return input_img 128 | 129 | for X in Xs: 130 | sh = X.shape 131 | for j in range(sh[0]): 132 | for i in range(sh[3]): 133 | for attempt in range(n_holes): 134 | X[j,:,:,i] = get_random_eraser(X[j,:,:,i]) 135 | -------------------------------------------------------------------------------- /src/visualization.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from mpl_toolkits.mplot3d import Axes3D 6 | from matplotlib import cm, gridspec 7 | 8 | from utils import _calc_plot_dim, elements 9 | 10 | CLASS_COLORS = 'rkbgcmy' 11 | 12 | def _get_mol_bounding_box(mol, pred, ref, margin=0.5): 13 | mol = mol[:,:3] 14 | mol = np.append(mol, pred[0][None,:3], axis=0) 15 | mol = np.append(mol, ref[0][None,:3], axis=0) 16 | lims_min = np.min(mol, axis=0) 17 | lims_max = np.max(mol, axis=0) 18 | max_range = (lims_max-lims_min).max() 19 | center = (lims_min + lims_max) / 2 20 | radius = max_range / 2 21 | lims = np.stack([center - radius - margin, center + radius + margin]).T 22 | return lims 23 | 24 | def _get_mol_bounding_box_multiple(mol_xyz, pred_xyz, ref_xyz, margin=0.5): 25 | xyz = np.concatenate([mol_xyz, pred_xyz, ref_xyz], axis=0) 26 | lims_min = np.min(xyz, axis=0) 27 | lims_max = np.max(xyz, axis=0) 28 | max_range = (lims_max-lims_min).max() 29 | center = (lims_min + lims_max) / 2 30 | radius = max_range / 2 31 | lims = np.stack([center - radius - margin, center + radius + margin]).T 32 | return lims 33 | 34 | def plot_input(X, constant_range=False, cmap='afmhot'): 35 | ''' 36 | Plot single stack of AFM images. 37 | Arguments: 38 | X: np.ndarray of shape (x, y, z). AFM image to plot. 39 | constant_range: Boolean. Whether the different slices should use the same value range or not. 40 | cmap: str or matplotlib colormap. Colormap to use for plotting. 41 | Returns: matplotlib.pyplot.figure. Figure on which the image was plotted. 42 | ''' 43 | rows, cols = _calc_plot_dim(X.shape[-1]) 44 | fig = plt.figure(figsize=(3.2*cols,2.5*rows)) 45 | vmax = X.max() 46 | vmin = X.min() 47 | for k in range(X.shape[-1]): 48 | fig.add_subplot(rows,cols,k+1) 49 | if constant_range: 50 | plt.imshow(X[:,:,k].T, cmap=cmap, vmin=vmin, vmax=vmax, origin="lower") 51 | else: 52 | plt.imshow(X[:,:,k].T, cmap=cmap, origin="lower") 53 | plt.colorbar() 54 | plt.tight_layout() 55 | return fig 56 | 57 | def make_input_plots(Xs, outdir='./predictions/', start_ind=0, constant_range=False, cmap='afmhot', verbose=1): 58 | ''' 59 | Plot multiple AFM images to files 0_input.png, 1_input.png, ... etc. 60 | Arguments: 61 | Xs: list of np.ndarray of shape (batch, x, y, z). Input AFM images to plot. 62 | outdir: str. Directory where images are saved. 63 | start_ind: int. Save index increments by one for each image. The first index is start_ind. 64 | constant_range: Boolean. Whether the different slices should use the same value range or not. 65 | cmap: str or matplotlib colormap. Colormap to use for plotting. 66 | verbose: int 0 or 1. Whether to print output information. 67 | ''' 68 | 69 | if not os.path.exists(outdir): 70 | os.makedirs(outdir) 71 | 72 | img_ind = start_ind 73 | for i in range(Xs[0].shape[0]): 74 | 75 | for j in range(len(Xs)): 76 | 77 | plot_input(Xs[j][i], constant_range, cmap=cmap) 78 | 79 | save_name = f'{img_ind}_input' 80 | if len(Xs) > 1: 81 | save_name += str(j+1) 82 | save_name = os.path.join(outdir, save_name) 83 | save_name += '.png' 84 | plt.savefig(save_name) 85 | plt.close() 86 | 87 | if verbose > 0: print(f'Input image saved to {save_name}') 88 | 89 | img_ind += 1 90 | 91 | def plot_confusion_matrix(ax, conf_mat, tick_labels=None): 92 | ''' 93 | Plot confusion matrix on matplotlib axes. 94 | Arguments: 95 | ax: matplotlib.axes.Axes. Axes object on which the confusion matrix is plotted. 96 | conf_mat: np.ndarray of shape (num_classes, num_classes). Confusion matrix counts. 97 | tick_labels: list of str. Labels for classes. 98 | ''' 99 | if tick_labels: 100 | assert len(conf_mat) == len(tick_labels) 101 | else: 102 | tick_labels = [str(i) for i in range(len(conf_mat))] 103 | 104 | conf_mat_norm = np.zeros_like(conf_mat, dtype=np.float64) 105 | for i, r in enumerate(conf_mat): 106 | conf_mat_norm[i] = r / np.sum(r) 107 | 108 | im = ax.imshow(conf_mat_norm, cmap=cm.Blues) 109 | plt.colorbar(im) 110 | ax.set_xticks(np.arange(conf_mat.shape[0])) 111 | ax.set_yticks(np.arange(conf_mat.shape[1])) 112 | ax.set_xlabel('Predicted class') 113 | ax.set_ylabel('True class') 114 | ax.set_xticklabels(tick_labels) 115 | ax.set_yticklabels(tick_labels, rotation='vertical', va='center') 116 | for i in range(conf_mat.shape[0]): 117 | for j in range(conf_mat.shape[1]): 118 | color = 'white' if conf_mat_norm[i,j] > 0.5 else 'black' 119 | label = '{:.3f}'.format(conf_mat_norm[i,j])+'\n('+'{:d}'.format(conf_mat[i,j])+')' 120 | ax.text(j, i, label, ha='center', va='center', color=color) 121 | 122 | 123 | def plot_graph_prediction_grid(pred, ref, mol, box_borders): 124 | ''' 125 | Plot 3D view and 2D top view of graph node prediction and reference. 126 | 127 | Arguments: 128 | pred: tuple (atom, bond, pred_dist), where atom is an Atom object, bond is a list, 129 | and pred_dist is np.ndarray of shape (x, y, z). 130 | ref: tuple (atom, bond, terminate, ref_dist), where atom is an Atom object, bond is a list, and 131 | terminate is an int, and ref_dist is np.ndarray of shape (x, y, z). ref_dist is optional. 132 | mol: MoleculeGraph. Input molecules. 133 | box_borders: tuple ((x_start, y_start, z_start),(x_end, y_end, z_end)). Position of plotting region. 134 | 135 | Returns: matplotlib.pyplot.Figure. 136 | ''' 137 | 138 | fig = plt.figure(figsize=(10,10)) 139 | 140 | ref_dist = ref[3].mean(axis=-1) 141 | pred_dist = pred[2].mean(axis=-1) 142 | vmin = min(ref_dist.min(), pred_dist.min()) 143 | vmax = max(ref_dist.max(), pred_dist.max()) 144 | mol_pos = mol.array(xyz=True) if len(mol) > 0 else np.empty((0,3)) 145 | 146 | # Suppress atoms from a terminated graph 147 | ref, pred = (ref, pred) if ref[2] == 0 else (None, None) 148 | 149 | def plot(i, atom, bonds, label, pd): 150 | 151 | ax_grid = fig.add_subplot(221+i) 152 | ax_graph = fig.add_subplot(223+i) 153 | 154 | atom_pos = atom.array(xyz=True) if atom else [] 155 | 156 | z_min, z_max = box_borders[0][2], box_borders[1][2] 157 | s = 80*(mol_pos[:,2] - z_min) / (z_max - z_min) 158 | if (s < 0).any(): 159 | raise ValueError('Encountered atom z position(s) below box borders.') 160 | 161 | extent = [box_borders[0][0], box_borders[1][0], box_borders[0][1], box_borders[1][1]] 162 | ax_grid.imshow(pd.T, origin='lower', extent=extent, vmin=vmin, vmax=vmax) 163 | 164 | ax_graph.scatter(mol_pos[:,0], mol_pos[:,1], c='gray', s=s) 165 | for b in mol.bonds: 166 | pos = np.vstack([mol_pos[b[0]], mol_pos[b[1]]]) 167 | ax_graph.plot(pos[:,0], pos[:,1], 'gray') 168 | if len(atom_pos) > 0: 169 | s = 80*(atom_pos[2] - z_min) / (z_max - z_min) 170 | ax_graph.scatter(atom_pos[0], atom_pos[1], c=CLASS_COLORS[atom.class_index], s=s) 171 | for b in bonds: 172 | pos = np.vstack([mol_pos[b,:3], atom_pos]) 173 | ax_graph.plot(pos[:,0], pos[:,1], c='r') 174 | 175 | ax_graph.set_xlim(box_borders[0][0], box_borders[1][0]) 176 | ax_graph.set_ylim(box_borders[0][1], box_borders[1][1]) 177 | ax_graph.set_title(f'{label}') 178 | 179 | return ax_grid, ax_graph 180 | 181 | # Prediction 182 | pred_atom, pred_bonds = (pred[0], pred[1]) if pred else (None, []) 183 | plot(0, pred_atom, pred_bonds, 'Prediction', pred_dist) 184 | 185 | # Reference 186 | ref_atom, ref_bonds = (ref[0], ref[1]) if ref else (None, []) 187 | ax_grid, _ = plot(1, ref_atom, ref_bonds, 'Reference', ref_dist) 188 | 189 | # Colorbar 190 | plt.tight_layout(rect=[0, 0, 0.9, 1]) 191 | pos = ax_grid.get_position() 192 | cax = fig.add_axes(rect=[0.9, pos.ymin, 0.03, pos.ymax - pos.ymin]) 193 | m = cm.ScalarMappable() 194 | m.set_array([vmin, vmax]) 195 | plt.colorbar(m, cax=cax) 196 | 197 | return fig 198 | 199 | def visualize_graph_grid(molecules, pred, ref, bond_threshold=0.5, box_borders=((2,2,-1.5),(18,18,0)), 200 | outdir='./graphs/', start_ind=0, show=False, verbose=1): 201 | ''' 202 | Plot grid model single-step predictions. 203 | 204 | Arguments: 205 | molecules: list of MoleculeGraph. Input molecule graphs. 206 | pred: list of tuples (atom, bond, pred_dist), where atom is an Atom object, bond is a list, 207 | and pred_dist is np.ndarray of shape (x, y, z). 208 | ref: list of tuples (atom, bond, terminate, ref_dist), where atom is an Atom object, bond is a list, 209 | terminate is an int, and ref_dist is np.ndarray of shape (x, y, z). 210 | bond_threshold: float in [0,1]. Predicted bonds with confidence level above bond_threshold are plotted. 211 | box_borders: tuple ((x_start, y_start, z_start),(x_end, y_end, z_end)). Position of plotting region. 212 | outdir: str. Directory where images are saved. 213 | start_ind: int. Save index increments by one for each graph. The first index is start_ind. 214 | show: Boolean. whether to show an interactive window for each graph. 215 | verbose: int 0 or 1. Whether to print output information. 216 | ''' 217 | 218 | if outdir and not os.path.exists(outdir): 219 | os.makedirs(outdir) 220 | 221 | # Convert bond indicator lists into index lists 222 | pred = [(a, np.where(b >= np.array(bond_threshold))[0], p) for a, b, p in pred] 223 | ref = [(a, np.where(b)[0], t, d) for a, b, t, d in ref] 224 | 225 | counter = start_ind 226 | for mol, p, r in zip(molecules, pred, ref): 227 | 228 | plot_graph_prediction_grid(p, r, mol, box_borders) 229 | 230 | if outdir: 231 | plt.savefig(save_path:=os.path.join(outdir, f'{counter}_pred_graph.png')) 232 | if verbose > 0: print(f'Graph image saved to {save_path}.') 233 | if show: 234 | plt.show() 235 | plt.close() 236 | 237 | counter += 1 238 | 239 | def plot_graph_sequence_grid(pred_graphs, ref_graphs, pred_sequence, box_borders=((2,2,-1.5),(18,18,0)), 240 | classes=None, class_colors=CLASS_COLORS, z_min=None, z_max=None, outdir='./graphs/', start_ind=0, verbose=1): 241 | 242 | if not os.path.exists(outdir): 243 | os.makedirs(outdir) 244 | 245 | if classes: 246 | assert len(class_colors) >= len(classes), f'Not enough colors for classes' 247 | 248 | if z_min is None: z_min = box_borders[0][2] 249 | if z_max is None: z_max = box_borders[1][2] 250 | scatter_size = 160 251 | 252 | def get_marker_size(z, max_size): 253 | return max_size * (z - z_min) / (z_max - z_min) 254 | 255 | def plot_xy(ax, atom_pos, mol, atom, bonds, scatter_size): 256 | 257 | if atom_pos is not None: 258 | s = get_marker_size(atom_pos[:,2], scatter_size) 259 | ax.scatter(atom_pos[:,0], atom_pos[:,1], c='lightgray', s=s) 260 | 261 | if len(mol) > 0: 262 | 263 | mol_pos = mol.array(xyz=True) 264 | 265 | s = get_marker_size(mol_pos[:,2], scatter_size) 266 | if (s < 0).any(): 267 | raise ValueError('Encountered atom z position(s) below box borders.') 268 | 269 | c = [class_colors[atom.class_index] for atom in mol.atoms] 270 | 271 | ax.scatter(mol_pos[:,0], mol_pos[:,1], c=c, s=s, edgecolors='k', zorder=2) 272 | for b in mol.bonds: 273 | pos = np.vstack([mol_pos[b[0]], mol_pos[b[1]]]) 274 | ax.plot(pos[:,0], pos[:,1], 'k', linewidth=2, zorder=1) 275 | 276 | if atom is not None: 277 | atom_xyz = atom.array(xyz=True) 278 | c = class_colors[atom.class_index] 279 | s = get_marker_size(atom_xyz[2], scatter_size) 280 | ax.scatter(atom_xyz[0], atom_xyz[1], c=c, s=s, edgecolors='k', zorder=2) 281 | bonds = [i for i in range(len(bonds)) if bonds[i] > 0.5] 282 | for b in bonds: 283 | pos = np.vstack([mol_pos[b], atom_xyz]) 284 | ax.plot(pos[:,0], pos[:,1], 'k', linewidth=2, zorder=1) 285 | 286 | ax.set_xlim(box_borders[0][0], box_borders[1][0]) 287 | ax.set_ylim(box_borders[0][1], box_borders[1][1]) 288 | ax.set_aspect('equal', 'box') 289 | 290 | def plot_xz(ax, mol, scatter_size): 291 | 292 | if len(mol) > 0: 293 | 294 | order = list(np.argsort(mol.array(xyz=True)[:, 1])[::-1]) 295 | mol = mol.permute(order) 296 | mol_pos = mol.array(xyz=True) 297 | 298 | s = get_marker_size(mol_pos[:,2], scatter_size) 299 | if (s < 0).any(): 300 | raise ValueError('Encountered atom z position(s) below box borders.') 301 | 302 | c = [class_colors[atom.class_index] for atom in mol.atoms] 303 | 304 | for b in mol.bonds: 305 | pos = np.vstack([mol_pos[b[0]], mol_pos[b[1]]]) 306 | ax.plot(pos[:,0], pos[:,2], 'k', linewidth=2, zorder=1) 307 | ax.scatter(mol_pos[:,0], mol_pos[:,2], c=c, s=s, edgecolors='k', zorder=2) 308 | 309 | ax.set_xlim(box_borders[0][0], box_borders[1][0]) 310 | ax.set_ylim(box_borders[0][2], box_borders[1][2]) 311 | ax.set_aspect('equal', 'box') 312 | 313 | ind = start_ind 314 | for p, r, s in zip(pred_graphs, ref_graphs, pred_sequence): 315 | 316 | # Setup plotting grid 317 | seq_len = len(s) + 1 318 | x_seq = min(6, seq_len) 319 | y_seq = int(seq_len / 6.1) + 1 320 | fig_seq = plt.figure(figsize=(2.5*x_seq, 2.8*y_seq)) 321 | grid_seq = gridspec.GridSpec(y_seq, x_seq) 322 | 323 | # Plot prediction sequence 324 | atom_pos = p.array(xyz=True) if len(p) > 0 else None 325 | for i in range(len(s) + 1): 326 | if i == 0: 327 | mol, atom, bonds = [], None, None 328 | else: 329 | mol, atom, bonds = s[i-1] 330 | x_grid = i % x_seq 331 | y_grid = i // x_seq 332 | ax = fig_seq.add_subplot(grid_seq[y_grid, x_grid]) 333 | plot_xy(ax, atom_pos, mol, atom, bonds, 100) 334 | ax.set_title(f'{i}') 335 | ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False) 336 | fig_seq.tight_layout() 337 | 338 | plt.savefig(save_path:=os.path.join(outdir, f'{ind}_pred_sequence.png')) 339 | if verbose > 0: print(f'Graph prediction sequence image saved to {save_path}') 340 | plt.close() 341 | 342 | # Plot final graph 343 | if classes: 344 | x_extra = 0.35 * max([len(c) for c in classes]) 345 | fig_final = plt.figure(figsize=(10+x_extra, 6.5)) 346 | fig_grid = gridspec.GridSpec(1, 2, width_ratios=(10, x_extra), wspace=1/(10+x_extra)) 347 | else: 348 | fig_final = plt.figure(figsize=(10, 6.5)) 349 | fig_grid = gridspec.GridSpec(1, 1) 350 | grid_final = fig_grid[0, 0].subgridspec(2, 2, height_ratios=(5, 1.5), hspace=0.1, wspace=0.2) 351 | 352 | # Prediction 353 | ax_xy_pred = fig_final.add_subplot(grid_final[0, 0]) 354 | ax_xz_pred = fig_final.add_subplot(grid_final[1, 0]) 355 | plot_xy(ax_xy_pred, None, p, None, None, scatter_size) 356 | plot_xz(ax_xz_pred, p, scatter_size) 357 | ax_xy_pred.set_xlabel('x (Å)', fontsize=12) 358 | ax_xy_pred.set_ylabel('y (Å)', fontsize=12) 359 | ax_xz_pred.set_xlabel('x (Å)', fontsize=12) 360 | ax_xz_pred.set_ylabel('z (Å)', fontsize=12) 361 | ax_xy_pred.set_title('Prediction', fontsize=20) 362 | 363 | # Reference 364 | ax_xy_ref = fig_final.add_subplot(grid_final[0, 1]) 365 | ax_xz_ref = fig_final.add_subplot(grid_final[1, 1]) 366 | plot_xy(ax_xy_ref, None, r, None, None, scatter_size) 367 | plot_xz(ax_xz_ref, r, scatter_size) 368 | ax_xy_ref.set_xlabel('x (Å)', fontsize=12) 369 | ax_xy_ref.set_ylabel('y (Å)', fontsize=12) 370 | ax_xz_ref.set_xlabel('x (Å)', fontsize=12) 371 | ax_xz_ref.set_ylabel('z (Å)', fontsize=12) 372 | ax_xy_ref.set_title('Reference', fontsize=20) 373 | 374 | if classes: 375 | 376 | # Plot legend 377 | ax_legend = fig_final.add_subplot(fig_grid[0, 1]) 378 | 379 | # Class colors 380 | dy = 0.08 381 | dx = 0.35 / x_extra 382 | y_start = 0.5 + dy * (len(classes) + 3) / 2 383 | for i, c in enumerate(classes): 384 | ax_legend.scatter(dx, y_start-dy*i, s=scatter_size, c=class_colors[i], edgecolors='k') 385 | ax_legend.text(2*dx, y_start-dy*i, ', '.join([elements[e-1] for e in c]), fontsize=16, 386 | ha='left', va='center_baseline') 387 | 388 | # Marker sizes 389 | y_start2 = y_start - (len(classes) + 1) * dy 390 | marker_zs = np.array([z_max, (z_min + z_max + 0.2) / 2, z_min + 0.2]) 391 | ss = get_marker_size(marker_zs, scatter_size) 392 | for i, (s, z) in enumerate(zip(ss, marker_zs)): 393 | ax_legend.scatter(dx, y_start2-dy*i, s=s, c='w', edgecolors='k') 394 | ax_legend.text(2*dx, y_start2-dy*i, f'z = {z}Å', fontsize=16, 395 | ha='left', va='center_baseline') 396 | 397 | ax_legend.set_xlim(0, 1) 398 | ax_legend.set_ylim(0, 1) 399 | ax_legend.axis('off') 400 | 401 | plt.savefig(save_path:=os.path.join(outdir, f'{ind}_pred_final.png')) 402 | if verbose > 0: print(f'Final graph prediction image saved to {save_path}') 403 | plt.close() 404 | 405 | ind += 1 406 | 407 | def plot_distribution_grid(pred_dist, ref_dist, box_borders=((2,2,-1.5),(18,18,0)), 408 | outdir='./graphs/', start_ind=0, verbose=1): 409 | 410 | assert pred_dist.shape == ref_dist.shape 411 | 412 | if not os.path.exists(outdir): 413 | os.makedirs(outdir) 414 | 415 | fontsize = 24 416 | 417 | z_start = box_borders[0][2] 418 | z_res = (box_borders[1][2] - box_borders[0][2]) / pred_dist.shape[-1] 419 | extent = [box_borders[0][0], box_borders[1][0], box_borders[0][1], box_borders[1][1]] 420 | 421 | ind = start_ind 422 | for p, r in zip(pred_dist, ref_dist): 423 | 424 | # Plot grid in 2D 425 | p_mean, r_mean = p.mean(axis=-1), r.mean(axis=-1) 426 | vmin = min(r_mean.min(), p_mean.min()) 427 | vmax = max(r_mean.max(), p_mean.max()) 428 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6)) 429 | ax1.imshow(p_mean.T, origin='lower', vmin=vmin, vmax=vmax, extent=extent) 430 | ax2.imshow(r_mean.T, origin='lower', vmin=vmin, vmax=vmax, extent=extent) 431 | ax1.set_title('Prediction') 432 | ax2.set_title('Reference') 433 | 434 | # Colorbar 435 | plt.tight_layout(rect=[0, 0, 0.9, 1]) 436 | pos = ax2.get_position() 437 | cax = fig.add_axes(rect=[0.9, pos.ymin, 0.03, pos.ymax - pos.ymin]) 438 | m = cm.ScalarMappable() 439 | m.set_array([vmin, vmax]) 440 | plt.colorbar(m, cax=cax) 441 | 442 | plt.savefig(save_path:=os.path.join(outdir, f'{ind}_pred_dist2D.png')) 443 | if verbose > 0: print(f'Position distribution 2D prediction image saved to {save_path}') 444 | plt.close() 445 | 446 | # Plot each z-slice separately 447 | vmin = min(r.min(), p.min()) 448 | vmax = max(r.max(), p.max()) 449 | nrows, ncols = _calc_plot_dim(p.shape[-1], f=0.5) 450 | fig = plt.figure(figsize=(4*ncols, 8.5*nrows)) 451 | fig_grid = fig.add_gridspec(nrows, ncols, wspace=0.05, hspace=0.15, 452 | left=0.03, right=0.98, bottom=0.02, top=0.98) 453 | for iz in range(p.shape[-1]): 454 | ix = iz % ncols 455 | iy = iz // ncols 456 | ax1, ax2 = fig_grid[iy, ix].subgridspec(2, 1, hspace=0.03).subplots() 457 | ax1.imshow(p[:,:,iz].T, origin='lower', vmin=vmin, vmax=vmax, extent=extent) 458 | ax2.imshow(r[:,:,iz].T, origin='lower', vmin=vmin, vmax=vmax, extent=extent) 459 | ax1.axis('off') 460 | ax2.axis('off') 461 | ax1.set_title(f'z = {z_start + (iz + 0.5) * z_res:.2f}Å', fontsize=fontsize) 462 | if ix == 0: 463 | ax1.text(-0.1, 0.5, 'Prediction', ha='center', va='center', 464 | transform=ax1.transAxes, rotation='vertical', fontsize=fontsize) 465 | ax2.text(-0.1, 0.5, 'Reference', ha='center', va='center', 466 | transform=ax2.transAxes, rotation='vertical', fontsize=fontsize) 467 | 468 | plt.savefig(save_path:=os.path.join(outdir, f'{ind}_pred_dist.png')) 469 | if verbose > 0: print(f'Position distribution prediction image saved to {save_path}') 470 | plt.close() 471 | 472 | ind += 1 473 | --------------------------------------------------------------------------------