├── .gitignore ├── LICENSE.txt ├── README.md ├── data └── prelabelled_anns.json ├── main.py ├── requirements.txt └── src ├── inpainting ├── __init__.py ├── config.py ├── inpaint.py ├── networks.py ├── run_inpaint.py └── util.py ├── postprocessing ├── figures.ipynb └── mergegifs.py ├── preprocessing ├── __init__.py ├── annotation_gui.py ├── import_data.py └── style.qss ├── slicegan ├── __init__.py ├── animations.py ├── model.py ├── networks.py ├── preprocessing.py ├── run_slicegan.py └── util.py └── style.qss /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | venv 3 | .env 4 | runs 5 | data/*runs 6 | data/micrographs* 7 | data/*temp* 8 | data/anns.json 9 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 tldr group, Imperial College London 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Microlib 2 | 3 | A repo for generating the dataset associated with microlib.io. 4 | 5 | Website: https://microlib.io/ 6 | 7 | Paper: https://www.nature.com/articles/s41597-022-01744-1 8 | 9 | ## Folder structure 10 | 11 | ``` 12 | microlib 13 | ┣ src 14 | ┃ ┣ preprocessing 15 | ┃ ┃ ┣ __init__.py 16 | ┃ ┃ ┣ annotations.gui.py 17 | ┃ ┃ ┣ inpaint_scalebars.py 18 | ┃ ┃ ┣ inpaint.py 19 | ┃ ┃ ┣ networks.py 20 | ┃ ┃ ┣ util.py 21 | ┃ ┣ inpainting 22 | ┃ ┃ ┣ __init__.py 23 | ┃ ┃ ┣ config.py 24 | ┃ ┃ ┣ inpaint_scalebars.py 25 | ┃ ┃ ┣ inpaint.py 26 | ┃ ┃ ┣ networks.py 27 | ┃ ┃ ┣ util.py 28 | ┃ ┣ slicegan 29 | ┃ ┃ ┣ __init__.py 30 | ┃ ┃ ┣ model.py 31 | ┃ ┃ ┣ network.py 32 | ┃ ┃ ┣ preprocessing.py 33 | ┃ ┃ ┣ run_slicegan.py 34 | ┃ ┃ ┣ util.py 35 | ┃ ┣ postprocessing 36 | ┃ ┃ ┣ __init__.py 37 | ┃ ┃ ┣ figures.ipynb 38 | ┃ ┃ ┣ mergegifs.py 39 | ┣ data 40 | ┃ ┣ prelabelled_anns.json 41 | ┃ ┣ micrographs_raw* 42 | ┃ ┣ micrographs_final* 43 | ┃ ┣ inpaint_runs* 44 | ┃ ┣ slicegan_runs* 45 | ┃ ┣ anns.json* 46 | ┣ .gitignore 47 | ┣ LICENSE.txt 48 | ┣ main.py 49 | ┣ README.md 50 | ┗ requirements.txt 51 | 52 | *folders and files generated during the processing steps 53 | ``` 54 | 55 | 56 | ## Repo setup 57 | 58 | Prerequisites: 59 | 60 | - conda 61 | - python3 62 | 63 | Create a new conda environment, activate and install pytorch 64 | 65 | _Note: cudatoolkit version and pytorch install depends on system, see [PyTorch install](https://pytorch.org/get-started/locally/) for more info._ 66 | 67 | ``` 68 | conda create --name microlib 69 | conda activate microlib 70 | conda install pytorch torchvision -c pytorch 71 | conda install -r requirements.txt 72 | ``` 73 | ## Dataset generation 74 | 75 | You are now ready to run the repo. We will download images, annotate them, perform inpainting, run slicegan and finally generate some animations. 76 | ## 77 | First, to download images run in import mode. This will create a series of requests to doitpoms. If you get cert errors, go to src/preprocessing/import_data.py and add verify=False to line 18 *at your own risk*. 78 | 79 | ``` 80 | python main.py import 81 | ``` 82 | ## 83 | Next, annotate the images by running in preprocess mode. You can skip this step and use our annotations by renaming data/prelabelled_anns.json to data/anns.json. If you quit the gui and rerun, you will automatically continue from where you left off - to restart, just delete anns.json. 84 | 85 | ``` 86 | python main.py preprocess 87 | ``` 88 | 89 | The following are the controls at different stages of the annotation GUI. At any time, press C to restart the current microstructure, or W to remove the current microstructure if it doesn't fit the exclusion criteria'. The stage you are on is shown at the top of the gui. These are the stages: 90 | 91 | 1. Scale bar col: click on the scale bar then use A and S keys to adjust thresholds, or press enter to skip. 92 | 2. Scale bar box: click on the top left then bottom right corners of the reqion containing the scale bar, or press enter to skip. You should not skip this if you have selected a scale bar col. 93 | 3. Crop region: click on the top left then bottom right corner to define the region you want to keep. Click a third time to reset. Press enter to skip. 94 | 4. Click on the different phases to segment. Use A and S to adjust threshold. Press enter to skip and select grayscale. 95 | 5. Voxel size: click on the left of the scalebar, then the right, then enter scale bar size in microns. 96 | 97 | ## 98 | Now run in inpaint mode. This creates a repo called final_images with all the inpainted images ready for slicegan, as well as any images that didn't need inpainting 99 | 100 | ``` 101 | python main.py inpaint 102 | ``` 103 | ## 104 | Run in slicegan mode to train 3D generators. This creates the data/slicegan_runs folder and a subfolder for each run that will contain the generator and discriminator, params, and the animations and 3D volumes generated in the next step. 105 | 106 | ``` 107 | python main.py slicegan 108 | ``` 109 | ## 110 | Finally, run in animate mode to generate a 3D volume and animate it slice by slice and rotating. Note that th 111 | 112 | ``` 113 | python main.py animate 114 | ``` 115 | 116 | ## Using pretrained generators 117 | 118 | This repo is centred around how users can follow the steps we took to generate the full 3D dataset in microlib. If instead you are interested in using the pretrained generators to make more microstructures of different sizes or shapes, you should instead use the SliceGAN repo. 119 | 120 | To do so, first clone SliceGAN from here: https://github.com/stke9/SliceGAN. Create the following additional folder within the repo, where microxxx.Gen and microxxx.params can be downloaded from microlib.io by clicking on a microstructure of interest: 121 | 122 | ``` 123 | SliceGAN 124 | ┣ TrainedGenerators 125 | ┃ ┣ microxxx 126 | ┃ ┃ ┣ microxxx.Gen 127 | ┃ ┃ ┣ microxxx.params 128 | ``` 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /data/prelabelled_anns.json: -------------------------------------------------------------------------------- 1 | {"0": {"data_path": "data/micrographs_png/000001.png", "barbox": [[95, 399], [386, 483]], "barcol": [255, 150], "phases": [135, 208], "vox_size": 0.14492753623188406, "data_type": "twophase"}, "1": {"data_path": "data/micrographs_png/000002.png", "barbox": [[124, 397], [374, 482]], "barcol": [255, 10], "phases": [122, 199], "vox_size": 0.12658227848101267, "data_type": "twophase"}, "2": {"data_path": "data/micrographs_png/000006.png", "barbox": [[102, 393], [392, 481]], "barcol": [255, 230], "phases": [98, 178], "vox_size": 0.7220216606498195, "data_type": "twophase"}, "3": {"data_path": "data/micrographs_png/000007.png", "barbox": [[102, 399], [391, 481]], "barcol": [255, 290], "phases": [130, 150], "vox_size": 0.1444043321299639, "data_type": "twophase"}, "5": {"data_path": "data/micrographs_png/000010.png", "barbox": [[102, 395], [394, 499]], "barcol": [255, 270], "phases": [104, 225], "vox_size": 1.444043321299639, "data_type": "twophase"}, "6": {"data_path": "data/micrographs_png/000011.png", "barbox": [[101, 399], [393, 482]], "barcol": [255, 190], "phases": [134, 226], "vox_size": 0.2898550724637681, "data_type": "twophase"}, "7": {"data_path": "data/micrographs_png/000016.png", "barbox": [[102, 405], [392, 481]], "barcol": [255, 220], "phases": [], "vox_size": 0.2888086642599278, "data_type": "grayscale"}, "8": {"data_path": "data/micrographs_png/000017.png", "barbox": [[104, 402], [391, 483]], "barcol": [255, 310], "phases": [108, 154], "vox_size": 1.444043321299639, "data_type": "twophase"}, "9": {"data_path": "data/micrographs_png/000021.png", "barbox": [[102, 397], [393, 481]], "barcol": [255, 159], "phases": [96, 222], "vox_size": 1.4492753623188406, "data_type": "twophase"}, "10": {"data_path": "data/micrographs_png/000031.png", "barbox": [[102, 398], [390, 482]], "barcol": [255, 159], "phases": [154, 157], "vox_size": 1.4492753623188406, "data_type": "twophase"}, "11": {"data_path": "data/micrographs_png/000039.png", "barbox": [[105, 394], [393, 482]], "barcol": [255, 99], "phases": [181, 212], "vox_size": 1.4336917562724014, "data_type": "twophase"}, "12": {"data_path": "data/micrographs_png/000043.png", "barbox": [[100, 400], [394, 481]], "barcol": [255, 240], "phases": [140, 159], "vox_size": 1.444043321299639, "data_type": "twophase"}, "13": {"data_path": "data/micrographs_png/000047.png", "barbox": [[100, 395], [390, 481]], "barcol": [255, 109], "phases": [41, 255], "vox_size": 1.444043321299639, "data_type": "twophase"}, "14": {"data_path": "data/micrographs_png/000048.png", "barbox": [[100, 396], [390, 482]], "barcol": [255, 69], "phases": [156, 222], "vox_size": 0.2888086642599278, "data_type": "twophase"}, "15": {"data_path": "data/micrographs_png/000051.png", "barbox": [[102, 403], [390, 481]], "barcol": [255, 200], "phases": [128, 157], "vox_size": 1.4492753623188406, "data_type": "twophase"}, "16": {"data_path": "data/micrographs_png/000053.png", "barbox": [[113, 404], [379, 482]], "barcol": [255, 109], "phases": [14, 141], "vox_size": 0.05976095617529881, "data_type": "twophase"}, "17": {"data_path": "data/micrographs_png/000054.png", "barbox": [[106, 401], [392, 481]], "barcol": [255, 60], "phases": [140, 246], "vox_size": 1.4492753623188406, "data_type": "twophase"}, "18": {"data_path": "data/micrographs_png/000060.png", "barbox": [[102, 400], [393, 481]], "barcol": [255, 129], "phases": [164, 193], "vox_size": 1.444043321299639, "data_type": "twophase"}, "20": {"data_path": "data/micrographs_png/000066.png", "barbox": [[101, 403], [390, 480]], "barcol": [255, 129], "phases": [184, 255], "vox_size": 1.4492753623188406, "data_type": "twophase"}, "21": {"data_path": "data/micrographs_png/000068.png", "barbox": [[102, 397], [394, 481]], "barcol": [255, 30], "phases": [62, 236], "vox_size": 0.7246376811594203, "data_type": "twophase"}, "22": {"data_path": "data/micrographs_png/000070.png", "barbox": [[104, 397], [390, 481]], "barcol": [255, 170], "phases": [147, 198], "vox_size": 0.7246376811594203, "data_type": "twophase"}, "23": {"data_path": "data/micrographs_png/000072.png", "barbox": [[101, 399], [392, 484]], "barcol": [255, 139], "phases": [175, 178], "vox_size": 1.444043321299639, "data_type": "twophase"}, "24": {"data_path": "data/micrographs_png/000084.png", "barbox": [[102, 394], [394, 483]], "barcol": [255, 119], "phases": [137, 221], "vox_size": 2.888086642599278, "data_type": "twophase"}, "25": {"data_path": "data/micrographs_png/000085.png", "barbox": [[100, 394], [394, 481]], "barcol": [255, 160], "phases": [155, 181], "vox_size": 0.7220216606498195, "data_type": "twophase"}, "26": {"data_path": "data/micrographs_png/000160.png", "barbox": [[447, 366], [577, 405]], "barcol": [255, 330], "phases": [93, 93], "vox_size": 0.819672131147541, "data_type": "twophase"}, "27": {"data_path": "data/micrographs_png/000161.png", "barbox": [[444, 376], [583, 418]], "barcol": [255, 330], "phases": [80, 84], "vox_size": 0.8333333333333334, "data_type": "twophase"}, "29": {"data_path": "data/micrographs_png/000177.png", "barbox": [[408, 426], [617, 480]], "barcol": [255, 320], "phases": [77, 120], "vox_size": 1.0416666666666667, "data_type": "twophase"}, "30": {"data_path": "data/micrographs_png/000186.png", "barbox": [[413, 455], [578, 518]], "barcol": [255, 360], "phases": [72, 145], "vox_size": 0.17006802721088435, "data_type": "twophase"}, "31": {"data_path": "data/micrographs_png/000188.png", "barbox": [[432, 474], [624, 525]], "barcol": [255, 230], "phases": [119, 123], "vox_size": 1.1363636363636365, "data_type": "twophase"}, "32": {"data_path": "data/micrographs_png/000189.png", "barbox": [[163, 327], [473, 409]], "barcol": [255, 139], "phases": [], "vox_size": 0.1, "data_type": "grayscale"}, "33": {"data_path": "data/micrographs_png/000205.png", "barbox": [[236, 442], [546, 521]], "barcol": [255, 90], "phases": [169, 194], "vox_size": 1.3840830449826989, "data_type": "twophase"}, "34": {"data_path": "data/micrographs_png/000209.png", "barbox": [[247, 435], [550, 517]], "barcol": [255, 360], "phases": [61, 131], "vox_size": 1.3793103448275863, "data_type": "twophase"}, "35": {"data_path": "data/micrographs_png/000210.png", "barbox": [[239, 437], [547, 525]], "barcol": [255, 410], "phases": [8, 133], "vox_size": 0.6872852233676976, "data_type": "twophase"}, "36": {"data_path": "data/micrographs_png/000211.png", "barbox": [[243, 435], [546, 518]], "barcol": [255, 410], "phases": [26, 155], "vox_size": 1.3793103448275863, "data_type": "twophase"}, "37": {"data_path": "data/micrographs_png/000213.png", "barbox": [[247, 443], [547, 521]], "barcol": [255, 70], "phases": [35, 258], "vox_size": 2.7491408934707904, "data_type": "twophase"}, "38": {"data_path": "data/micrographs_png/000217.png", "barbox": [[239, 439], [548, 521]], "barcol": [255, 160], "phases": [], "vox_size": 1.36986301369863, "data_type": "grayscale"}, "40": {"data_path": "data/micrographs_png/000227.png", "barbox": [[245, 434], [550, 514]], "barcol": [255, 50], "phases": [147, 214], "vox_size": 1.3745704467353952, "data_type": "twophase"}, "41": {"data_path": "data/micrographs_png/000228.png", "barbox": [[246, 439], [549, 518]], "barcol": [255, 230], "phases": [105, 224], "vox_size": 0.3448275862068966, "data_type": "twophase"}, "43": {"data_path": "data/micrographs_png/000233.png", "barbox": [[244, 449], [547, 525]], "barcol": [255, 99], "phases": [188, 257], "vox_size": 1.3745704467353952, "data_type": "twophase"}, "44": {"data_path": "data/micrographs_png/000235.png", "barbox": [[245, 446], [546, 520]], "barcol": [255, 60], "phases": [151, 237], "vox_size": 1.3793103448275863, "data_type": "twophase"}, "45": {"data_path": "data/micrographs_png/000237.png", "barbox": [[244, 435], [547, 512]], "barcol": [255, 70], "phases": [114, 243], "vox_size": 1.3745704467353952, "data_type": "twophase"}, "46": {"data_path": "data/micrographs_png/000238.png", "barbox": [[246, 446], [547, 522]], "barcol": [255, 40], "phases": [148, 246], "vox_size": 0.3424657534246575, "data_type": "twophase"}, "49": {"data_path": "data/micrographs_png/000276.png", "barbox": [[240, 447], [546, 526]], "barcol": [255, 80], "phases": [98, 233], "vox_size": 1.3745704467353952, "data_type": "twophase"}, "50": {"data_path": "data/micrographs_png/000277.png", "barbox": [[244, 443], [546, 519]], "barcol": [255, 80], "phases": [113, 220], "vox_size": 0.684931506849315, "data_type": "twophase"}, "52": {"data_path": "data/micrographs_png/000286.png", "barbox": [[243, 441], [552, 512]], "barcol": [255, 30], "phases": [26, 248], "vox_size": 2.711864406779661, "data_type": "twophase"}, "53": {"data_path": "data/micrographs_png/000319.png", "barbox": [[246, 453], [549, 528]], "barcol": [255, 80], "phases": [], "vox_size": 1.3888888888888888, "data_type": "grayscale"}, "55": {"data_path": "data/micrographs_png/000329.png", "barbox": [[247, 438], [549, 524]], "barcol": [255, 40], "phases": [140, 230], "vox_size": 0.3436426116838488, "data_type": "twophase"}, "60": {"data_path": "data/micrographs_png/000340.png", "barbox": [[245, 440], [556, 513]], "barcol": [255, 40], "phases": [137, 235], "vox_size": 1.342281879194631, "data_type": "twophase"}, "63": {"data_path": "data/micrographs_png/000360.png", "barbox": [[242, 441], [548, 523]], "barcol": [255, 40], "phases": [197, 253], "vox_size": 1.3651877133105803, "data_type": "twophase"}, "65": {"data_path": "data/micrographs_png/000365.png", "barbox": [[251, 443], [546, 522]], "barcol": [255, 90], "phases": [177, 223], "vox_size": 0.684931506849315, "data_type": "twophase"}, "66": {"data_path": "data/micrographs_png/000368.png", "barbox": [[249, 453], [546, 526]], "barcol": [255, 80], "phases": [172, 215], "vox_size": 1.36986301369863, "data_type": "twophase"}, "67": {"data_path": "data/micrographs_png/000370.png", "barbox": [[246, 448], [546, 523]], "barcol": [255, 70], "phases": [190, 241], "vox_size": 1.3793103448275863, "data_type": "twophase"}, "68": {"data_path": "data/micrographs_png/000372.png", "barbox": [[246, 447], [546, 521]], "barcol": [255, 80], "phases": [219, 232], "vox_size": 1.3745704467353952, "data_type": "twophase"}, "69": {"data_path": "data/micrographs_png/000376.png", "barbox": [[242, 440], [548, 517]], "barcol": [255, 19], "phases": [199, 235], "vox_size": 1.3745704467353952, "data_type": "twophase"}, "70": {"data_path": "data/micrographs_png/000378.png", "barbox": [[248, 442], [546, 518]], "barcol": [255, 119], "phases": [73, 192], "vox_size": 1.3745704467353952, "data_type": "twophase"}, "71": {"data_path": "data/micrographs_png/000381.png", "barbox": [[245, 440], [548, 524]], "barcol": [255, 30], "phases": [46, 229], "vox_size": 1.3651877133105803, "data_type": "twophase"}, "72": {"data_path": "data/micrographs_png/000387.png", "barbox": [[249, 453], [546, 528]], "barcol": [255, 50], "phases": [113, 223], "vox_size": 1.3888888888888888, "data_type": "twophase"}, "73": {"data_path": "data/micrographs_png/000393.png", "barbox": [[240, 433], [544, 513]], "barcol": [255, 70], "phases": [110, 254], "vox_size": 1.3651877133105803, "data_type": "twophase"}, "74": {"data_path": "data/micrographs_png/000396.png", "barbox": [[242, 431], [548, 515]], "barcol": [255, 50], "phases": [48, 202], "vox_size": 1.36986301369863, "data_type": "twophase"}, "77": {"data_path": "data/micrographs_png/000406.png", "barbox": [[248, 425], [547, 508]], "barcol": [255, 50], "phases": [59, 234], "vox_size": 2.7586206896551726, "data_type": "twophase"}, "80": {"data_path": "data/micrographs_png/000420.png", "barbox": [[246, 437], [548, 520]], "barcol": [255, 50], "phases": [74, 253], "vox_size": 2.7303754266211606, "data_type": "twophase"}, "81": {"data_path": "data/micrographs_png/000429.png", "barbox": [[247, 451], [547, 528]], "barcol": [255, 70], "phases": [], "vox_size": 1.3745704467353952, "data_type": "grayscale"}, "83": {"data_path": "data/micrographs_png/000441.png", "barbox": [[245, 429], [547, 512]], "barcol": [255, 150], "phases": [122, 185], "vox_size": 1.3793103448275863, "data_type": "twophase"}, "84": {"data_path": "data/micrographs_png/000442.png", "barbox": [[245, 439], [547, 526]], "barcol": [255, 150], "phases": [82, 193], "vox_size": 0.3448275862068966, "data_type": "twophase"}, "86": {"data_path": "data/micrographs_png/000447.png", "barbox": [[247, 443], [548, 514]], "barcol": [255, 50], "phases": [152, 210], "vox_size": 1.3745704467353952, "data_type": "twophase"}, "87": {"data_path": "data/micrographs_png/000477.png", "barbox": [[247, 443], [547, 522]], "barcol": [255, 50], "phases": [117, 220], "vox_size": 2.7491408934707904, "data_type": "twophase"}, "89": {"data_path": "data/micrographs_png/000484.png", "barbox": [[249, 421], [547, 500]], "barcol": [255, 50], "phases": [169, 238], "vox_size": 2.73972602739726, "data_type": "twophase"}, "92": {"data_path": "data/micrographs_png/000516.png", "barbox": [[245, 414], [547, 492]], "barcol": [255, 50], "phases": [], "vox_size": 2.7586206896551726, "data_type": "grayscale"}, "104": {"data_path": "data/micrographs_png/000711.png", "barbox": [[615, 530], [798, 591]], "barcol": null, "phases": [], "vox_size": 6.289308176100629, "data_type": "grayscale"}, "105": {"data_path": "data/micrographs_png/000712.png", "barbox": [[522, 531], [798, 592]], "barcol": null, "phases": [], "vox_size": 1.6129032258064515, "data_type": "grayscale"}, "106": {"data_path": "data/micrographs_png/000716.png", "barbox": [[660, 544], [792, 584]], "barcol": null, "phases": [124, 164], "vox_size": 0.22123893805309736, "data_type": "twophase"}, "114": {"data_path": "data/micrographs_png/000736.png", "barbox": [[640, 39], [738, 90]], "barcol": null, "phases": [124, 246], "vox_size": 0.08108108108108109, "data_type": "twophase"}, "115": {"data_path": "data/micrographs_png/000737.png", "barbox": [[640, 35], [735, 90]], "barcol": null, "phases": [106, 255], "vox_size": 0.07692307692307693, "data_type": "twophase"}, "117": {"data_path": "data/micrographs_png/000740.png", "barbox": [[592, 38], [739, 94]], "barcol": null, "phases": [197, 237], "vox_size": 0.1652892561983471, "data_type": "twophase"}, "123": {"data_path": "data/micrographs_png/000760.png", "barbox": [[254, 385], [443, 431]], "barcol": null, "phases": [162, 240], "vox_size": 0.5747126436781609, "data_type": "twophase"}, "126": {"data_path": "data/micrographs_png/000782.png", "barbox": [[661, 491], [835, 541]], "barcol": null, "phases": [83, 226], "vox_size": 0.6756756756756757, "data_type": "twophase"}, "127": {"data_path": "data/micrographs_png/000784.png", "barbox": [[661, 495], [831, 543]], "barcol": null, "phases": [97, 154], "vox_size": 0.6711409395973155, "data_type": "twophase"}, "131": {"data_path": "data/micrographs_png/000860.png", "barbox": [[1595, 1376], [2045, 1495]], "barcol": null, "phases": [45, 186], "vox_size": 0.2695417789757412, "data_type": "twophase"}, "132": {"data_path": "data/micrographs_png/000971.png", "barbox": [[150, 652], [1113, 786]], "barcol": null, "phases": [33, 203], "vox_size": 1.0548523206751055, "data_type": "twophase"}} -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from src.preprocessing import import_data, annotation_gui 4 | from src.inpainting import run_inpaint 5 | from src.slicegan import run_slicegan, animations 6 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 7 | 8 | def main(mode): 9 | """[summary] 10 | 11 | :param mode: [mode to run in] 12 | :type mode: [str] 13 | 14 | """ 15 | 16 | # Initialise Config object 17 | 18 | 19 | if mode == 'import': 20 | import_data() 21 | elif mode == 'preprocess': 22 | annotation_gui.preprocess_gui() 23 | elif mode == 'inpaint': 24 | run_inpaint.inpaint_dataset('train') 25 | elif mode =='slicegan': 26 | run_slicegan.slicegan_dataset() 27 | elif mode == 'animate': 28 | animations.animate_dataset() 29 | 30 | else: 31 | raise ValueError("Mode not recognised") 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("mode") 37 | args = parser.parse_args() 38 | main(args.mode) 39 | # main('train', False, 'test') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio==2.14.1 2 | matplotlib==3.5.1 3 | moviepy==1.0.3 4 | numpy==1.21.4 5 | Pillow==9.1.1 6 | plotoptix==0.14.2 7 | PyQt5==5.15.6 8 | python-dotenv==0.20.0 9 | requests==2.26.0 10 | scipy==1.7.3 11 | sympy==1.9 12 | tifffile==2021.11.2 13 | tqdm==4.63.0 14 | wandb==0.12.11 15 | -------------------------------------------------------------------------------- /src/inpainting/__init__.py: -------------------------------------------------------------------------------- 1 | from .networks import * 2 | from .config import * 3 | from .util import * 4 | -------------------------------------------------------------------------------- /src/inpainting/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class Config(): 5 | """Config class 6 | """ 7 | def __init__(self, tag): 8 | self.tag = tag 9 | self.cli = False 10 | self.path = f'data/inpaint_runs/{self.tag}' 11 | self.data_path = '' 12 | self.mask_coords = [] 13 | self.net_type = 'conv-resize' 14 | self.image_type = 'n-phase' 15 | self.l = 128 16 | self.n_phases = 2 17 | # Training hyperparams 18 | self.batch_size = 8 19 | self.beta1 = 0.9 20 | self.beta2 = 0.999 21 | self.num_epochs = 250 22 | self.iters = 100 23 | self.lrg = 0.0005 24 | self.lr = 0.0005 25 | self.Lambda = 10 26 | self.critic_iters = 5 27 | self.pw_coeff = 1e3 28 | self.lz = 7 29 | self.lf = 7 30 | self.dl = 32 31 | self.ngpu = 1 32 | if self.ngpu > 0: 33 | self.device_name = "cuda:0" 34 | else: 35 | self.device_name = 'cpu' 36 | self.conv_resize = True 37 | self.nz = 100 38 | # Architecture 39 | self.lays = 4 40 | self.laysd = 5 41 | # kernel sizes 42 | self.dk, self.gk = [4]*self.laysd, [4]*self.lays 43 | self.ds, self.gs = [2]*self.laysd, [2]*self.lays 44 | self.df, self.gf = [self.n_phases, 64, 128, 256, 512, 1], [ 45 | self.nz, 512, 256, 128, self.n_phases] 46 | self.dp = [1, 1, 1, 1, 1] 47 | self.gp = [1, 1, 1, 1] 48 | 49 | # self.gs[0] = 1 50 | 51 | def update_params(self): 52 | self.df[0] = self.n_phases 53 | self.gf[-1] = self.n_phases 54 | if self.net_type=='conv-resize': 55 | self.lays = 5 56 | self.gk = [3]*self.lays 57 | self.gs = [1]*self.lays 58 | self.gp = [1]*self.lays 59 | self.gf = [self.nz, 512, 256, 128, 64, self.n_phases] 60 | 61 | 62 | def save(self): 63 | j = {} 64 | for k, v in self.__dict__.items(): 65 | j[k] = v 66 | with open(f'{self.path}/config.json', 'w') as f: 67 | json.dump(j, f) 68 | 69 | def load(self): 70 | with open(f'{self.path}/config.json', 'r') as f: 71 | j = json.load(f) 72 | for k, v in j.items(): 73 | setattr(self, k, v) 74 | 75 | def get_net_params(self): 76 | return self.dk, self.ds, self.df, self.dp, self.gk, self.gs, self.gf, self.gp 77 | 78 | def get_train_params(self): 79 | return self.l, self.dl, self.batch_size, self.beta1, self.beta2, self.num_epochs, self.iters, self.lrg, self.lr, self.Lambda, self.critic_iters, self.lz, self.nz 80 | 81 | 82 | class ConfigPoly(Config): 83 | def __init__(self, tag): 84 | super(ConfigPoly, self).__init__(tag) 85 | self.l = 64 86 | self.lz = 4 87 | self.ngpu=1 88 | self.lays = 5 89 | self.laysd = 5 90 | # kernel sizes 91 | self.dk, self.gk = [4]*self.laysd, [4]*self.lays 92 | self.ds, self.gs = [2]*self.laysd, [2]*self.lays 93 | self.df, self.gf = [self.n_phases, 128, 256, 512, 1024, 1], [ 94 | self.nz, 1024, 512, 256, 128, self.n_phases] 95 | self.df, self.gf = [self.n_phases, 64, 128, 256, 512, 1], [ 96 | self.nz, 512, 256, 128, 64, self.n_phases] 97 | self.dp = [1, 1, 1, 1, 0] 98 | self.gp = [2, 2, 2, 2, 3] 99 | def get_train_params(self): 100 | return self.l, self.batch_size, self.beta1, self.beta2, self.num_epochs, self.iters, self.lrg, self.lr, self.Lambda, self.critic_iters, self.lz, self.nz -------------------------------------------------------------------------------- /src/inpainting/inpaint.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | from tqdm import tqdm 3 | import matplotlib.pyplot as plt 4 | # from inpainting import * 5 | import numpy as np 6 | import os 7 | import torch.optim as optim 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import time 12 | from copy import deepcopy 13 | import imageio 14 | from src.inpainting.networks import * 15 | from src.inpainting.config import * 16 | from src.inpainting.util import * 17 | 18 | 19 | def train(img, imtype, mask, mask_ip, rect, pth, tag): 20 | h, w = img.shape 21 | c = ConfigPoly(tag) 22 | c.data_path = pth 23 | 24 | # rh, rw = r.height() * sf, r.width() * sf 25 | x, y = np.meshgrid(np.arange(w), np.arange(h)) # make a canvas with coordinates 26 | x, y = x.flatten(), y.flatten() 27 | seeds_mask = np.zeros((h,w)) 28 | for x in range(c.l): 29 | for y in range(c.l): 30 | seeds_mask += np.roll(np.roll(mask, -x, 0), -y, 1) 31 | seeds_mask[seeds_mask>1]=1 32 | real_seeds = np.where(seeds_mask[:-c.l, :-c.l]==0) 33 | overwrite = True 34 | initialise_folders(tag, overwrite) 35 | 36 | if imtype == 'n-phase': 37 | c.n_phases = len(np.unique(plt.imread(c.data_path)[...,0])) 38 | c.conv_resize=True 39 | else: 40 | c.n_phases = 1 41 | c.image_type = imtype 42 | netD, netG = make_nets_poly(c, overwrite) 43 | real_seeds = real_seeds 44 | mask = mask 45 | c.save() 46 | l, batch_size, beta1, beta2, num_epochs, iters, lrg, lr, Lambda, critic_iters, lz, nz, = c.get_train_params() 47 | # Read in data 48 | ngpu = 1 49 | training_imgs, nc = preprocess(c.data_path, c.image_type) 50 | device = torch.device('cuda:0') 51 | # Define Generator network 52 | netG = netG().to(device) 53 | netD = netD().to(device) 54 | pth = f'data/inpaint_runs/{tag}/' 55 | 56 | try: 57 | os.mkdir(f'{pth}') 58 | os.mkdir(f'{pth}imgs') 59 | except: 60 | pass 61 | c.frames = 100 62 | c.iters_per_epoch = 1000 63 | c.opt_iters = 1000 64 | c.epochs = 50 if imtype == 'n-phase' else 100 65 | c.mask = mask 66 | c.mask_ip = mask_ip 67 | c.poly_rects = rect 68 | c.save_inpaint = int(c.opt_iters/c.frames) 69 | c.pth = pth 70 | if ('cuda' in str(device)) and (ngpu > 1): 71 | netD = (nn.DataParallel(netD, list(range(ngpu)))).to(device) 72 | netG = nn.DataParallel(netG, list(range(ngpu))).to(device) 73 | optD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, beta2)) 74 | optG = optim.Adam(netG.parameters(), lr=lrg, betas=(beta1, beta2)) 75 | 76 | max_iters = c.epochs*c.iters_per_epoch 77 | for ep in tqdm(range(c.epochs)): 78 | torch.cuda.synchronize() 79 | for i in range(c.iters_per_epoch): 80 | times = [] 81 | # Discriminator Training 82 | if ('cuda' in str(device)) and (ngpu > 1): 83 | start_overall = torch.cuda.Event(enable_timing=True) 84 | end_overall = torch.cuda.Event(enable_timing=True) 85 | start_overall.record() 86 | else: 87 | start_overall = time.time() 88 | 89 | netD.zero_grad() 90 | 91 | noise = torch.randn(batch_size, nz, lz, lz, device=device) 92 | fake_data = netG(noise).detach() 93 | real_data = batch_real_poly(training_imgs, l, batch_size, real_seeds).to(device) 94 | 95 | # Train on real 96 | out_real = netD(real_data).mean() 97 | # train on fake images 98 | out_fake = netD(fake_data).mean() 99 | gradient_penalty = calc_gradient_penalty(netD, real_data, fake_data, batch_size, l, device, Lambda, nc) 100 | 101 | # Compute the discriminator loss and backprop 102 | disc_cost = out_fake - out_real + gradient_penalty 103 | disc_cost.backward() 104 | 105 | optD.step() 106 | 107 | # Generator training 108 | if i % int(critic_iters) == 0: 109 | netG.zero_grad() 110 | noise = torch.randn(batch_size, nz, lz, lz, device=device) 111 | # Forward pass through G with noise vector 112 | fake_data = netG(noise) 113 | output = -netD(fake_data).mean() 114 | 115 | # Calculate loss for G and backprop 116 | output.backward() 117 | optG.step() 118 | 119 | if ('cuda' in str(device)) and (ngpu > 1): 120 | end_overall.record() 121 | times.append(start_overall.elapsed_time(end_overall)) 122 | else: 123 | end_overall = time.time() 124 | times.append(end_overall-start_overall) 125 | 126 | 127 | # Every 50 iters log images and useful metrics 128 | if i==0: 129 | torch.save(netG.state_dict(), f'{pth}/Gen.pt') 130 | torch.save(netD.state_dict(), f'{pth}/Disc.pt') 131 | mse, img = inpaint(netG, c) 132 | times = [] 133 | plt.imsave(f'{pth}imgs/img{ep}.png', img) 134 | img_pth = f'data/micrographs_final/{tag}.png' 135 | plt.imsave(img_pth, img) 136 | 137 | def inpaint(netG, c): 138 | img = preprocess(c.data_path,c.image_type)[0] 139 | img = torch.nn.functional.pad(img, (32, 32, 32, 32), value=-1) 140 | mask_ip = torch.nn.functional.pad(torch.tensor(c.mask_ip), (32, 32, 32, 32), value=0) 141 | 142 | for rect in c.poly_rects: 143 | x0, y0, x1, y1 = (int(i)+32 for i in rect) 144 | w, h = x1-x0, y1-y0 145 | x1 += 32 - w%32 146 | y1 += 32 - h%32 147 | 148 | w, h = x1-x0, y1-y0 149 | im_crop = img[:, x0-16:x1+16, y0-16:y1+16] 150 | mask_crop = mask_ip[x0-16:x1+16, y0-16:y1+16] 151 | ch, w, h = im_crop.shape 152 | # print(im_crop.shape, mask_crop.shape, img.shape, mask_ip.shape, x0, y0, x1, y1) 153 | if c.conv_resize: 154 | lx, ly = int(w/16), int(h/16) 155 | else: 156 | lx, ly = int(w/32) + 2, int(h/32) + 2 157 | inpaints, mse = optimise_noise(c, lx, ly, im_crop, mask_crop, netG) 158 | frames = len(inpaints) 159 | 160 | if c.image_type =='n-phase': 161 | final_imgs = [torch.argmax(img, dim=0) for i in range(frames)] 162 | final_img_fresh = torch.argmax(img, dim=0) 163 | else: 164 | final_img_fresh = img.permute(1, 2, 0) 165 | final_imgs = [deepcopy(img.permute(1, 2, 0)) for i in range(frames)] 166 | for fimg, inpaint in enumerate(inpaints): 167 | final_imgs[fimg][x0:x1, y0:y1] = inpaint 168 | for i, final_img in enumerate(final_imgs): 169 | istr = f'00{i}' 170 | if c.image_type=='n-phase': 171 | final_img[mask_ip!=1] = final_img_fresh[mask_ip!=1] 172 | # final_img[mask_ip!=1] = 0.5 173 | 174 | final_img = (final_img.numpy()/final_img.max()) 175 | plt.imsave(f'data/temp/temp{istr[-3:]}.png', np.stack([final_img for i in range(3)], -1)) 176 | else: 177 | for ch in range(c.n_phases): 178 | final_img[:,:,ch][mask_ip==0] = final_img_fresh[:,:,ch][mask_ip==0] 179 | final_img[final_img==-1] = 0.5 180 | final_img-= final_img.min() 181 | final_img/= final_img.max() 182 | if c.image_type=='grayscale': 183 | plt.imsave(f'data/temp/temp{istr[-3:]}.png', np.concatenate([final_img for i in range(3)], -1)) 184 | filenames = sorted(os.listdir('data/temp')) 185 | with imageio.get_writer(f'{c.pth}/movie.gif', mode='I', duration=0.2) as writer: 186 | image = imageio.imread(f'data/temp/temp000.png') 187 | for ch in range(2): 188 | image[...,ch][mask_ip==1] = 0 189 | 190 | image[...,2][mask_ip==1] = 255 191 | for i in range(10): 192 | writer.append_data(image[32:-32,32:-32]) 193 | 194 | for filename in filenames: 195 | image = imageio.imread(f'data/temp/{filename}') 196 | writer.append_data(image[32:-32,32:-32]) 197 | 198 | return mse, image 199 | 200 | def optimise_noise(c, lx, ly, img, mask, netG): 201 | 202 | device = torch.device("cuda:0" if( 203 | torch.cuda.is_available() and c.ngpu > 0) else "cpu") 204 | target = img.to(device) 205 | for ch in range(c.n_phases): 206 | target[ch][mask==1] = -1 207 | # plt.imsave('test2.png', torch.cat([target.permute(1,2,0) for i in range(3)], -1).cpu().numpy()) 208 | # plt.imsave('test.png', np.stack([mask for i in range(3)], -1)) 209 | 210 | target = target.unsqueeze(0) 211 | noise = [torch.nn.Parameter(torch.randn(1, c.nz, lx, ly, requires_grad=True, device=device))] 212 | noise_opt = torch.optim.Adam(params=noise, lr=0.02, betas=(0.8, 0.8)) 213 | inpaints = [] 214 | # loss_min = 1000 215 | for i in range(c.opt_iters): 216 | raw = netG(noise[0]) 217 | # print(raw.shape, target.shape) 218 | loss = (raw - target)**4 219 | loss[target==-1] = 0 220 | loss = loss.mean() 221 | loss.backward() 222 | noise_opt.step() 223 | with torch.no_grad(): 224 | noise[0] -= torch.tile(torch.mean(noise[0], dim=[1]), (1, c.nz,1,1)) 225 | noise[0] /= torch.tile(torch.std(noise[0], dim=[1]), (1, c.nz,1,1)) 226 | if c.image_type == 'n-phase': 227 | raw = torch.argmax(raw[0], dim=0)[16:-16, 16:-16].detach().cpu() 228 | else: 229 | raw = raw[0].permute(1,2,0)[16:-16, 16:-16].detach().cpu() 230 | 231 | if (i%c.save_inpaint==0) or (i <20): 232 | inpaints.append(raw) 233 | # if (loss < loss_min): 234 | # best_img = deepcopy(raw) 235 | # loss_min = loss 236 | 237 | 238 | 239 | # inpaints.append(best_img) 240 | return inpaints, loss.item() -------------------------------------------------------------------------------- /src/inpainting/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pickle 5 | 6 | 7 | def make_nets_rect(config, training=True): 8 | """Creates Generator and Discriminator class objects from params either loaded from config object or params file. 9 | 10 | :param config: a Config class object 11 | :type config: Config 12 | :param training: if training is True, params are loaded from Config object. If False, params are loaded from file, defaults to True 13 | :type training: bool, optional 14 | :return: Discriminator and Generator class objects 15 | :rtype: Discriminator, Generator 16 | """ 17 | # save/load params 18 | if training: 19 | config.save() 20 | else: 21 | config.load() 22 | 23 | dk, ds, df, dp, gk, gs, gf, gp = config.get_net_params() 24 | 25 | # Make nets 26 | if config.net_type == 'gan': 27 | class Generator(nn.Module): 28 | def __init__(self): 29 | super(Generator, self).__init__() 30 | self.convs = nn.ModuleList() 31 | self.bns = nn.ModuleList() 32 | for lay, (k, s, p) in enumerate(zip(gk, gs, gp)): 33 | self.convs.append(nn.ConvTranspose2d( 34 | gf[lay], gf[lay+1], k, s, p, bias=False)) 35 | self.bns.append(nn.BatchNorm2d(gf[lay+1])) 36 | 37 | def forward(self, x): 38 | # x = torch.cat((x, mask[:,-1].reshape(x.shape)), dim=1) 39 | for conv, bn in zip(self.convs[:-1], self.bns[:-1]): 40 | x = F.relu_(bn(conv(x))) 41 | if config.image_type == 'n-phase': 42 | out = torch.softmax(self.convs[-1](x), dim=1) 43 | else: 44 | out = torch.sigmoid(self.convs[-1](x)) # bs x 1 x 1 45 | # out = torch.where((mask[:,-1]==0).unsqueeze(1).repeat(1,3,1,1,1), mask[:,0:3], out) 46 | return out # bs x n x imsize x imsize x imsize 47 | 48 | class Discriminator(nn.Module): 49 | def __init__(self): 50 | super(Discriminator, self).__init__() 51 | self.convs = nn.ModuleList() 52 | for lay, (k, s, p) in enumerate(zip(dk, ds, dp)): 53 | self.convs.append( 54 | nn.Conv2d(df[lay], df[lay + 1], k, s, p, bias=False)) 55 | 56 | def forward(self, x): 57 | for conv in self.convs[:-1]: 58 | x = F.relu_(conv(x)) 59 | return x 60 | else: 61 | class Generator(nn.Module): 62 | def __init__(self): 63 | super(Generator, self).__init__() 64 | self.convs = nn.ModuleList() 65 | self.bns = nn.ModuleList() 66 | self.up = nn.Upsample(scale_factor=2) 67 | 68 | for lay, (k, s, p) in enumerate(zip(gk, gs, gp)): 69 | self.convs.append(nn.Conv2d( 70 | gf[lay], gf[lay+1], 3, 1, 1, bias=False, padding_mode='reflect')) 71 | self.bns.append(nn.BatchNorm2d(gf[lay+1])) 72 | 73 | def forward(self, x): 74 | # x = torch.cat((x, mask[:,-1].reshape(x.shape)), dim=1) 75 | for conv, bn in zip(self.convs[:-1], self.bns[:-1]): 76 | x = F.relu_(bn(self.up(conv(x)))) 77 | if config.image_type == 'n-phase': 78 | out = torch.softmax(self.convs[-1](x), dim=1) 79 | else: 80 | out = torch.sigmoid(self.convs[-1](x)) 81 | # out = torch.where((mask[:,-1]==0).unsqueeze(1).repeat(1,3,1,1,1), mask[:,0:3], out) 82 | return out # bs x n x imsize x imsize x imsize 83 | 84 | class Discriminator(nn.Module): 85 | def __init__(self): 86 | super(Discriminator, self).__init__() 87 | self.convs = nn.ModuleList() 88 | for lay, (k, s, p) in enumerate(zip(dk, ds, dp)): 89 | self.convs.append( 90 | nn.Conv2d(df[lay], df[lay + 1], k, s, p, bias=False)) 91 | 92 | def forward(self, x): 93 | for conv in self.convs[:-1]: 94 | x = F.relu_(conv(x)) 95 | x = self.convs[-1](x) # bs x 1 x 1 96 | return x 97 | 98 | return Discriminator, Generator 99 | 100 | def make_nets_poly(config, training=True): 101 | """Creates Generator and Discriminator class objects from params either loaded from config object or params file. 102 | 103 | :param config: a Config class object 104 | :type config: Config 105 | :param training: if training is True, params are loaded from Config object. If False, params are loaded from file, defaults to True 106 | :type training: bool, optional 107 | :return: Discriminator and Generator class objects 108 | :rtype: Discriminator, Generator 109 | """ 110 | # save/load params 111 | if training: 112 | config.save() 113 | else: 114 | config.load() 115 | 116 | dk, ds, df, dp, gk, gs, gf, gp = config.get_net_params() 117 | df[0] = config.n_phases 118 | gf[-1] = config.n_phases 119 | # Make nets 120 | if config.net_type == 'gan': 121 | class Generator(nn.Module): 122 | def __init__(self): 123 | super(Generator, self).__init__() 124 | self.convs = nn.ModuleList() 125 | self.bns = nn.ModuleList() 126 | for lay, (k, s, p) in enumerate(zip(gk, gs, gp)): 127 | self.convs.append(nn.ConvTranspose2d( 128 | gf[lay], gf[lay+1], k, s, p, bias=False)) 129 | self.bns.append(nn.BatchNorm2d(gf[lay+1])) 130 | 131 | def forward(self, x): 132 | # x = torch.cat((x, mask[:,-1].reshape(x.shape)), dim=1) 133 | for conv, bn in zip(self.convs[:-1], self.bns[:-1]): 134 | x = F.relu_(bn(conv(x))) 135 | if config.image_type == 'n-phase': 136 | out = torch.softmax(self.convs[-1](x), dim=1) 137 | else: 138 | out = torch.sigmoid(self.convs[-1](x)) 139 | # out = torch.where((mask[:,-1]==0).unsqueeze(1).repeat(1,3,1,1,1), mask[:,0:3], out) 140 | return out # bs x n x imsize x imsize x imsize 141 | else: 142 | class Generator(nn.Module): 143 | def __init__(self): 144 | super(Generator, self).__init__() 145 | self.convs = nn.ModuleList() 146 | self.bns = nn.ModuleList() 147 | self.up = nn.Upsample(scale_factor=2) 148 | 149 | for lay, (k, s, p) in enumerate(zip(gk, gs, gp)): 150 | self.convs.append(nn.Conv2d( 151 | gf[lay], gf[lay+1], 3, 1, 1, bias=False, padding_mode='reflect')) 152 | self.bns.append(nn.BatchNorm2d(gf[lay+1])) 153 | 154 | def forward(self, x): 155 | # x = torch.cat((x, mask[:,-1].reshape(x.shape)), dim=1) 156 | for conv, bn in zip(self.convs[:-1], self.bns[:-1]): 157 | x = F.relu_(bn(self.up(conv(x)))) 158 | if config.image_type == 'n-phase': 159 | out = torch.softmax(self.convs[-1](x), dim=1) 160 | else: 161 | out = torch.sigmoid(self.convs[-1](x)) 162 | # out = torch.where((mask[:,-1]==0).unsqueeze(1).repeat(1,3,1,1,1), mask[:,0:3], out) 163 | return out # bs x n x imsize x imsize x imsize 164 | 165 | class Discriminator(nn.Module): 166 | def __init__(self): 167 | super(Discriminator, self).__init__() 168 | self.convs = nn.ModuleList() 169 | for lay, (k, s, p) in enumerate(zip(dk, ds, dp)): 170 | self.convs.append( 171 | nn.Conv2d(df[lay], df[lay + 1], k, s, p, bias=False)) 172 | 173 | def forward(self, x): 174 | for conv in self.convs[:-1]: 175 | x = F.relu_(conv(x)) 176 | x = self.convs[-1](x) # bs x 1 x 1 177 | return x 178 | 179 | return Discriminator, Generator 180 | -------------------------------------------------------------------------------- /src/inpainting/run_inpaint.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import json 3 | import matplotlib.pyplot as plt 4 | from src.inpainting.inpaint import train 5 | import numpy as np 6 | import os 7 | def inpaint_dataset(mode): 8 | if not os.path.exists('data/inpaint_runs'): 9 | os.mkdir('data/inpaint_runs') 10 | if not os.path.exists('data/temp'): 11 | os.mkdir('data/temp') 12 | if not os.path.exists('data/micrographs_final'): 13 | os.mkdir('data/micrographs_final') 14 | with open(f'data/anns.json', 'r') as f: 15 | data_map = json.load(f) 16 | for key in data_map.keys(): 17 | load_sample(data_map[key], mode) 18 | 19 | def load_sample(s, mode): 20 | try: 21 | pth = s['data_path'] 22 | except: 23 | return 24 | if not pth: 25 | return 26 | bar_box = s['barbox'] 27 | crop = s['crop'] 28 | 29 | if not bar_box: 30 | img = plt.imread(pth)[...,0] 31 | if crop: 32 | x0, y0 = crop[0] 33 | x1, y1 = crop[1] 34 | img = img[y0:y1, x0:x1] 35 | plt.imsave('data/micrographs_final', img) 36 | return 37 | 38 | phases = [ph / 255 for ph in s['phases']] 39 | img = plt.imread(pth)[...,0] 40 | x0, y0 = bar_box[0] 41 | x1, y1 = bar_box[1] 42 | x1, y1 = x1, y1 43 | rect = [(y0, x0, y1, x1)] 44 | mask = np.zeros_like(img) 45 | mask[y0:y1, x0:x1] = 1 46 | if s['barcol'] != None: 47 | bar_col = s['barcol'][0] / 255 48 | bar_col_var = s['barcol'][1]/2550 49 | mask_ip = np.zeros_like(img) 50 | mask_cropped = abs(img[y0:y1, x0:x1]-bar_col) < bar_col_var 51 | mask_ip[y0:y1, x0:x1][mask_cropped] = 1 52 | else: 53 | mask_ip = deepcopy(mask) 54 | for sh in [1, -1]: 55 | mask_ip += np.roll(mask_ip, sh, axis=0) 56 | mask_ip += np.roll(mask_ip, sh, axis=1) 57 | mask_ip[mask_ip > 1] = 1 58 | if len(phases) != 0: 59 | img = oh(phases, pth) 60 | pth = 'data/temp.png' 61 | plt.imsave(pth, img) 62 | if crop: 63 | x0, y0 = crop[0] 64 | x1, y1 = crop[1] 65 | img = img[y0:y1, x0:x1] 66 | imtype = 'grayscale' if len(phases)==0 else 'n-phase' 67 | tag = s['data_path'][-7:-4] 68 | tag = f'micro{tag}' 69 | print(f'training {tag}') 70 | if mode == 'train': 71 | train(img, imtype, mask, mask_ip, rect, pth, tag) 72 | return 73 | 74 | 75 | def oh(phases, pth): 76 | img_oh = plt.imread(pth)[...,0] 77 | boundaries = [0] 78 | for ph1, ph2 in zip(phases[:-1], phases[1:]): 79 | boundaries.append(ph1 + (ph2 - ph1)/2) 80 | boundaries.append(1) 81 | for i, (b_low, b_high) in enumerate(zip(boundaries[:-1],boundaries[1:])): 82 | img_oh[(img_oh >= b_low) & (img_oh <= b_high)] = i 83 | return img_oh 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /src/inpainting/util.py: -------------------------------------------------------------------------------- 1 | from time import sleep 2 | from tkinter import image_types 3 | import numpy as np 4 | import torch 5 | from torch import autograd 6 | import wandb 7 | from dotenv import load_dotenv 8 | import os 9 | import subprocess 10 | import shutil 11 | import matplotlib.pyplot as plt 12 | from matplotlib.pyplot import cm 13 | from torch import nn 14 | import tifffile 15 | 16 | # check for existing models and folders 17 | def check_existence(tag): 18 | """Checks if model exists, then asks for user input. Returns True for overwrite, False for load. 19 | 20 | :param tag: [description] 21 | :type tag: [type] 22 | :raises SystemExit: [description] 23 | :raises AssertionError: [description] 24 | :return: True for overwrite, False for load 25 | :rtype: [type] 26 | """ 27 | root = f'data/inpaint_runs/{tag}' 28 | check_D = os.path.exists(f'{root}/Disc.pt') 29 | check_G = os.path.exists(f'{root}/Gen.pt') 30 | if check_G or check_D: 31 | print(f'Models already exist for tag {tag}.') 32 | x = input("To overwrite existing model enter 'o', to load existing model enter 'l' or to cancel enter 'c'.\n") 33 | if x=='o': 34 | print("Overwriting") 35 | return True 36 | if x=='l': 37 | print("Loading previous model") 38 | return False 39 | elif x=='c': 40 | raise SystemExit 41 | else: 42 | raise AssertionError("Incorrect argument entered.") 43 | return True 44 | 45 | 46 | # set-up util 47 | def initialise_folders(tag, overwrite): 48 | """[summary] 49 | 50 | :param tag: [description] 51 | :type tag: [type] 52 | """ 53 | if overwrite: 54 | try: 55 | os.mkdir(f'data/inpaint_runs') 56 | except: 57 | pass 58 | try: 59 | os.mkdir(f'data/inpaint_runs/{tag}') 60 | except: 61 | pass 62 | 63 | def wandb_init(name, offline): 64 | """[summary] 65 | 66 | :param name: [description] 67 | :type name: [type] 68 | :param offline: [description] 69 | :type offline: [type] 70 | """ 71 | if offline: 72 | mode = 'disabled' 73 | else: 74 | mode = None 75 | load_dotenv(os.path.join(os.getcwd(), '.env')) 76 | API_KEY = os.getenv('WANDB_API_KEY') 77 | ENTITY = os.getenv('WANDB_ENTITY') 78 | PROJECT = os.getenv('WANDB_PROJECT') 79 | if API_KEY is None or ENTITY is None or PROJECT is None: 80 | raise AssertionError('.env file arguments missing. Make sure WANDB_API_KEY, WANDB_ENTITY and WANDB_PROJECT are present.') 81 | print("Logging into W and B using API key {}".format(API_KEY)) 82 | process = subprocess.run(["wandb", "login", API_KEY], capture_output=True) 83 | print("stderr:", process.stderr) 84 | 85 | 86 | print('initing') 87 | wandb.init(entity=ENTITY, name=name, project=PROJECT, mode=mode) 88 | 89 | wandb_config = { 90 | 'active': True, 91 | 'api_key': API_KEY, 92 | 'entity': ENTITY, 93 | 'project': PROJECT, 94 | # 'watch_called': False, 95 | 'no_cuda': False, 96 | # 'seed': 42, 97 | 'log_interval': 1000, 98 | 99 | } 100 | # wandb.watch_called = wandb_config['watch_called'] 101 | wandb.config.no_cuda = wandb_config['no_cuda'] 102 | # wandb.config.seed = wandb_config['seed'] 103 | wandb.config.log_interval = wandb_config['log_interval'] 104 | 105 | def wandb_save_models(fn): 106 | """[summary] 107 | 108 | :param pth: [description] 109 | :type pth: [type] 110 | :param fn: [description] 111 | :type fn: filename 112 | """ 113 | shutil.copy(fn, os.path.join(wandb.run.dir, fn)) 114 | wandb.save(fn) 115 | 116 | # training util 117 | def preprocess(data_path, imtype, load=True): 118 | """[summary] 119 | 120 | :param imgs: [description] 121 | :type imgs: [type] 122 | :return: [description] 123 | :rtype: [type] 124 | """ 125 | # img = tifffile.imread(data_path) 126 | img = plt.imread(data_path) 127 | if imtype == 'colour': 128 | img = img[:,:,:3] 129 | img = torch.tensor(img) 130 | return img.permute(2,0,1), 3 131 | else: 132 | if len(img.shape) > 2: 133 | img = img[...,0] 134 | if imtype == 'n-phase': 135 | phases = np.unique(img) 136 | if len(phases) > 10: 137 | raise AssertionError('Image not one hot encoded.') 138 | # x, y, z = img.shape 139 | x, y = img.shape 140 | # img_oh = torch.zeros(len(phases), x, y, z) 141 | img_oh = torch.zeros(len(phases), x, y) 142 | for i, ph in enumerate(phases): 143 | img_oh[i][img == ph] = 1 144 | return img_oh, len(phases) 145 | elif imtype == 'grayscale': 146 | img = np.expand_dims(img, 0) 147 | img = torch.tensor(img) 148 | return img, 1 149 | # x, y, z = img.shape 150 | 151 | 152 | def calculate_size_from_seed(seed, c): 153 | imsize = seed 154 | for (k, s, p) in zip(c.gk, c.gs, c.gp): 155 | imsize = (imsize-1)*s-2*p+k 156 | return imsize 157 | 158 | def calculate_seed_from_size(imsize, c): 159 | for (k, s, p) in zip(c.gk, c.gs, c.gp): 160 | imsize = ((imsize-k+2*p)/s+1).to(int) 161 | return imsize 162 | 163 | def make_mask(training_imgs, c): 164 | y1,y2,x1,x2 = c.mask_coords 165 | ydiff, xdiff = y2-y1, x2-x1 166 | seed = calculate_seed_from_size(torch.tensor([xdiff, ydiff]).to(int), c) 167 | img_seed = seed+2 168 | img_size = calculate_size_from_seed(img_seed, c) 169 | mask_size = calculate_size_from_seed(seed, c) 170 | D_size_dim = int(torch.div(mask_size.min(),32, rounding_mode='floor'))*16 171 | D_seed = calculate_seed_from_size(torch.tensor([D_size_dim, D_size_dim]).to(int), c) 172 | 173 | x2, y2 = x1+mask_size[0].item(), y1+mask_size[1].item() 174 | xmid, ymid = (x2+x1)//2, (y2+y1)//2 175 | x1_bound, x2_bound, y1_bound, y2_bound = xmid-img_size[0].item()//2, xmid+img_size[0].item()//2, ymid-img_size[1].item()//2, ymid+img_size[1].item()//2 176 | unmasked = training_imgs[:,x1_bound:x2_bound, y1_bound:y2_bound].clone() 177 | training_imgs[:, x1:x2, y1:y2] = 0 178 | mask = training_imgs[:,x1_bound:x2_bound, y1_bound:y2_bound] 179 | mask_layer = torch.zeros_like(training_imgs[0]).unsqueeze(0) 180 | unmasked = torch.cat([unmasked, torch.zeros_like(unmasked[0]).unsqueeze(0)]) 181 | mask_layer[:,x1:x2,y1:y2] = 1 182 | mask = torch.cat((mask, mask_layer[:,x1_bound:x2_bound, y1_bound:y2_bound])) 183 | 184 | # save coords to c 185 | c.mask_coords = (x1,x2,y1,y2) 186 | c.mask_size = (mask_size[0].item(), mask_size[1].item()) 187 | c.D_seed_x = D_seed[0].item() 188 | c.D_seed_y = D_seed[1].item() 189 | 190 | # plot regions where discriminated 191 | # plt.figure() 192 | # plotter = mask.permute(1,2,0).numpy().copy() 193 | # plotter[(img_size[0].item()-D_size_dim)//2:(img_size[0].item()+D_size_dim)//2,(img_size[1].item()-D_size_dim)//2:(img_size[1].item()+D_size_dim)//2,:] = 0 194 | # plt.imshow(plotter) 195 | # plt.savefig('data/mask_plot.png') 196 | # plt.close() 197 | 198 | # plt.imsave('data/mask.png',mask.permute(1,2,0).numpy()) 199 | # plt.imsave('data/unmasked.png',unmasked.permute(1,2,0).numpy()) 200 | return mask, unmasked, D_size_dim, img_size, img_seed, c 201 | 202 | def update_discriminator(c): 203 | out = c.dl 204 | layer = 0 205 | dk = [4] 206 | dp = [1] 207 | ds = [2] 208 | df = [c.n_phases] 209 | while out != 1: 210 | out_check = int(round((out+2*dp[layer]-dk[layer])/ds[layer]+1)) 211 | if out_check>1: 212 | out = out_check 213 | dk.append(4) 214 | ds.append(2) 215 | dp.append(1) 216 | df.append(int(np.min([2**(layer+6), 512]))) 217 | layer += 1 218 | elif out_check<1: 219 | dp[layer] = int(round((2+dk[layer]-out)/2)) 220 | else: 221 | out = out_check 222 | df.append(1) 223 | c.df = df 224 | c.dk = dk 225 | c.ds = ds 226 | c.dp = dp 227 | return c 228 | 229 | def update_pixmap_rect(raw, img, c): 230 | updated_pixmap = raw.clone().unsqueeze(0) 231 | x1, x2, y1, y2 = c.mask_coords 232 | lx, ly = c.mask_size 233 | x_1, x_2, y_1, y_2 = (img.shape[2]-lx)//2,(img.shape[2]+lx)//2, (img.shape[3]-ly)//2, (img.shape[3]+ly)//2 234 | updated_pixmap[:,:, x1:x2, y1:y2] = img[:,:,x_1:x_2, y_1:y_2] 235 | updated_pixmap = post_process(updated_pixmap, c).permute(0,2,3,1) 236 | if c.image_type=='grayscale': 237 | plt.imsave('data/temp/temp.png', updated_pixmap[0,...,0], cmap='gray') 238 | else: 239 | plt.imsave('data/temp/temp.png', updated_pixmap[0].numpy()) 240 | 241 | def calc_gradient_penalty(netD, real_data, fake_data, batch_size, l, device, gp_lambda, nc): 242 | """[summary] 243 | 244 | :param netD: [description] 245 | :type netD: [type] 246 | :param real_data: [description] 247 | :type real_data: [type] 248 | :param fake_data: [description] 249 | :type fake_data: [type] 250 | :param batch_size: [description] 251 | :type batch_size: [type] 252 | :param l: [description] 253 | :type l: [type] 254 | :param device: [description] 255 | :type device: [type] 256 | :param gp_lambda: [description] 257 | :type gp_lambda: [type] 258 | :param nc: [description] 259 | :type nc: [type] 260 | :return: [description] 261 | :rtype: [type] 262 | """ 263 | alpha = torch.rand(batch_size, 1) 264 | alpha = alpha.expand(batch_size, int( 265 | real_data.nelement() / batch_size)).contiguous() 266 | alpha = alpha.view(batch_size, nc, l, l) 267 | alpha = alpha.to(device) 268 | 269 | interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach()) 270 | interpolates = interpolates.to(device) 271 | interpolates.requires_grad_(True) 272 | disc_interpolates = netD(interpolates) 273 | gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, 274 | grad_outputs=torch.ones( 275 | disc_interpolates.size()).to(device), 276 | create_graph=True, only_inputs=True)[0] 277 | 278 | gradients = gradients.view(gradients.size(0), -1) 279 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * gp_lambda 280 | return gradient_penalty 281 | 282 | def batch_real_poly(img, l, bs, real_seeds): 283 | n_ph, _, _ = img.shape 284 | max_idx = len(real_seeds[0]) 285 | idxs = torch.randint(max_idx, (bs,)) 286 | data = torch.zeros((bs, n_ph, l, l)) 287 | for i, idx in enumerate(idxs): 288 | x, y = real_seeds[0][idx], real_seeds[1][idx] 289 | data[i] = img[:, x:x+l, y:y+l] 290 | return data 291 | 292 | def batch_real(img, l, bs, mask_coords): 293 | """[summary] 294 | :param training_imgs: [description] 295 | :type training_imgs: [type] 296 | :return: [description] 297 | :rtype: [type] 298 | """ 299 | x1, x2, y1, y2 = mask_coords 300 | n_ph, x_max, y_max = img.shape 301 | data = torch.zeros((bs, n_ph, l, l)) 302 | for i in range(bs): 303 | x, y = torch.randint(x_max - l, (1,)), torch.randint(y_max - l, (1,)) 304 | while (x1x-l) and (y1y-l): 305 | x, y = torch.randint(x_max - l, (1,)), torch.randint(y_max - l, (1,)) 306 | data[i] = img[:, x:x+l, y:y+l] 307 | return data 308 | 309 | def pixel_wise_loss(fake_img, real_img, coeff=1, device=None): 310 | mask = real_img.clone().permute(1,2,0) 311 | mask = (mask[...,-1]==0).unsqueeze(0) 312 | mask = mask.repeat(fake_img.shape[0], fake_img.shape[1],1,1) 313 | fake_img = torch.where(mask==True, fake_img, torch.tensor(0).float().to(device)) 314 | real_img = real_img.unsqueeze(0).repeat(fake_img.shape[0], 1 ,1, 1)[:,0:-1] 315 | real_img = torch.where(mask==True, real_img, torch.tensor(0).float().to(device)) 316 | return torch.nn.MSELoss(reduction='none')(fake_img, real_img)*coeff 317 | 318 | # Evaluation util 319 | def post_process(img, c): 320 | """Turns a n phase image (bs, n, imsize, imsize) into a plottable euler image (bs, 3, imsize, imsize, imsize) 321 | 322 | :param img: a tensor of the n phase img 323 | :type img: torch.Tensor 324 | :return: 325 | :rtype: 326 | """ 327 | img = img.detach().cpu() 328 | if c.image_type=='n-phase': 329 | phases = np.arange(c.n_phases) 330 | color = iter(cm.rainbow(np.linspace(0, 1, c.n_phases))) 331 | # color = iter([[0,0,0],[0.5,0.5,0.5], [1,1,1]]) 332 | img = torch.argmax(img, dim=1) 333 | if len(phases) > 10: 334 | raise AssertionError('Image not one hot encoded.') 335 | bs, x, y = img.shape 336 | out = torch.zeros((bs, 3, x, y)) 337 | for b in range(bs): 338 | for i, ph in enumerate(phases): 339 | col = next(color) 340 | col = torch.tile(torch.Tensor(col[0:3]).unsqueeze(1).unsqueeze(1), (x,y)) 341 | out[b] = torch.where((img[b] == ph), col, out[b]) 342 | out = out 343 | else: 344 | out = img 345 | return out 346 | 347 | def crop(fake_data, l): 348 | w = fake_data.shape[2] 349 | return fake_data[:,:,w//2-l//2:w//2+l//2,w//2-l//2:w//2+l//2] 350 | 351 | def make_noise(noise, seed_x, seed_y, c, device): 352 | # noise = torch.ones(bs, nz, seed_x, seed_y, device=device) 353 | mask = torch.zeros_like(noise).to(device) 354 | mask[:,:, (seed_x-c.D_seed_x)//2:(seed_x+c.D_seed_x)//2, (seed_y-c.D_seed_y)//2:(seed_y+c.D_seed_y)//2] = 1 355 | rand = torch.randn_like(noise).to(device)*mask 356 | noise = noise*(mask==0)+rand 357 | return noise -------------------------------------------------------------------------------- /src/postprocessing/mergegifs.py: -------------------------------------------------------------------------------- 1 | from imageio import get_reader, get_writer 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | from random import shuffle 6 | #Create reader object for the gif 7 | gifs_all = [get_reader(f'runs/{file}/movie.gif') for file in sorted(os.listdir('runs'))] 8 | gifs = [] 9 | for gif in gifs_all: 10 | if len(np.unique(gif.get_data(0))) < 10: 11 | print(gif.get_data(0).shape) 12 | gifs.append(gif) 13 | 14 | ngifs = len(gifs) 15 | gifs = gifs[5:-1] 16 | print(f'{ngifs} gifs loaded') 17 | max_width = 8 18 | x, y = gifs[0].get_data(0).shape[:2] 19 | n_frames = gifs[0].get_length() 20 | #If they don't have the same number of frame take the shorter 21 | 22 | #Create writer object 23 | new_gif = get_writer('output_short2.gif', duration = 0.1) 24 | print(f'max width {max_width}') 25 | for width in range(1, max_width + 1): 26 | print(width) 27 | shuffle(gifs) 28 | for frame_number in range(n_frames): 29 | frames = [] 30 | for gif in gifs[:width**2]: 31 | frame = gif.get_data(frame_number) 32 | 33 | # if len(frame.shape) > 2: 34 | # frame = frame[...,0] 35 | # frame = frame[16:-16, 16:-16] 36 | 37 | x, y = frame.shape[:2] 38 | 39 | # new_frame = np.zeros((maxdim, maxdim, frame.shape[-1]), dtype=np.uint8) 40 | ys = (y-x)//2 41 | frame = frame[:, ys:ys+x] 42 | frame =np.array(Image.fromarray(frame).resize(size=(496,496))) 43 | border_frame = np.zeros((512, 512, frame.shape[-1])) 44 | border_frame[8:-8, 8:-8] = frame 45 | frames.append(border_frame) 46 | rows = [] 47 | for i in range(width): 48 | st = i 49 | fin = i+1 50 | rows.append(np.hstack(frames[st*width:fin*width])) 51 | img = np.vstack(rows) 52 | 53 | new_gif.append_data(np.array(Image.fromarray(img.astype(np.uint8)).resize(size=(512,512)))) 54 | import moviepy.editor as mp 55 | 56 | clip = mp.VideoFileClip("output_short2.gif") 57 | clip = clip.speedx(final_duration=40) 58 | clip.write_videofile("inpaint_concat2.mp4", fps=24) 59 | # new_gif.close() -------------------------------------------------------------------------------- /src/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .import_data import * 2 | from .annotation_gui import * -------------------------------------------------------------------------------- /src/preprocessing/annotation_gui.py: -------------------------------------------------------------------------------- 1 | from logging.handlers import RotatingFileHandler 2 | import sys 3 | from PyQt5.QtWidgets import QMainWindow, QApplication, QWidget, QPushButton, QFileDialog, QAction, QComboBox, QLabel, QInputDialog 4 | from PyQt5.QtGui import QColor, QBrush, QPainter, QPixmap, QPolygonF, QPen 5 | from PyQt5.QtCore import QPoint, QRect, QPointF, QThread, QTimeLine, Qt 6 | import matplotlib.pyplot as plt 7 | from sympy import re 8 | import numpy as np 9 | import os 10 | import json 11 | from scipy import ndimage 12 | # GUI for identifying 13 | class MainWindow(QMainWindow): 14 | def __init__(self, app, ow): 15 | super().__init__() 16 | self.app = app 17 | self.ow = ow 18 | self.initUI() 19 | 20 | def initUI(self): 21 | self.setWindowTitle('Microstructure Inpainter') 22 | self.painter_widget = PainterWidget(self) 23 | self.setCentralWidget(self.painter_widget) 24 | self.setGeometry(30, 30, self.painter_widget.image.width(), self.painter_widget.image.height()+50) 25 | self.show() 26 | 27 | def keyPressEvent(self, event): 28 | self.painter_widget.keyPressEvent(event) 29 | 30 | class PainterWidget(QWidget): 31 | def __init__(self, parent): 32 | super(PainterWidget, self).__init__(parent) 33 | self.parent = parent 34 | screen = self.parent.app.primaryScreen() 35 | self.screensize = screen.size() 36 | self.data_path = 'data/micrographs_raw' 37 | self.img_list = sorted(os.listdir(self.data_path)) 38 | # with open(f'data/good_micros.json', 'r') as f: 39 | # self.good = json.load(f) 40 | self.data_map = {} 41 | self.current_img = 0 42 | 43 | try: 44 | with open(f'data/anns.json', 'r') as f: 45 | self.data_map = json.load(f) 46 | print('loading') 47 | last_saved = list(self.data_map)[-1] 48 | last_saved_pth = self.data_map[last_saved]['data_path'].split('/')[-1] 49 | # print(last_saved_pth, self.img_list.index(last_saved_pth)) 50 | self.current_img = self.img_list.index(last_saved_pth) + 1 51 | except: 52 | pass 53 | self.img_name = self.img_list[self.current_img] 54 | self.img_path = f'{self.data_path}/{self.img_name}' 55 | self.img = plt.imread(self.img_path) 56 | self.setPixmap(self.img_path, loading=True) 57 | self.boundary=0.01 58 | self.phases = [] 59 | self.selected_phase = 1 60 | self.opaque = 1 61 | self.rot=0 62 | self.stage = QLabel('scale bar col: click on the scale bar then use A and S keys to adjust thresholds, or press enter to skip') 63 | self.modes = ['scale bar col: click on the scale bar then use A and S keys to adjust thresholds, or press enter to if no scalebar', 64 | 'scale bar box: click the top left then bottom right corners of the region containing the scale bar, or press enter if no scalebar', 65 | 'crop region: click the top left then bottom right corner to give the region to keep. Press enter to not crop', 66 | 'Click on the different phases to segment. Use A and S to adjust threshold', 67 | 'voxel size: click on the left of the scalebar, then the right, then enter scale bar size in microns', 68 | 'scale bar col: click on the scale bar then use A and S keys to adjust thresholds, or press enter to skip', ] 69 | label = parent.addToolBar('stage') 70 | label.addWidget(self.stage) 71 | self.parent.addToolBarBreak() 72 | self.general = QLabel('At any time, press C to restart the microstructure, or W to remove the microstructure') 73 | label = parent.addToolBar('general') 74 | label.addWidget(self.general) 75 | self.clear() 76 | self.show() 77 | 78 | 79 | def setPixmap(self, fp, loading=True): 80 | self.image = QPixmap(fp) 81 | if loading: 82 | x, y = self.screensize.width(), self.screensize.height() 83 | imgx, imgy = self.image.width(), self.image.height() 84 | xfrac, yfrac = imgx/x, imgy/y 85 | w = x * 0.8 if xfrac == yfrac else imgx * 0.8 * 1/yfrac 86 | self.scaled_image_width = int(w) 87 | self.scale_factor = imgx/w 88 | 89 | self.image = self.image.scaledToWidth(self.scaled_image_width) 90 | self.parent.setGeometry(30, 30, self.image.width(), self.image.height()+50) 91 | self.update() 92 | 93 | def paintEvent(self, event): 94 | qp = QPainter(self) 95 | self.resize(self.image.width(), self.image.height()) 96 | qp.drawPixmap(self.rect(), self.image) 97 | br = QBrush(QColor(100, 10, 10, 10)) 98 | pen = QPen(QColor(0, 0, 0, 255), 1.5) 99 | qp.setBrush(br) 100 | qp.setPen(pen) 101 | 102 | def mousePressEvent(self, event): 103 | x, y = event.pos().x(), event.pos().y() 104 | x, y = self.convert_coords(x, y) 105 | self.x, self.y = x, y 106 | if self.mode==0: 107 | self.select_barcol() 108 | if self.mode==1: 109 | self.apply_barbox(x, y) 110 | if self.mode==2: 111 | self.apply_crop(x, y) 112 | if self.mode==3: 113 | self.add_phase() 114 | if self.mode==4: 115 | self.set_voxsize(x, y) 116 | 117 | def convert_coords(self, x, y): 118 | y0, x0 = self.img.shape[:2] 119 | h = self.frameGeometry().width() 120 | w = self.frameGeometry().height() 121 | xn = int(x*x0/h) 122 | yn = int(y*y0/w) 123 | # print('image shape', x0, y0, 'frame shape', h, w, 'click', x, y, 'newclick', xn, yn) 124 | return xn, yn 125 | 126 | def apply_crop(self, x, y): 127 | if len(self.crop)<2: 128 | self.crop.append((x, y)) 129 | if len(self.crop)==2: 130 | x0, y0 = self.crop[0] 131 | self.img = self.img[y0:y, x0:x] 132 | self.load_temp() 133 | else: 134 | self.crop = [] 135 | self.img = plt.imread(self.img_path) 136 | self.setPixmap(self.img_path) 137 | 138 | def set_voxsize(self, x, y): 139 | if len(self.voxsize)<2: 140 | self.voxsize.append(x) 141 | if len(self.voxsize)==2: 142 | self.get_sb() 143 | 144 | 145 | def get_sb(self): 146 | num,ok = QInputDialog.getInt(self,"enter the scalebar value in microns","enter the scalebar value in microns") 147 | if ok: 148 | self.sb = (int(num)) 149 | 150 | def apply_barbox(self, x, y): 151 | if len(self.bar_box)<2: 152 | self.bar_box.append((x, y)) 153 | if len(self.bar_box)==2: 154 | x0, y0 = self.bar_box[0] 155 | self.img[y0:y, x0:x] -= 0.1 156 | self.img[self.img< 0] = 0 157 | self.load_temp() 158 | else: 159 | self.bar_box = [] 160 | self.reload_img() 161 | 162 | def select_barcol(self): 163 | self.reload_img() 164 | self.bar_col = self.img[self.y, self.x] 165 | mask = np.where((abs(self.img-self.bar_col) 1] = 1 179 | self.img[self.img < 0] = 0 180 | self.load_temp() 181 | 182 | def show_phases(self): 183 | self.reload_img() 184 | if len(self.phases)==1: 185 | self.img[...,0][self.img[...,0] < self.phases[0]] +=0.1 186 | else: 187 | self.phases = sorted(self.phases) 188 | boundaries = [0] 189 | for ph1, ph2 in zip(self.phases[:-1], self.phases[1:]): 190 | boundaries.append(ph1 + (ph2 - ph1)/2) 191 | boundaries.append(1) 192 | for i, (b_low, b_high) in enumerate(zip(boundaries[:-1],boundaries[1:])): 193 | self.img[..., i][(self.img[...,i] >= b_low) & (self.img[...,i] <= b_high)] += 1 if self.opaque else 0 194 | 195 | self.img[self.img>1] = 1 196 | self.img[self.img<0] = 0 197 | self.load_temp() 198 | 199 | def reload_img(self): 200 | try: 201 | self.img = plt.imread(self.img_path) 202 | self.setPixmap(self.img_path) 203 | except: 204 | print(f'failed to load image {self.current_img}') 205 | self.update_current_img() 206 | self.reload_img() 207 | 208 | def update_current_img(self): 209 | self.current_img += 1 210 | self.img_name = self.img_list[self.current_img] 211 | self.img_path = f'{self.data_path}/{self.img_name}' 212 | # while self.current_img in self.good: 213 | # self.current_img +=1 214 | print(self.current_img) 215 | 216 | def load_temp(self): 217 | plt.imsave('data/temp.png', self.img) 218 | self.setPixmap('data/temp.png') 219 | 220 | 221 | def clear(self): 222 | self.crop = [] 223 | self.sb = 0 224 | self.voxsize = [] 225 | self.bar_box = [] 226 | self.bar_col = [None] 227 | self.phases = [] 228 | self.rot = 0 229 | self.mode = 0 230 | self.stage.setText(self.modes[self.mode]) 231 | self.reload_img() 232 | 233 | def keyPressEvent(self, event): 234 | if event.key() == Qt.Key_Return: 235 | 236 | self.mode += 1 237 | self.stage.setText(self.modes[self.mode]) 238 | 239 | if self.mode==5: 240 | info = {} 241 | pth = self.img_path 242 | info['data_path'] = pth 243 | info['crop'] = self.crop 244 | info['barbox'] = self.bar_box 245 | info['voxsize'] = self.sb / (self.voxsize[1] - self.voxsize[0]) if len(self.voxsize)==2 else 0 246 | info['barcol'] = (int(self.bar_col[0]*255), int(self.boundary*255)) if self.bar_col[0] != None else None 247 | info['phases'] = [int(ph*255) for ph in self.phases] 248 | info['rot'] = self.rot 249 | info['data_type'] = 'grayscale' if len(self.phases) == 0 else 'twophase' 250 | 251 | key = len(self.data_map.keys()) + 1 252 | print(key) 253 | self.data_map[key] = info 254 | with open(f'data/anns.json', 'w') as f: 255 | json.dump(self.data_map, f) 256 | self.clear() 257 | self.update_current_img() 258 | 259 | print(f'{self.current_img}/{len(self.img_list)} complete') 260 | self.reload_img() 261 | 262 | elif event.key() == Qt.Key_C: 263 | self.clear() 264 | self.reload_img() 265 | 266 | 267 | elif event.key() == Qt.Key_W: 268 | try: 269 | self.data_map.pop(str(self.current_img)) 270 | except: 271 | pass 272 | with open(f'data/anns.json', 'w') as f: 273 | json.dump(self.data_map, f) 274 | self.clear() 275 | self.update_current_img() 276 | 277 | self.reload_img() 278 | 279 | elif event.key() == Qt.Key_Q: 280 | self.mode = 'skip' 281 | elif event.key() == Qt.Key_A: 282 | if self.mode==0: 283 | self.boundary +=0.01 284 | self.select_barcol() 285 | 286 | if self.mode==3: 287 | if self.phases[self.selected_phase] < 1: 288 | self.phases[self.selected_phase] +=0.02 289 | self.show_phases() 290 | 291 | if self.mode==4: 292 | self.rot+=2 293 | self.rotateimg() 294 | 295 | 296 | elif event.key() == Qt.Key_S: 297 | if self.mode==0: 298 | self.boundary -=0.01 299 | self.select_barcol() 300 | 301 | if self.mode==3: 302 | if self.phases[self.selected_phase] > 0: 303 | self.phases[self.selected_phase] -=0.02 304 | self.show_phases() 305 | 306 | if self.mode==4: 307 | self.rot-=2 308 | self.rotateimg() 309 | 310 | elif event.text().isnumeric(): 311 | self.selected_phase = int(event.text()) - 1 312 | 313 | elif event.key() == Qt.Key_X: 314 | self.opaque = 0 if self.opaque else 1 315 | self.show_phases() 316 | 317 | 318 | def preprocess_gui(ow=False): 319 | app = QApplication(sys.argv) 320 | qss="src/style.qss" 321 | with open(qss,"r") as fh: 322 | app.setStyleSheet(fh.read()) 323 | window = MainWindow(app, ow) 324 | sys.exit(app.exec_()) 325 | -------------------------------------------------------------------------------- /src/preprocessing/import_data.py: -------------------------------------------------------------------------------- 1 | import requests # to get image from the web 2 | import shutil # to save it locally 3 | import os 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | ## Set up the image URL and filename 7 | def import_data(): 8 | f = 'data/micrographs_raw' 9 | if not os.path.exists(f): 10 | os.mkdir(f) 11 | for i in range(1100): 12 | tag = '000000' + str(i) 13 | tag = tag[-6:] 14 | 15 | image_url = f"https://www.doitpoms.ac.uk/miclib/micrographs/large/{tag}.jpg" 16 | filename = image_url.split("/")[-1] 17 | filename = f'data/micrographs_raw/{filename}' 18 | # Open the url image, set stream to True, this will return the stream content. 19 | # r = requests.get(image_url, stream = True) 20 | r = requests.get(image_url, stream = True, verify=False) 21 | 22 | # Check if the image was retrieved successfully 23 | if r.status_code == 200: 24 | # Set decode_content value to True, otherwise the downloaded image file's size will be zero. 25 | r.raw.decode_content = True 26 | 27 | # Open a local file with wb ( write binary ) permission. 28 | with open(filename,'wb') as f: 29 | shutil.copyfileobj(r.raw, f) 30 | 31 | img = plt.imread(filename) 32 | if len(img.shape) < 3: 33 | img = np.dstack([np.array(img)]*3 ) 34 | print(img.shape) 35 | plt.imsave(filename[:-3] + 'png', img) 36 | os.remove(filename) 37 | print('Image sucessfully Downloaded: ',filename) 38 | else: 39 | print('Image Couldn\'t be retreived') 40 | -------------------------------------------------------------------------------- /src/preprocessing/style.qss: -------------------------------------------------------------------------------- 1 | * { 2 | font-size:13px; 3 | color:#c1c1c1; 4 | font-family:"Helvetica"; 5 | background:black; 6 | } 7 | 8 | 9 | QToolBar { 10 | background-color: black; 11 | border: none; 12 | } 13 | 14 | QToolButton { 15 | background:white; 16 | color:black; 17 | margin: 5px; 18 | margin-top: 12px; 19 | margin-bottom: 12px; 20 | padding: 7px; 21 | border-radius: 2px; 22 | border: 2px; 23 | 24 | } 25 | 26 | QPushButton { 27 | background:white; 28 | color:black; 29 | margin: 5px; 30 | margin-top: 12px; 31 | margin-bottom: 12px; 32 | padding: 7px; 33 | border-radius: 2px; 34 | border: 2px; 35 | 36 | } 37 | -------------------------------------------------------------------------------- /src/slicegan/__init__.py: -------------------------------------------------------------------------------- 1 | from .networks import * 2 | from .model import * 3 | from .preprocessing import * 4 | from .util import * 5 | 6 | 7 | -------------------------------------------------------------------------------- /src/slicegan/animations.py: -------------------------------------------------------------------------------- 1 | from asyncio.log import logger 2 | from msilib.schema import Error 3 | from sympy import Q 4 | from src.slicegan import networks, util 5 | import argparse 6 | import torch 7 | import tifffile 8 | import numpy as np 9 | from plotoptix.materials import make_material 10 | import plotoptix.materials as m 11 | from plotoptix import NpOptiX, TkOptiX 12 | from scipy import ndimage 13 | import moviepy.editor as mp 14 | from moviepy.editor import * 15 | import matplotlib.pyplot as plt 16 | import imageio 17 | import os 18 | from time import time 19 | from time import sleep 20 | import shutil 21 | 22 | class Animator(): 23 | def __init__(self): 24 | # Define project name 25 | res = 512 26 | min_accum = 500 27 | self.optix = NpOptiX(on_scene_compute=self.compute, 28 | on_rt_completed=self.update, 29 | width=res, height=res, 30 | start_now=False) 31 | 32 | self.optix.set_param(min_accumulation_step=min_accum, 33 | # 1 animation frame = 128 accumulation frames 34 | max_accumulation_frames=5130, 35 | light_shading="Hard") # accumulate 512 frames when paused 36 | self.optix.set_uint("path_seg_range", 5, 10) 37 | exposure = 1 38 | gamma = 2.3 39 | self.optix.set_float("tonemap_exposure", exposure) # sRGB tonning 40 | self.optix.set_float("tonemap_gamma", gamma) 41 | self.optix.set_float("denoiser_blend", 0.25) 42 | self.optix.add_postproc("Denoiser") 43 | self.optix.set_background(250) 44 | alpha = np.full((1, 1, 4), 1, dtype=np.float32) 45 | self.optix.set_texture_2d("mask", (255 * alpha).astype(np.uint8)) 46 | m_diffuse_3 = make_material("Diffuse", color_tex="mask") 47 | self.optix.setup_material("3", m_diffuse_3) 48 | self.optix.start() 49 | self.optix.pause_compute() 50 | 51 | def new_animation(self, micro): 52 | self.Project_name = micro 53 | # Specify project folder. 54 | self.Project_path = f'data/slicegan_runs/{micro}' 55 | if not os.path.exists(f'{self.Project_path}/frames'): 56 | os.mkdir(f'{self.Project_path}/frames') 57 | 58 | self.frame_path = f'{self.Project_path}/frames' 59 | self.Project_path = f'data/slicegan_runs/{micro}/{micro}' 60 | img_size, img_channels, scale_factor = 64, 1, 1 61 | z_channels = 16 62 | lays = 6 63 | dk, gk = [4] * lays, [4] * lays 64 | ds, gs = [2] * lays, [2] * lays 65 | df, gf = [img_channels, 64, 128, 256, 512, 1], [z_channels, 512, 256, 128, 64, 66 | img_channels] 67 | dp, gp = [1, 1, 1, 1, 0], [2, 2, 2, 2, 3] 68 | 69 | ## Create Networks 70 | netD, netG = networks.slicegan_nets(self.Project_path, False, 'grayscale', dk, ds, 71 | df, dp, gk, gs, gf, gp) 72 | netG = netG() 73 | netG.eval() 74 | lf = 12 75 | n = (lf - 2) * 32 76 | noise = torch.randn(1, z_channels, lf, lf, lf) 77 | netG = netG.cuda() 78 | noise = noise.cuda() 79 | nseeds = 10 80 | netG.load_state_dict(torch.load(self.Project_path + '_Gen.pt')) 81 | 82 | img = netG(noise[0].unsqueeze(0)) 83 | image_type = 'twophase' if img.shape[1] == 2 else 'grayscale' 84 | # image_type = 'colour' 85 | 86 | img = util.post_proc(img, image_type) 87 | self.img = img 88 | if image_type == 'twophase': 89 | self.ph = 0 if np.mean(img) > 0.5 else 1 90 | if image_type == 'twophase': 91 | bind = np.array(np.where(img == self.ph)).T 92 | c = 1 - (bind + 0.5) 93 | c = (0.3, 0.3, 0.3) 94 | elif image_type=='colour': 95 | bind = np.array(np.where(img[...,0] != -1)).T 96 | c = (img.reshape(-1, 3)) / 255 97 | print(c.shape, bind.shape) 98 | 99 | else: 100 | # img = ndimage.gaussian_filter(img, blur) 101 | bind = np.array(np.where(img != -1)).T 102 | # bind[:,1][img.reshape(-1) < 0] +=1000 103 | c = (img.reshape(-1)) / 255 104 | bind = bind / n - 0.5 105 | tf = 360 106 | self.f = 0 107 | self.fn=0 108 | self.s = int(360 / tf) 109 | self.e = [-3, 0, 0] 110 | self.l = [-5, 10, 0] 111 | self.bind = bind 112 | self.c = c 113 | self.fin = False 114 | self.rotating=False 115 | self.n = img.shape[0] 116 | self.image_type = image_type 117 | s = 1 / self.n 118 | self.optix.set_data("cubes_b", pos=self.bind, u=[s, 0, 0], v=[0, s, 0], w=[0, 0, s], 119 | geom="Parallelepipeds", # cubes, actually default geometry 120 | mat="3", # opaque, mat, default 121 | c=self.c) 122 | self.optix.setup_camera("cam1", eye=self.e, target=[0, 0, 0], up=[0, 1, 0], 123 | fov=30) 124 | self.optix.set_ambient((0.3, 0.3, 0.3)) 125 | x = self.n / 2 126 | self.optix.resume_compute() 127 | 128 | def compute(self, rt: NpOptiX, 129 | delta: int) -> None: # compute scene updates in parallel to the raytracing 130 | self.fn+=1 131 | if not self.rotating: 132 | self.f += self.s 133 | img = self.img[-self.f:] 134 | # print(img.shape) 135 | # print(np.array(np.where(img == self.ph)).T.shape) 136 | 137 | if self.image_type == 'twophase': 138 | bind = np.array(np.where(img == self.ph)).T 139 | c = 1 - (bind + 0.5) 140 | c = (0.3, 0.3, 0.3) 141 | elif self.image_type=='colour': 142 | bind = np.array(np.where(img[...,0] != -1)).T 143 | self.c = (img.reshape(-1, 3)) / 255 144 | 145 | else: 146 | # img = ndimage.gaussian_filter(img, blur) 147 | bind = np.array(np.where(img != -1)).T 148 | # bind[:,1][img.reshape(-1) < 0] +=1000 149 | self.c = (img.reshape(-1)) / 255 150 | self.bind = bind / self.n - 0.5 151 | if self.f==self.n: 152 | self.f=0 153 | self.rotating = True 154 | self.e = [-3, min(1.2 * (self.f/self.n), 1.2), 0] 155 | self.l = [-5, 10,0] 156 | # self.bind = bind/n - 0.5 157 | else: 158 | self.f += self.s 159 | f_step = self.f * np.pi * 2 / 360 160 | # self.e = [0.5*np.cos(self.f/360), 12, 20*np.sin(self.f/360)] 161 | x, y = np.cos(f_step), np.sin(f_step) 162 | 163 | self.e = [-3 * x, 1.2, -3 * y] 164 | self.l = [-5 * x, 5, -5 * y] 165 | 166 | 167 | # optionally, save every frame to a separate file using save_image() method 168 | def update(self, rt: NpOptiX) -> None: 169 | rt.update_camera('cam1', eye=self.e) 170 | # rt.update_light('light1', pos=self.l) 171 | rt.set_data("cubes_b", pos=self.bind, u=[1 / self.n, 0, 0], v=[0, 1 / self.n, 0], 172 | w=[0, 0, 1 / self.n], 173 | geom="Parallelepipeds", # cubes, actually default geometry 174 | mat="diffuse", # opaque, mat, default 175 | c=self.c) 176 | # rt.update_light('light1', pos=self.l) 177 | # print("frames/frame_{:05d}.png".format(self.f)) 178 | 179 | # self.optix.close() 180 | # raise Error 181 | rt.save_image(self.frame_path + '/frame_{:05d}.png'.format(self.fn)) 182 | if self.f == 360: 183 | self.optix.pause_compute() 184 | self.save_animation() 185 | # rt.close() 186 | self.fin = True 187 | 188 | 189 | def save_animation(self): 190 | 191 | frames = sorted(os.listdir(self.Project_path + 'frames'))[1:] 192 | frames = frames[:319] + frames[320:] 193 | 194 | end_frames = frames[:319] 195 | end_frames.reverse() 196 | 197 | frames = frames + end_frames 198 | fps = 45 199 | frame_duration = 1 / fps 200 | 201 | clips = [ 202 | ImageClip(f'{self.Project_path}frames/{m}').set_duration(frame_duration) 203 | for m in frames] 204 | clips[0] = clips[0].set_duration(1) 205 | clip = concatenate_videoclips(clips, method="compose") 206 | clip.write_videofile(f'{self.Project_path}_long.mp4', fps=fps, verbose=False, logger=None, ffmpeg_params=['-movflags', 'faststart']) 207 | def animate_dataset(): 208 | dir = f'data/slicegan_runs' 209 | micros = sorted(os.listdir(dir)) 210 | print(len(micros)) 211 | a = Animator() 212 | for micro in micros: 213 | 214 | a.new_animation(micro) 215 | while not a.fin: 216 | sleep(1) 217 | pass 218 | print(f'{micro} finished') 219 | a.fin = False 220 | a.optix.close() 221 | 222 | 223 | -------------------------------------------------------------------------------- /src/slicegan/model.py: -------------------------------------------------------------------------------- 1 | from src.slicegan import preprocessing, util 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | import torch.optim as optim 7 | import time 8 | import matplotlib 9 | 10 | def train(pth, imtype, datatype, real_data, Disc, Gen, nc, l, nz, sf): 11 | """ 12 | train the generator 13 | :param pth: path to save all files, imgs and data 14 | :param imtype: image type e.g nphase, colour or gray 15 | :param datatype: training data format e.g. tif, jpg ect 16 | :param real_data: path to training data 17 | :param Disc: 18 | :param Gen: 19 | :param nc: channels 20 | :param l: image size 21 | :param nz: latent vector size 22 | :param sf: scale factor for training data 23 | :return: 24 | """ 25 | if len(real_data) == 1: 26 | real_data *= 3 27 | isotropic = True 28 | else: 29 | isotropic = False 30 | 31 | print('Loading Dataset...') 32 | dataset_xyz = preprocessing.batch(real_data, datatype, l, sf) 33 | 34 | ## Constants for NNs 35 | matplotlib.use('Agg') 36 | ngpu = 1 37 | num_epochs = 1 38 | 39 | # batch sizes 40 | batch_size = 32 41 | D_batch_size = 8 42 | # optimiser params for G and D 43 | lrg = 0.0001 44 | lrd = 0.0001 45 | beta1 = 0 46 | beta2 = 0.9 47 | Lambda = 10 48 | critic_iters = 5 49 | cudnn.benchmark = True 50 | workers = 0 51 | lz = 4 52 | ##Dataloaders for each orientation 53 | device = torch.device("cuda:0" if(torch.cuda.is_available() and ngpu > 0) else "cpu") 54 | print(device, " will be used.\n") 55 | 56 | # D trained using different data for x, y and z directions 57 | dataloaderx = torch.utils.data.DataLoader(dataset_xyz[0], batch_size=batch_size, 58 | shuffle=True, num_workers=workers) 59 | dataloadery = torch.utils.data.DataLoader(dataset_xyz[1], batch_size=batch_size, 60 | shuffle=True, num_workers=workers) 61 | dataloaderz = torch.utils.data.DataLoader(dataset_xyz[2], batch_size=batch_size, 62 | shuffle=True, num_workers=workers) 63 | 64 | # Create the Genetator network 65 | netG = Gen().to(device) 66 | if ('cuda' in str(device)) and (ngpu > 1): 67 | netG = nn.DataParallel(netG, list(range(ngpu))) 68 | optG = optim.Adam(netG.parameters(), lr=lrg, betas=(beta1, beta2)) 69 | 70 | # Define 1 Discriminator and optimizer for each plane in each dimension 71 | netDs = [] 72 | optDs = [] 73 | for i in range(3): 74 | netD = Disc() 75 | netD = (nn.DataParallel(netD, list(range(ngpu)))).to(device) 76 | netDs.append(netD) 77 | optDs.append(optim.Adam(netDs[i].parameters(), lr=lrd, betas=(beta1, beta2))) 78 | 79 | disc_real_log = [] 80 | disc_fake_log = [] 81 | gp_log = [] 82 | Wass_log = [] 83 | 84 | print("Starting Training Loop...") 85 | # For each epoch 86 | start = time.time() 87 | for epoch in range(num_epochs): 88 | # sample data for each direction 89 | for i, (datax, datay, dataz) in enumerate(zip(dataloaderx, dataloadery, dataloaderz), 1): 90 | dataset = [datax, datay, dataz] 91 | ### Initialise 92 | ### Discriminator 93 | ## Generate fake image batch with G 94 | noise = torch.randn(D_batch_size, nz, lz,lz,lz, device=device) 95 | fake_data = netG(noise).detach() 96 | # for each dim (d1, d2 and d3 are used as permutations to make 3D volume into a batch of 2D images) 97 | for dim, (netD, optimizer, data, d1, d2, d3) in enumerate( 98 | zip(netDs, optDs, dataset, [2, 3, 4], [3, 2, 2], [4, 4, 3])): 99 | if isotropic: 100 | netD = netDs[0] 101 | optimizer = optDs[0] 102 | netD.zero_grad() 103 | ##train on real images 104 | real_data = data[0].to(device) 105 | out_real = netD(real_data).view(-1).mean() 106 | ## train on fake images 107 | # perform permutation + reshape to turn volume into batch of 2D images to pass to D 108 | fake_data_perm = fake_data.permute(0, d1, 1, d2, d3).reshape(l * D_batch_size, nc, l, l) 109 | out_fake = netD(fake_data_perm).mean() 110 | gradient_penalty = util.calc_gradient_penalty(netD, real_data, fake_data_perm[:batch_size], 111 | batch_size, l, 112 | device, Lambda, nc) 113 | disc_cost = out_fake - out_real + gradient_penalty 114 | disc_cost.backward() 115 | optimizer.step() 116 | #logs for plotting 117 | disc_real_log.append(out_real.item()) 118 | disc_fake_log.append(out_fake.item()) 119 | Wass_log.append(out_real.item() - out_fake.item()) 120 | gp_log.append(gradient_penalty.item()) 121 | ### Generator Training 122 | if i % int(critic_iters) == 0: 123 | netG.zero_grad() 124 | errG = 0 125 | noise = torch.randn(batch_size, nz, lz,lz,lz, device=device) 126 | fake = netG(noise) 127 | 128 | for dim, (netD, d1, d2, d3) in enumerate( 129 | zip(netDs, [2, 3, 4], [3, 2, 2], [4, 4, 3])): 130 | if isotropic: 131 | #only need one D 132 | netD = netDs[0] 133 | # permute and reshape to feed to disc 134 | fake_data_perm = fake.permute(0, d1, 1, d2, d3).reshape(l * batch_size, nc, l, l) 135 | output = netD(fake_data_perm) 136 | errG -= output.mean() 137 | # Calculate gradients for G 138 | errG.backward() 139 | optG.step() 140 | 141 | # Output training stats & show imgs 142 | if i % 25 == 0: 143 | netG.eval() 144 | with torch.no_grad(): 145 | torch.save(netG.state_dict(), pth + '_Gen.pt') 146 | torch.save(netD.state_dict(), pth + '_Disc.pt') 147 | noise = torch.randn(1, nz,lz,lz,lz, device=device) 148 | img = netG(noise) 149 | ###Print progress 150 | ## calc ETA 151 | steps = len(dataloaderx) 152 | util.calc_eta(steps, time.time(), start, i, epoch, num_epochs) 153 | ###save example slices 154 | util.test_plotter(img, 5, imtype, pth) 155 | # plotting graphs 156 | util.graph_plot([disc_real_log, disc_fake_log], ['real', 'perp'], pth, 'LossGraph') 157 | util.graph_plot([Wass_log], ['Wass Distance'], pth, 'WassGraph') 158 | util.graph_plot([gp_log], ['Gradient Penalty'], pth, 'GpGraph') 159 | netG.train() 160 | -------------------------------------------------------------------------------- /src/slicegan/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import pickle 5 | def slicegan_nets(pth, Training, imtype, dk,ds,df,dp,gk,gs,gf,gp): 6 | """ 7 | Define a generator and Discriminator 8 | :param Training: If training, we save params, if not, we load params from previous. 9 | This keeps the parameters consistent for older models 10 | :return: 11 | """ 12 | #save params 13 | params = [dk, ds, df, dp, gk, gs, gf, gp] 14 | # if fresh training, save params 15 | if Training: 16 | with open(pth + '_params.data', 'wb') as filehandle: 17 | # store the data as binary data stream 18 | pickle.dump(params, filehandle) 19 | # if loading model, load the associated params file 20 | else: 21 | with open(pth + '_params.data', 'rb') as filehandle: 22 | # read the data as binary data stream 23 | dk, ds, df, dp, gk, gs, gf, gp = pickle.load(filehandle) 24 | 25 | 26 | # Make nets 27 | class Generator(nn.Module): 28 | def __init__(self): 29 | super(Generator, self).__init__() 30 | self.convs = nn.ModuleList() 31 | self.bns = nn.ModuleList() 32 | for lay, (k,s,p) in enumerate(zip(gk,gs,gp)): 33 | self.convs.append(nn.ConvTranspose3d(gf[lay], gf[lay+1], k, s, p, bias=False)) 34 | self.bns.append(nn.BatchNorm3d(gf[lay+1])) 35 | 36 | def forward(self, x): 37 | for conv,bn in zip(self.convs[:-1],self.bns[:-1]): 38 | x = F.relu_(bn(conv(x))) 39 | #use tanh if colour or grayscale, otherwise softmax for one hot encoded 40 | if imtype in ['grayscale', 'colour']: 41 | out = 0.5*(torch.tanh(self.convs[-1](x))+1) 42 | else: 43 | out = torch.softmax(self.convs[-1](x),1) 44 | return out 45 | 46 | class Discriminator(nn.Module): 47 | def __init__(self): 48 | super(Discriminator, self).__init__() 49 | self.convs = nn.ModuleList() 50 | for lay, (k, s, p) in enumerate(zip(dk, ds, dp)): 51 | self.convs.append(nn.Conv2d(df[lay], df[lay + 1], k, s, p, bias=False)) 52 | 53 | def forward(self, x): 54 | for conv in self.convs[:-1]: 55 | x = F.relu_(conv(x)) 56 | x = self.convs[-1](x) 57 | return x 58 | 59 | return Discriminator, Generator 60 | 61 | -------------------------------------------------------------------------------- /src/slicegan/preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import tifffile 5 | def batch(data,type,l, sf): 6 | """ 7 | Generate a batch of images randomly sampled from a training microstructure 8 | :param data: data path 9 | :param type: data type 10 | :param l: image size 11 | :param sf: scale factor 12 | :return: 13 | """ 14 | Testing = False 15 | if type == 'png' or type == 'jpg': 16 | datasetxyz = [] 17 | for img in data: 18 | img = plt.imread(img) 19 | if len(img.shape)>2: 20 | img = img[:,:,0] 21 | 22 | img = img[::sf,::sf] 23 | x_max, y_max= img.shape[:] 24 | 25 | phases = np.unique(img) 26 | data = np.zeros([32 * 900, len(phases), l, l]) 27 | for i in range(32 * 900): 28 | x = np.random.randint(1, x_max - l-1) 29 | y = np.random.randint(1, y_max - l-1) 30 | # create one channel per phase for one hot encoding 31 | for cnt, phs in enumerate(phases): 32 | img1 = np.zeros([l, l]) 33 | img1[img[x:x + l, y:y + l] == phs] = 1 34 | data[i, cnt, :, :] = img1 35 | if Testing: 36 | for j in range(7): 37 | plt.imshow(data[j, 0, :, :]+2*data[j, 1, :, :]) 38 | plt.pause(0.3) 39 | plt.show() 40 | plt.clf() 41 | plt.close() 42 | data = torch.FloatTensor(data) 43 | dataset = torch.utils.data.TensorDataset(data) 44 | datasetxyz.append(dataset) 45 | 46 | elif type=='tif': 47 | datasetxyz=[] 48 | img = np.array(tifffile.imread(data[0])) 49 | img = img[::sf,::sf,::sf] 50 | ## Create a data store and add random samples from the full image 51 | x_max, y_max, z_max = img.shape[:] 52 | print('training image shape: ', img.shape) 53 | vals = np.unique(img) 54 | for dim in range(3): 55 | data = np.empty([32 * 900, len(vals), l, l]) 56 | print('dataset ', dim) 57 | for i in range(32*900): 58 | x = np.random.randint(0, x_max - l) 59 | y = np.random.randint(0, y_max - l) 60 | z = np.random.randint(0, z_max - l) 61 | # create one channel per phase for one hot encoding 62 | lay = np.random.randint(img.shape[dim]-1) 63 | for cnt,phs in enumerate(list(vals)): 64 | img1 = np.zeros([l,l]) 65 | if dim==0: 66 | img1[img[lay, y:y + l, z:z + l] == phs] = 1 67 | elif dim==1: 68 | img1[img[x:x + l,lay, z:z + l] == phs] = 1 69 | else: 70 | img1[img[x:x + l, y:y + l,lay] == phs] = 1 71 | data[i, cnt, :, :] = img1[:,:] 72 | # data[i, (cnt+1)%3, :, :] = img1[:,:] 73 | 74 | if Testing: 75 | for j in range(2): 76 | plt.imshow(data[j, 0, :, :] + 2 * data[j, 1, :, :]) 77 | plt.pause(1) 78 | plt.show() 79 | plt.clf() 80 | plt.close() 81 | data = torch.FloatTensor(data) 82 | dataset = torch.utils.data.TensorDataset(data) 83 | datasetxyz.append(dataset) 84 | 85 | elif type=='colour': 86 | ## Create a data store and add random samples from the full image 87 | datasetxyz = [] 88 | for img in data: 89 | img = plt.imread(img) 90 | img = img[::sf,::sf,:] 91 | ep_sz = 32 * 900 92 | data = np.empty([ep_sz, 3, l, l]) 93 | x_max, y_max = img.shape[:2] 94 | for i in range(ep_sz): 95 | x = np.random.randint(0, x_max - l) 96 | y = np.random.randint(0, y_max - l) 97 | # create one channel per phase for one hot encoding 98 | data[i, 0, :, :] = img[x:x + l, y:y + l,0] 99 | data[i, 1, :, :] = img[x:x + l, y:y + l,1] 100 | data[i, 2, :, :] = img[x:x + l, y:y + l,2] 101 | print('converting') 102 | if Testing: 103 | datatest = np.swapaxes(data,1,3) 104 | datatest = np.swapaxes(datatest,1,2) 105 | for j in range(5): 106 | plt.imshow(datatest[j, :, :, :]) 107 | plt.pause(0.5) 108 | plt.show() 109 | plt.clf() 110 | plt.close() 111 | data = torch.FloatTensor(data) 112 | dataset = torch.utils.data.TensorDataset(data) 113 | datasetxyz.append(dataset) 114 | elif type=='grayscale': 115 | datasetxyz = [] 116 | for img in data: 117 | img = plt.imread(img) 118 | if len(img.shape) > 2: 119 | img = img[:, :, 0] 120 | img = img/img.max() 121 | img = img[::sf, ::sf] 122 | x_max, y_max = img.shape[:] 123 | data = np.empty([32 * 900, 1, l, l]) 124 | for i in range(32 * 900): 125 | x = np.random.randint(1, x_max - l - 1) 126 | y = np.random.randint(1, y_max - l - 1) 127 | subim = img[x:x + l, y:y + l] 128 | data[i, 0, :, :] = subim 129 | if Testing: 130 | for j in range(7): 131 | plt.imshow(data[j, 0, :, :]) 132 | plt.pause(0.3) 133 | plt.show() 134 | plt.clf() 135 | plt.close() 136 | data = torch.FloatTensor(data) 137 | dataset = torch.utils.data.TensorDataset(data) 138 | datasetxyz.append(dataset) 139 | return datasetxyz 140 | 141 | 142 | -------------------------------------------------------------------------------- /src/slicegan/run_slicegan.py: -------------------------------------------------------------------------------- 1 | ### Welcome to SliceGAN ### 2 | ####### Steve Kench ####### 3 | ''' 4 | Use this file to define your settings for a training run, or 5 | to generate a synthetic image using a trained generator. 6 | ''' 7 | 8 | from src.slicegan import model, networks, util 9 | import argparse 10 | import os 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | 14 | def slicegan_dataset(): 15 | if not os.path.exists('data/slicegan_runs'): 16 | os.mkdir('data/slicegan_runs') 17 | tags = sorted(os.listdir('data/micrographs_final')) 18 | for tag in tags: 19 | slicegan_tag('train', tag) 20 | 21 | def slicegan_gen_dataset(): 22 | tags = sorted(os.listdir('data/slicegan_runs')) 23 | tags = [tag +'.png' for tag in tags] 24 | for tag in tags: 25 | try: 26 | slicegan_tag('gen', tag) 27 | except: 28 | print(f'couldnt generate {tag}') 29 | 30 | def slicegan_tag(mode, tag): 31 | # Define project name 32 | Project_name = tag[:-4] 33 | # Specify project folder. 34 | Project_dir = 'data/slicegan_runs/' 35 | if not os.path.exists(Project_dir): 36 | os.mkdir(Project_dir) 37 | # Run with False to show an image during or after training 38 | Training = True if mode == 'train' else False 39 | Project_path = util.mkdr(Project_name, Project_dir, Training,) 40 | if not Project_path: 41 | return 42 | ## Data Processing 43 | # Define image type (colour, grayscale, three-phase or two-phase. 44 | # n-phase materials must be segmented) 45 | img = plt.imread(f'data/micrographs_final/{tag}') 46 | 47 | image_type = 'twophase' if len(np.unique(img)) == 2 else 'grayscale' 48 | # define data type (for colour/grayscale images, must be 'colour' / ' 49 | # greyscale. nphase can be, 'tif', 'png', 'jpg','array') 50 | data_type = 'png' if image_type == 'twophase' else 'grayscale' 51 | # Path to your data. One string for isotrpic, 3 for anisotropic 52 | 53 | data_path = [f'data/micrographs_final/{tag}'] 54 | 55 | ## Network Architectures 56 | # Training image size, no. channels and scale factor vs raw data 57 | img_size, img_channels, scale_factor = 64, 2 if image_type=='twophase' else 1, 1 58 | # z vector depth 59 | z_channels = 16 60 | # Layers in G and D 61 | lays = 5 62 | # kernals for each layer 63 | dk, gk = [4]*lays, [4]*lays 64 | # strides 65 | ds, gs = [2]*lays, [2]*lays 66 | # no. filters 67 | df, gf = [img_channels,64,128,256,512,1], [z_channels,512,256,128,64,img_channels] 68 | # paddings 69 | dp, gp = [1,1,1,1,0],[2,2,2,2,3] 70 | 71 | ## Create Networks 72 | netD, netG = networks.slicegan_nets(Project_path, Training, image_type, dk, ds, df,dp, gk ,gs, gf, gp) 73 | print('training') 74 | # Train 75 | if Training: 76 | model.train(Project_path, image_type, data_type, data_path, netD, netG, img_channels, img_size, z_channels, scale_factor) 77 | else: 78 | img, raw, netG = util.test_img(Project_path, image_type, netG(), z_channels, lf=12, periodic=False) 79 | -------------------------------------------------------------------------------- /src/slicegan/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch import nn 3 | import torch 4 | from torch import autograd 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import tifffile 8 | import sys 9 | ## Training Utils 10 | 11 | def mkdr(proj,proj_dir,Training): 12 | """ 13 | When training, creates a new project directory or overwrites an existing directory according to user input. When testing, returns the full project path 14 | :param proj: project name 15 | :param proj_dir: project directory 16 | :param Training: whether new training run or testing image 17 | :return: full project path 18 | """ 19 | pth = proj_dir + '/' + proj 20 | if Training: 21 | try: 22 | os.mkdir(pth) 23 | return pth + '/' + proj 24 | except FileExistsError: 25 | return False 26 | except FileNotFoundError: 27 | print('The specifified project directory ' + proj_dir + ' does not exist. Please change to a directory that does exist and again') 28 | sys.exit() 29 | else: 30 | return pth + '/' + proj 31 | 32 | 33 | def weights_init(m): 34 | """ 35 | Initialises training weights 36 | :param m: Convolution to be intialised 37 | :return: 38 | """ 39 | classname = m.__class__.__name__ 40 | if classname.find('Conv') != -1: 41 | nn.init.normal_(m.weight.data, 0.0, 0.02) 42 | elif classname.find('BatchNorm') != -1: 43 | nn.init.normal_(m.weight.data, 1.0, 0.02) 44 | nn.init.constant_(m.bias.data, 0) 45 | 46 | def calc_gradient_penalty(netD, real_data, fake_data, batch_size, l, device, gp_lambda,nc): 47 | """ 48 | calculate gradient penalty for a batch of real and fake data 49 | :param netD: Discriminator network 50 | :param real_data: 51 | :param fake_data: 52 | :param batch_size: 53 | :param l: image size 54 | :param device: 55 | :param gp_lambda: learning parameter for GP 56 | :param nc: channels 57 | :return: gradient penalty 58 | """ 59 | #sample and reshape random numbers 60 | alpha = torch.rand(batch_size, 1, device = device) 61 | alpha = alpha.expand(batch_size, int(real_data.nelement() / batch_size)).contiguous() 62 | alpha = alpha.view(batch_size, nc, l, l) 63 | 64 | # create interpolate dataset 65 | interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach()) 66 | interpolates.requires_grad_(True) 67 | 68 | #pass interpolates through netD 69 | disc_interpolates = netD(interpolates) 70 | gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, 71 | grad_outputs=torch.ones(disc_interpolates.size(), device = device), 72 | create_graph=True, only_inputs=True)[0] 73 | # extract the grads and calculate gp 74 | gradients = gradients.view(gradients.size(0), -1) 75 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * gp_lambda 76 | return gradient_penalty 77 | 78 | 79 | def calc_eta(steps, time, start, i, epoch, num_epochs): 80 | """ 81 | Estimates the time remaining based on the elapsed time and epochs 82 | :param steps: 83 | :param time: current time 84 | :param start: start time 85 | :param i: iteration through this epoch 86 | :param epoch: 87 | :param num_epochs: totale no. of epochs 88 | """ 89 | elap = time - start 90 | progress = epoch * steps + i + 1 91 | rem = num_epochs * steps - progress 92 | ETA = rem / progress * elap 93 | hrs = int(ETA / 3600) 94 | mins = int((ETA / 3600 % 1) * 60) 95 | print('[%d/%d][%d/%d]\tETA: %d hrs %d mins' 96 | % (epoch, num_epochs, i, steps, 97 | hrs, mins)) 98 | 99 | ## Plotting Utils 100 | def post_proc(img,imtype): 101 | """ 102 | turns one hot image back into grayscale 103 | :param img: input image 104 | :param imtype: image type 105 | :return: plottable image in the same form as the training data 106 | """ 107 | try: 108 | #make sure it's one the cpu and detached from grads for plotting purposes 109 | img = img.detach().cpu() 110 | except: 111 | pass 112 | # for n phase materials, seperate out the channels and take the max 113 | if imtype == 'twophase': 114 | img_pp = np.zeros(img.shape[2:]) 115 | p1 = np.array(img[0][0]) 116 | p2 = np.array(img[0][1]) 117 | img_pp[(p1 < p2)] = 1 # background, yellow 118 | return img_pp 119 | if imtype == 'threephase': 120 | img_pp = np.zeros(img.shape[2:]) 121 | p1 = np.array(img[0][0]) 122 | p2 = np.array(img[0][1]) 123 | p3 = np.array(img[0][2]) 124 | img_pp[(p1 > p2) & (p1 > p3)] = 0 # background, yellow 125 | img_pp[(p2 > p1) & (p2 > p3)] = 1 # spheres, green 126 | img_pp[(p3 > p2) & (p3 > p1)] = 2 # binder, purple 127 | return img_pp 128 | # colour and grayscale don't require post proc, just a shift 129 | if imtype == 'colour': 130 | return np.int_(255 * (np.swapaxes(img[0], 0, -1))) 131 | if imtype == 'grayscale': 132 | return 255*img[0][0] 133 | 134 | def test_plotter(img,slcs,imtype,pth): 135 | """ 136 | creates a fig with 3*slc subplots showing example slices along the three axes 137 | :param img: raw input image 138 | :param slcs: number of slices to take in each dir 139 | :param imtype: image type 140 | :param pth: where to save plot 141 | """ 142 | img = post_proc(img,imtype) 143 | fig, axs = plt.subplots(slcs, 3) 144 | if imtype == 'colour': 145 | for j in range(slcs): 146 | axs[j, 0].imshow(img[j, :, :, :], vmin = 0, vmax = 255) 147 | axs[j, 1].imshow(img[:, j, :, :], vmin = 0, vmax = 255) 148 | axs[j, 2].imshow(img[:, :, j, :], vmin = 0, vmax = 255) 149 | elif imtype == 'grayscale': 150 | for j in range(slcs): 151 | axs[j, 0].imshow(img[j, :, :], cmap = 'gray') 152 | axs[j, 1].imshow(img[:, j, :], cmap = 'gray') 153 | axs[j, 2].imshow(img[:, :, j], cmap = 'gray') 154 | else: 155 | for j in range(slcs): 156 | axs[j, 0].imshow(img[j, :, :]) 157 | axs[j, 1].imshow(img[:, j, :]) 158 | axs[j, 2].imshow(img[:, :, j]) 159 | plt.savefig(pth + '_slices.png') 160 | plt.close() 161 | 162 | def graph_plot(data,labels,pth,name): 163 | """ 164 | simple plotter for all the different graphs 165 | :param data: a list of data arrays 166 | :param labels: a list of plot labels 167 | :param pth: where to save plots 168 | :param name: name of the plot figure 169 | :return: 170 | """ 171 | 172 | for datum,lbl in zip(data,labels): 173 | plt.plot(datum, label = lbl) 174 | plt.legend() 175 | plt.savefig(pth + '_' + name) 176 | plt.close() 177 | 178 | 179 | def test_img(pth, imtype, netG, nz = 64, lf = 4, periodic=False): 180 | """ 181 | saves a test volume for a trained or in progress of training generator 182 | :param pth: where to save image and also where to find the generator 183 | :param imtype: image type 184 | :param netG: Loaded generator class 185 | :param nz: latent z dimension 186 | :param lf: length factor 187 | :param show: 188 | :param periodic: list of periodicity in axis 1 through n 189 | :return: 190 | """ 191 | netG.load_state_dict(torch.load(pth + '_Gen.pt')) 192 | netG.eval() 193 | netG.cuda() 194 | 195 | noise = torch.randn(1, nz, lf, lf, lf,).cuda() 196 | if periodic: 197 | if periodic[0]: 198 | noise[:, :, :2] = noise[:, :, -2:] 199 | if periodic[1]: 200 | noise[:, :, :, :2] = noise[:, :, :, -2:] 201 | if periodic[2]: 202 | noise[:, :, :, :, :2] = noise[:, :, :, :, -2:] 203 | raw = netG(noise) 204 | print('Postprocessing') 205 | gb = post_proc(raw,imtype) 206 | if periodic: 207 | if periodic[0]: 208 | gb = gb[:-1] 209 | if periodic[1]: 210 | gb = gb[:,:-1] 211 | if periodic[2]: 212 | gb = gb[:,:,:-1] 213 | tif = np.int_(gb) 214 | tifffile.imwrite(pth + 'test.tif', tif) 215 | 216 | return tif, raw, netG 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | -------------------------------------------------------------------------------- /src/style.qss: -------------------------------------------------------------------------------- 1 | * { 2 | font-size:13px; 3 | color:#c1c1c1; 4 | font-family:"Helvetica"; 5 | background:black; 6 | } 7 | 8 | 9 | QToolBar { 10 | background-color: black; 11 | border: none; 12 | } 13 | 14 | QToolButton { 15 | background:white; 16 | color:black; 17 | margin: 5px; 18 | margin-top: 12px; 19 | margin-bottom: 12px; 20 | padding: 7px; 21 | border-radius: 2px; 22 | border: 2px; 23 | 24 | } 25 | 26 | QPushButton { 27 | background:white; 28 | color:black; 29 | margin: 5px; 30 | margin-top: 12px; 31 | margin-bottom: 12px; 32 | padding: 7px; 33 | border-radius: 2px; 34 | border: 2px; 35 | 36 | } 37 | --------------------------------------------------------------------------------