├── .cargo └── config.toml ├── .gitignore ├── LICENSE ├── README.md ├── aethon.py ├── cfgs ├── PG.yaml ├── PW.yaml ├── jobarray.py └── slurmconfig.yaml ├── core ├── __init__.py ├── defaults.yaml └── random_seeds.npy ├── datasets ├── DLRDatasetLoader.py ├── IPIDatasetLoader.py ├── ISAIDDatasetLoader.py ├── ISPRSDatasetLoader.py ├── LGNDatasetLoader.py ├── MultiSegmentationDataset.py ├── SegmentationDataset.py ├── SegmentationDatasetPatchProviders.py ├── SemcityToulouseDatasetLoader.py ├── SparseInstanceImage.py ├── SynthinelDatasetLoader.py ├── ToyDataset.py ├── __init__.py └── isaid_cache.npz ├── hpo ├── aethonproxy.py └── smac_example.py ├── models ├── DeepLabv3p.py ├── FCN.py ├── FarSeg.py ├── PFNet.py ├── RAFCN.py ├── SegForestComponents.py ├── SegForestNet.py ├── SegForestTree.py ├── SegForestTreeDecoder.py ├── UNet.py ├── UNetpp.py ├── Xception.py ├── __init__.py └── xception.json.bz2 ├── rust ├── Cargo.toml ├── __init__.py └── src │ ├── drawcircle.rs │ ├── lib.rs │ ├── metrics.rs │ ├── pixelstats.rs │ ├── rasterize.rs │ ├── sparseimage.rs │ └── utils.rs ├── samples.png ├── tasks ├── __init__.py ├── archive.py ├── array.py ├── evalmodel.py ├── monitor.py ├── remotesync.py ├── semanticsegmentation.py └── slurm.py ├── user.yaml └── utils ├── __init__.py ├── confusionmatrix.py ├── lrscheduler.py ├── modelfitter.py ├── preprocess ├── dlr_landcover.py ├── model_weights.py └── toulouse.py ├── sam.py └── vectorquantization.py /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | rustflags = ["-C", "target-cpu=native"] 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ---> Python 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # Rust 133 | rust/src/exports.rs 134 | rust/Cargo.lock 135 | rust/target/ 136 | 137 | 138 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, Daniel Gritzner 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SegForestNet 2 | Reference implementation of SegForestNet, a model which predicts binary space partitioning trees to compute a semantic segmentation of aerial images. [The associated paper titled "SegForestNet: Spatial-Partitioning-Based Aerial Image Segmentation" is available on arXiv](https://arxiv.org/abs/2302.01585). Please cite our paper if you use anything from this repository. 3 | 4 | ```bibtex 5 | @misc{gritzner2024segforestnet, 6 | title = {SegForestNet: Spatial-Partitioning-Based Aerial Image Segmentation}, 7 | author = {Gritzner, Daniel and Ostermann, Jörn}, 8 | publisher = {arXiv}, 9 | year = {2024}, 10 | eprint = {2302.01585}, 11 | archivePrefix = {arXiv}, 12 | primaryClass = {cs.CV} 13 | doi = {10.48550/ARXIV.2302.01585}, 14 | url = {https://arxiv.org/abs/2302.01585}, 15 | keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences, I.5.4}, 16 | } 17 | ``` 18 | # Results 19 | Our model delivers state-of-the-art performance, even under non optimal training conditions (see paper for details). While other models, e.g., DeepLab v3+, deliver performance on a similar level, SegForestNet is better at predicting small object such as cars properly. It predicts proper rectangles rather than round-ish shapes. Also, car segments which should be disconnected may merge into one larger region when using other models. Model weights are available in [this release](https://github.com/gritzner/SegForestNet/releases/tag/2024-04). 20 | 21 | Mean $F_1$ scores: 22 | 23 | | | Hannover | Buxtehude | Nienburg | Schleswig | Hameln | Vaihingen | Potsdam | Toulouse | 24 | | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | 25 | | FCN | 84.9% | 87.7% | 85.5% | 82.6% | 87.8% | 86.6% | 91.3% | 75.8% | 26 | | DeepLab v3+ | __85.7%__ | 88.7% | 86.7% | __83.6%__ | 88.6% | __86.9%__ | __91.5%__ | __77.6%__ | 27 | | __SegForestNet__ | 85.5% | __88.8%__ | 86.2% | 83.0% | __88.7%__ | 86.8% | 91.3% | 74.8% | 28 | | PFNet | 85.4% | 88.4% | 86.3% | 83.2% | 88.4% | 86.8% | __91.5%__ | 75.8% | 29 | | FarSeg | __85.7%__ | 88.5% | __86.8%__ | 82.8% | 88.4% | 86.7% | 91.4% | 75.0% | 30 | | U-Net | 84.3% | 86.7% | 85.5% | 78.5% | 86.8% | 84.2% | 88.6% | 75.2% | 31 | | RA-FCN | 78.5% | 83.1% | 80.0% | 74.6% | 83.9% | 82.6% | 86.6% | 66.9% | 32 | 33 | ![](samples.png) 34 | 35 | # How to run 36 | ### Dependencies 37 | The code has been tested on openSUSE Leap 15.4 running the following software: 38 | * cargo 1.74.1 (1.67.1 may also be sufficient) 39 | * cuda 11.6.1 40 | * libtiff 4.5.0 41 | * matplotlib 3.7.0 42 | * numpy 1.23.5 43 | * opencv 4.6.0 44 | * python 3.10.10 45 | * pytorch 1.13.1 46 | * pyyaml 6.0 47 | * rustc 1.74.1 (1.67.1 may also be sufficient) 48 | * scikit-learn 1.2.1 49 | * scipy 1.10.0 50 | * torchvision 0.14.1 51 | 52 | Optional dependencies: 53 | * geotiff 1.7.0 54 | * tifffile 2021.7.2 55 | * timm 0.9.2 56 | 57 | ### Preparing the training environment (optional) 58 | Using pretrained encoder weights requires executing ```utils/preprocess/model_weights.py``` once to download the necessary model weights (for legacy reasons this will also download weights for another encoder which is no longer used in this codebase). Two of the datasets (DLR Multi-Sensor Land-Cover Classification (MSLCC) and SemCity Toulouse) also require executing the appropriate Python script in ```utils/preprocess/``` once. This is necessary to convert some ```.tif``` files into a format that OpenCV likes. The scripts in ```utils/preprocess/``` need the optional dependencies. 59 | 60 | ### Running the code 61 | Open a terminal in the directory you cloned this repository into and execute the following command: 62 | 63 | ```shell 64 | python aethon.py PW potsdam SegForestNet 65 | ``` 66 | 67 | This will use the configuration file ```cfgs/PW.yaml``` to run our framework. Furthermore, you will need a user configuration file called ```~/.aethon/user.yaml```. An example user configuration can be found in ```user.yaml```. The full configuration our framework will parse will be the concatenation of ```core/defaults.yaml``` and ```cfgs/semseg.yaml```. Additionally, all the occurances of ```$N``` in ```cfgs/PW.yaml``` will be replaced by the parameters given in the commandline, e.g., ```$0``` will become ```potsdam``` and ```$1```will become ```SegForestNet```. The example above will run our framework to train our model with on the Potsdam dataset using the first random seed from the array in ```core/random_seeds.npy``` for data augmentation. This is the same random seed we used for the experiments in our paper. 68 | 69 | Even though we cannot provide some of the datasets used in the paper for legal reasons we still provide their data loaders as a reference. The data loaders can be found in ```datasets/```. 70 | 71 | The training results, including an evaluation of the trained model on the validation and test subsets, can be found in the appropriate subfolder in ```tmp/PW/```once training is complete. 72 | 73 | ### Running within a Jupyter notebook 74 | You can run our code in Jupyter by simply copying the content of ```aethon.py``` to a notebook and adding the commandline parameters to the second line. Example for the second line: 75 | 76 | ```python 77 | core.init("PW potsdam SegForestNet") 78 | ``` 79 | 80 | # Model code 81 | If you are only interested in the code of our model, take a look at ```models/SegForest*.py```. The class ```SegForestNet``` implements our model. It uses several helper classes to give our already complicated code some additional structuring. The constructor of our model has two parameters in addition to ```self```: 82 | * ```params``` is an object with the two attributes ```input_shape``` and ```num_classes``` so that the model knows what kind of data to expect. See line 29 ```tasks/semanticsegmentation.py``` for an example. 83 | * ```config``` is an object which is a parsed version of the relevant subset of the configuration file used to run our framework, in particular the section ```SegForestNet_params``` in ```cfgs/PW.yaml``` in the example above. The parsing is done by the ```parse_dict``` function in ```core/__init__.py```. 84 | 85 | The ```trees``` subsection of the configuration is of particular interest. It defines the number of trees to predict per block. Each entry of the list ```trees``` will later become an instance of ```models/SegForestTree.py``` with each tree object consisting of a pair of decoders and representing a different tree. The attribute ```graph``` defines the tree structure in terms of components (found in ```models/SegForestComponents.py```). ```eval``` is used to turn ```graph``` into an actual tree object which is technically a security problem. However, the only use cases our framework is supposed to be used in are use cases in which the person triggering the execution of our framework has full system access anyway or at least enough system access to execute arbitrary Python or Rust code. **Note:** this is not the only instance of insecure code in our framework. Examples of valid tree graphs are: 86 | * ```BSPTree(2, Line)```: for a BSP tree of depth two, i.e., a total of three inner nodes and four leaf nodes, using $f_1$ from our paper as signed distance function 87 | * ```BSPTree(2, Circle)```: same as above but using $f_3$ instead of $f_1$ 88 | * ```BSPNode(BSPTree(1, Line), Leaf, Line)```: a BSP tree with two inner nodes (the left child of the root node is a BSP tree of depth one while the right child is a leaf node already) and three leaf nodes, using $f_1$ in all inner nodes 89 | 90 | The different signed distance functions are defined in the appendix of the paper. 91 | 92 | The attribute ```one_tree_per_class``` causes the list of ```trees``` to automatically be expanded such that there is exactly one tree for each class. All trees will use the same configuration, e.g., the same ```graph```. In case multiple trees are defined manually an attribute called ```outputs``` must defined for each tree. It is a list of integers defining which tree is responsible for predicting the logits of which class. Examples: 93 | * ```[0]``` predict logits for the first class 94 | * ```[1, 2, 4]``` predict logits for classes two, three and five 95 | 96 | The union of all ```outputs``` must be the set of all classes and the intersection of ```outputs``` of any two different trees must be empty. 97 | 98 | If you want to use SegForestNet outside our framework you need these files: 99 | * models/SegForestNet.py 100 | * models/SegForestTree.py 101 | * models/SegForestTreeDecoder.py 102 | * models/SegForestComponents.py 103 | * models/Xception.py 104 | * models/xception.json.bz2 105 | * utils/\_\_init\_\_.py 106 | * utils/vectorquantization.py 107 | 108 | You need to fix several dependencies. To remove the dependency on the core module, replace all instances of ```core.device``` with the appropriate device, usually ```torch.device("cuda:0")```. Add ```import gzip``` to ```Xception.py``` and in line 142 use ```gzip.open(...)``` instead of ```core.open(...)```. Also, in the same line, change the first argument of ```open``` to the path of your downloaded Xception model weights. In ```utils/__init__.py``` comment out lines one to three as well as line five. -------------------------------------------------------------------------------- /aethon.py: -------------------------------------------------------------------------------- 1 | import core 2 | core.init() 3 | 4 | import tasks 5 | if core.args.configuration[0] != "@": 6 | task = core.create_object(tasks, core.task) 7 | task.run() 8 | core.call(f"touch {core.output_path}/done") 9 | else: 10 | import os 11 | if "LD_LIBRARY_PATH" in os.environ: 12 | del os.environ["LD_LIBRARY_PATH"] # fixes an issue with calling ssh after cv2 has been imported 13 | getattr(tasks, core.args.configuration[1:])() 14 | -------------------------------------------------------------------------------- /cfgs/PG.yaml: -------------------------------------------------------------------------------- 1 | task: SemanticSegmentation 2 | output_path: tmp/PG 3 | clear_output_path: True 4 | 5 | SegmentationDataset_params: 6 | <<: *$0-defaults 7 | random_seed: 0 8 | patch_size: [224, 224] 9 | training_samples: 8000 10 | augmentation: 11 | scaling: .3 12 | x_shear: 6 13 | y_shear: 6 14 | rotation: 30 15 | h_flip: 0 16 | v_flip: 0 17 | contrast: 0 18 | brightness: 0 19 | noise: .1 20 | at_test_time: False 21 | 22 | SegForestNet_params: 23 | <<: *SegForestNet-defaults 24 | downsampling: 3 25 | pretrained_encoder: False 26 | trees: 27 | - num_features: 28 | shape: 8 29 | content: 24 30 | shape_to_content: 0 31 | graph: BSPTree(2, Line) 32 | classifier: [] 33 | classifier_skip_from: 0 34 | classifier_context: 2 35 | one_tree_per_class: True 36 | decoder: 37 | type: TreeFeatureDecoder 38 | num_blocks: 8 39 | context: 1 40 | intermediate_features: 96 41 | use_residual_blocks: True 42 | vq: 43 | type: [0, 0] 44 | region_map: 45 | accumulation: add 46 | node_weight: 1 47 | softmax_temperature: 48 | parameters: [epoch, epochs] 49 | func: 0 50 | value_range: [1, 1] 51 | loss: 52 | cross_entropy: pixels 53 | ce_constant: 10 54 | distribution_metric: gini 55 | min_region_size: 4 56 | weights: [.8625,.0475,.035,.055,0] 57 | 58 | FCN_params: 59 | downsampling: 5 60 | pretrained_encoder: False 61 | 62 | RAFCN_params: 63 | downsampling: 5 64 | pretrained_encoder: False 65 | 66 | FarSeg_params: 67 | downsampling: 5 68 | pretrained_encoder: False 69 | 70 | PFNet_params: 71 | downsampling: 5 72 | pretrained_encoder: False 73 | 74 | UNet_params: 75 | 76 | DeepLabv3p_params: 77 | downsampling: 3 78 | pretrained_encoder: False 79 | aspp_dilation_rates: [] # default for MobileNetv2 backbone (performs best, even with Xception) 80 | #aspp_dilation_rates: [12, 24, 36] # default for Xception backbone with downsampling == 3 according to DeepLab paper 81 | #aspp_dilation_rates: [6, 12, 18] # default for Xception backbone with downsampling == 4 according to DeepLab paper 82 | 83 | SemanticSegmentation_params: 84 | <<: *SemanticSegmentation-defaults 85 | dataset: SegmentationDataset 86 | model: $1 87 | epochs: $2 88 | mini_batch_size: 18 89 | shuffle_seed: -1 90 | num_samples_per_epoch: 8000 91 | unique_iterations: False 92 | optimizer: 93 | type: AdamW 94 | arguments: 95 | betas: [0.9, 0.999] 96 | weight_decay: 0.01 97 | learning_rate: 98 | max_value: $3 99 | min_value: 0 100 | num_cycles: 1 101 | cycle_length_factor: 2 102 | num_iterations_factor: 1 103 | gradient_clipping: 0 104 | class_weights: 105 | ignore_dataset: False 106 | ignored_class_weight: 0 107 | dynamic_exponent: 0 108 | -------------------------------------------------------------------------------- /cfgs/PW.yaml: -------------------------------------------------------------------------------- 1 | task: SemanticSegmentation 2 | output_path: tmp/PW 3 | clear_output_path: True 4 | 5 | SegmentationDataset_params: 6 | <<: *$0-defaults 7 | random_seed: 0 8 | patch_size: [256, 256] 9 | 10 | SegForestNet_params: 11 | <<: *SegForestNet-defaults 12 | downsampling: 3 13 | pretrained_encoder: True 14 | trees: 15 | - num_features: 16 | shape: 8 17 | content: 24 18 | shape_to_content: 0 19 | graph: BSPTree(2, Line) 20 | classifier: [] 21 | classifier_skip_from: 0 22 | classifier_context: 2 23 | one_tree_per_class: True 24 | decoder: 25 | type: TreeFeatureDecoder 26 | num_blocks: 8 27 | context: 1 28 | intermediate_features: 96 29 | use_residual_blocks: True 30 | vq: 31 | type: [0, 0] 32 | region_map: 33 | accumulation: add 34 | node_weight: 1 35 | softmax_temperature: 36 | parameters: [epoch, epochs] 37 | func: 0 38 | value_range: [1, 1] 39 | loss: 40 | cross_entropy: pixels 41 | ce_constant: 10 42 | distribution_metric: gini 43 | min_region_size: 4 44 | weights: [.8625,.0475,.035,.055,0] 45 | 46 | FCN_params: 47 | downsampling: 5 48 | pretrained_encoder: True 49 | 50 | RAFCN_params: 51 | downsampling: 5 52 | pretrained_encoder: True 53 | 54 | FarSeg_params: 55 | downsampling: 5 56 | pretrained_encoder: True 57 | 58 | PFNet_params: 59 | downsampling: 5 60 | pretrained_encoder: True 61 | 62 | UNet_params: 63 | 64 | DeepLabv3p_params: 65 | downsampling: 3 66 | pretrained_encoder: True 67 | aspp_dilation_rates: [] # default for MobileNetv2 backbone (performs best, even with Xception) 68 | #aspp_dilation_rates: [12, 24, 36] # default for Xception backbone with downsampling == 3 according to DeepLab paper 69 | #aspp_dilation_rates: [6, 12, 18] # default for Xception backbone with downsampling == 4 according to DeepLab paper 70 | 71 | SemanticSegmentation_params: 72 | <<: *SemanticSegmentation-defaults 73 | dataset: SegmentationDataset 74 | model: $1 75 | shuffle_seed: -1 76 | -------------------------------------------------------------------------------- /cfgs/jobarray.py: -------------------------------------------------------------------------------- 1 | # use 'python aethon.py @archive' to create PG.tar.bz2 and copy it to the proper location (see slurmconfig.yaml) first 2 | # also, make sure 'python aethon.py @monitor slurmconfig' is running concurrently 3 | # 4 | # example call (runs @monitor concurrently): 'python aethon.py @array jobarray 0-55 @monitor slurmconfig @threads 16' 5 | # 6 | 7 | datasets = "hannover", "buxtehude", "nienburg", "potsdam", "vaihingen", "hameln_DA", "schleswig_DA", "toulouse_full" 8 | n = len(datasets) 9 | 10 | i = int(input()) 11 | i, j = i//n, i%n 12 | 13 | dataset = datasets[j] 14 | mem = (12, 12, 12, 44, 18, 12, 12, 25)[j] 15 | 16 | models = "PFNet", "FCN", "DeepLabv3p", "RAFCN", "FarSeg", "UNet", "SegForestNet" 17 | n = len(models) 18 | i, j = i//n, i%n 19 | 20 | model = models[j] 21 | epochs, learning_rate = { 22 | "PFNet": (100, 0.005), 23 | "FCN": (120, 0.006), 24 | "DeepLabv3p": (80, 0.0035), 25 | "RAFCN": (120, 0.0002), 26 | "FarSeg": (120, 0.001), 27 | "UNet": (80, 0.00015), 28 | "SegForestNet": (200, 0.003) 29 | }[model] 30 | 31 | slurm_prefix = f"@slurm slurmconfig PG 30:00:00 {mem}G PG.tar.bz2 tmp/PG " 32 | print(f"{slurm_prefix}PG {dataset} {model} {epochs} {learning_rate}") 33 | -------------------------------------------------------------------------------- /cfgs/slurmconfig.yaml: -------------------------------------------------------------------------------- 1 | servers: 2 | login: your_login_server # login via ssh must be possible without entering a password 3 | transfer: your_transfer_server # for file copying 4 | primary_path: /path/with/a/lot/of/space/to/store/lots/of/datasets 5 | secondary_path: /path/where/your/current/datasets/you/are/working/with/are/stored 6 | slurm: 7 | monitor: 8 | port: 7777 9 | timeout: 5 10 | remote_timeout: 60 11 | timer: 300 12 | user: your_username 13 | verbose: False 14 | options:map: 15 | - [mail-user, your_address@example.com] 16 | - [mail-type, ALL] 17 | - [partition, gpu_partition] 18 | - [nodes, 1] 19 | - [cpus-per-task, 12] 20 | - [gres, "gpu:1"] 21 | script: | 22 | #!/bin/bash 23 | {slurm_options} 24 | mkdir -p {jobs_path}/$SLURM_JOB_ID 25 | cd {jobs_path}/$SLURM_JOB_ID 26 | source /path/to/your/conda/installation/bin/activate torch 27 | export PROJ_LIB=/path/to/your/conda/installation/envs/torch/share/proj 28 | export PATH=/path/to/your/cargo/installation/usually/your/home/.cargo/bin:$PATH 29 | export LD_LIBRARY_PATH=/path/to/your/conda/installation/envs/torch/lib:$LD_LIBRARY_PATH 30 | tar -xf {jobs_path}/{code} 31 | python -u aethon.py {parameters} --git-log git_log.txt 32 | return_value=$? 33 | rm -rf tmp/cache 34 | exit $return_value 35 | jobs_path: /path/where/slurm/jobs/store/their/files 36 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import bz2 3 | import builtins 4 | import types 5 | import argparse 6 | from pathlib import Path 7 | import re 8 | import yaml 9 | import os 10 | import sys 11 | import subprocess 12 | call = lambda cmd: subprocess.call(cmd, shell=True) 13 | 14 | 15 | def str2bool(s): 16 | return not s.lower() in ("", "0", "f", "false", "n", "no", "off") 17 | 18 | 19 | def open(filename, mode): 20 | if filename[-3:] == ".gz": 21 | return gzip.open(filename, mode) 22 | elif filename[-4:] == ".bz2": 23 | return bz2.open(filename, mode) 24 | else: 25 | return builtins.open(filename, mode) 26 | 27 | 28 | def list_to_tuple(v): 29 | if type(v) == list: 30 | v = tuple([list_to_tuple(x) for x in v]) 31 | return v 32 | 33 | 34 | def parse_map(v): 35 | d = {} 36 | for x in v: 37 | if len(x) != 2: 38 | raise RuntimeError(f"cannot parse '{x}' into map entry: length is not 2.") 39 | d[list_to_tuple(x[0])] = list_to_tuple(x[1]) 40 | return d 41 | 42 | 43 | def parse_dict(d): 44 | result = types.SimpleNamespace() 45 | for k, v in d.items(): 46 | if type(v) == dict: 47 | v = parse_dict(v) 48 | elif type(v) == list: 49 | if k[-4:].lower() == ":map": 50 | k = k[:-4] 51 | v = parse_map(v) 52 | else: 53 | v = [parse_dict(x) if type(x)==dict else x for x in v] 54 | setattr(result, k, v) 55 | return result 56 | 57 | 58 | def get_object_meta_info(name): 59 | if type(name) == list: 60 | if len(name) > 1: 61 | return name[0], globals()[f"{name[1]}_params"] 62 | if len(name) > 0: 63 | name = name[0] 64 | else: 65 | raise RuntimeError("requested meta-information of empty list") 66 | return name, globals()[f"{name}_params"] 67 | 68 | 69 | def create_object(module, name, **params): 70 | name, object_config = get_object_meta_info(name) 71 | params = types.SimpleNamespace(**params) 72 | return getattr(module, name)(object_config, params) 73 | 74 | 75 | def init(args=None): 76 | parser = argparse.ArgumentParser(description="Deep Learning for Remote Sensing") 77 | parser.add_argument("configuration", nargs="?", default="", type=str, help="configuration file") 78 | parser.add_argument("parameters", nargs="*", type=str, help="configuration parameters") 79 | parser.add_argument("--cpu", action="store_true", help="use CPU instead of GPU") 80 | parser.add_argument("--compile", action="store_true", help="force Rust code compilation and then terminate") 81 | parser.add_argument("--git-log", default="", type=str, help="echo text file instead of actually running git log") 82 | args = parser.parse_args([arg for arg in args.split(" ") if len(arg) > 0] if type(args)==str else args) 83 | 84 | base_path = Path(__file__).absolute().parent.parent 85 | cache_path = f"{base_path}/tmp/cache" 86 | if not os.path.exists(cache_path): 87 | os.makedirs(cache_path) 88 | home = os.environ["HOME"] 89 | with open(f"{home}/.aethon/user.yaml", "r") as f: 90 | user = parse_dict(yaml.full_load("".join(f.readlines()))) 91 | assert user.git_log_num_lines >= 0 92 | 93 | if args.compile: 94 | import rust 95 | rust.init(True, cache_path) 96 | sys.exit(0) 97 | 98 | if len(args.configuration) == 0: 99 | raise RuntimeError(f"configuration file argument is undefined, try running with --help for additional information") 100 | config_file = base_path.joinpath("cfgs", args.configuration + ".yaml") 101 | 102 | print("%s/python %s"%(base_path, " ".join(sys.argv))) 103 | 104 | if args.configuration[0] != "@" and not config_file.is_file(): 105 | raise RuntimeError(f"unknown configuration file '{config_file}'") 106 | 107 | with open(f"{base_path}/core/defaults.yaml", "r") as f: 108 | defaults_yaml = "".join(f.readlines()) 109 | raw_yaml = "" 110 | if args.configuration[0] != "@": 111 | print(f"parsing configuration file '{config_file}'") 112 | with config_file.open() as f: 113 | raw_yaml = f.read() 114 | config_yaml = f"{raw_yaml}" 115 | pattern = re.compile("\$(\d)") 116 | while True: 117 | match = pattern.search(config_yaml) 118 | if not match: 119 | break 120 | span = match.span() 121 | index = int(match.group(1)) 122 | if index >= len(args.parameters): 123 | raise RuntimeError(f"insufficient number of parameters for configuration file; needs at least {index+1} parameter(s)") 124 | config_yaml = config_yaml[:span[0]] + str(args.parameters[index]) + config_yaml[span[1]:] 125 | full_config_yaml = f"{defaults_yaml}\n\n{config_yaml}" 126 | config = parse_dict(yaml.full_load(full_config_yaml)) 127 | 128 | config.base_path = base_path 129 | config.cache_path = cache_path 130 | if args.configuration[0] != "@": 131 | if not (hasattr(config, "output_path") and config.output_path): 132 | config.output_path = f"{base_path}/tmp" 133 | if not os.path.exists(config.output_path): 134 | os.makedirs(config.output_path) 135 | if getattr(config, "clear_output_path", False): 136 | call(f"rm -rf {config.output_path}/*") 137 | call(f"hostname > {config.output_path}/job_details.txt") 138 | with open(f"{config.output_path}/job_details.txt", "a") as f: 139 | f.write(" ".join(sys.argv) + "\n\n") 140 | call(f"nvidia-smi >> {config.output_path}/job_details.txt") 141 | call(f"echo >> {config.output_path}/job_details.txt") 142 | if len(args.git_log) == 0: 143 | call(f"git log -{user.git_log_num_lines} >> {config.output_path}/job_details.txt") 144 | else: 145 | call(f"cat {args.git_log} >> {config.output_path}/job_details.txt") 146 | if "SLURM_JOB_ID" in os.environ: 147 | slurm_job_id = os.environ["SLURM_JOB_ID"] 148 | call(f"scontrol -d show job={slurm_job_id} > {config.output_path}/slurm_job_details.txt") 149 | with open(f"{config.output_path}/{args.configuration}.yaml", "w") as f: 150 | f.write(config_yaml) 151 | with open(f"{config.output_path}/{args.configuration}.raw.yaml", "w") as f: 152 | f.write(raw_yaml) 153 | with open(f"{config.output_path}/{args.configuration}.full.yaml", "w") as f: 154 | f.write(full_config_yaml) 155 | 156 | config.num_threads = len(os.sched_getaffinity(0)) 157 | 158 | import torch 159 | config.device = torch.device("cpu" if args.cpu else "cuda:0") 160 | 161 | from concurrent.futures import ThreadPoolExecutor 162 | config.thread_pool = ThreadPoolExecutor() 163 | 164 | import rust 165 | rust.init(False, cache_path) 166 | 167 | import numpy as np 168 | config.random_seeds = np.load(f"{base_path}/core/random_seeds.npy") 169 | 170 | for k, v in config.__dict__.items(): 171 | if k[:2] == "__": 172 | continue 173 | globals()[k] = v 174 | globals()["args"] = args 175 | globals()["user"] = user 176 | -------------------------------------------------------------------------------- /core/defaults.yaml: -------------------------------------------------------------------------------- 1 | __defaults: &-defaults 2 | empty: 3 | 4 | __augmentation_defaults: &augmentation-defaults 5 | scaling: .1 6 | x_shear: 0 7 | y_shear: 0 8 | rotation: 180 9 | h_flip: .5 10 | v_flip: .5 11 | contrast: .1 12 | brightness: .1 13 | noise: .1 14 | at_test_time: True 15 | 16 | __SegmentationDataset_defaults: &SegmentationDataset-defaults 17 | channels: 18 | input: [ndvi, ir, red, green, blue, depth] 19 | training_samples: 300000 20 | split_weights: 21 | training: 70 22 | validation: 15 23 | test: 15 24 | min_sample_entropy: 25 | threshold: 0 26 | training_histogram: False 27 | apply_to_validation: False 28 | apply_to_test: False 29 | augmentation: 30 | <<: *augmentation-defaults 31 | 32 | __LGNDataset_defaults: &LGNDataset-defaults 33 | <<: *SegmentationDataset-defaults 34 | ground_truth_mapping: 35 | classes: [3, 1, 4, 5, 0, 0, 0, 5, 0, 2, 2, 2, 2, 2, 5] # map to ISPRS classes 36 | lut: # copy of 'lut' from datasets/ISPRSDatasetLoader.py 37 | - [ 0, 0, 0] 38 | - [ 0, 0,255] 39 | - [ 0,255, 0] 40 | - [255, 0, 0] 41 | - [255,255, 0] 42 | - [255,255,255] 43 | ignore_class: 5 44 | 45 | __hannover_defaults: &hannover-defaults 46 | <<: *LGNDataset-defaults 47 | domain: hannover 48 | 49 | __buxtehude_defaults: &buxtehude-defaults 50 | <<: *LGNDataset-defaults 51 | domain: buxtehude 52 | 53 | __nienburg_defaults: &nienburg-defaults 54 | <<: *LGNDataset-defaults 55 | domain: nienburg 56 | 57 | __vaihingen_defaults: &vaihingen-defaults 58 | <<: *SegmentationDataset-defaults 59 | domain: vaihingen 60 | channels: 61 | input: [ndvi, ir, red, green, depth] 62 | split_weights: 63 | training: 90 64 | validation: 10 65 | ignore_class: 5 66 | 67 | __potsdam_defaults: &potsdam-defaults 68 | <<: *SegmentationDataset-defaults 69 | domain: potsdam 70 | split_weights: 71 | training: 90 72 | validation: 10 73 | ignore_class: 5 74 | 75 | __toulouse_defaults: &toulouse-defaults 76 | <<: *SegmentationDataset-defaults 77 | domain: toulouse 78 | channels: 79 | input: [ndvi, blue2, blue, green, yellow, red, red2, ir, ir2] 80 | ignore_class: 0 81 | 82 | __toulouse_full_defaults: &toulouse_full-defaults 83 | <<: *SegmentationDataset-defaults 84 | domain: toulouse_full 85 | channels: 86 | input: [ndvi, blue2, blue, green, yellow, red, red2, ir, ir2] 87 | ignore_class: 0 88 | 89 | __toulouse_pan_defaults: &toulouse_pan-defaults 90 | <<: *SegmentationDataset-defaults 91 | domain: toulouse_pan 92 | channels: 93 | input: [ndvi, blue2, blue, green, yellow, red, red2, ir, ir2] 94 | split_weights: 95 | training: 62.5 96 | validation: 37.5 97 | ignore_class: 0 98 | 99 | __synthinel_defaults: &synthinel-defaults 100 | <<: *SegmentationDataset-defaults 101 | domain: synthinel 102 | channels: 103 | input: [red, green, blue] 104 | 105 | __synthinel_redroof_defaults: &synthinel_redroof-defaults 106 | <<: *synthinel-defaults 107 | domain: synthinel_redroof 108 | 109 | __synthinel_paris_defaults: &synthinel_paris-defaults 110 | <<: *synthinel-defaults 111 | domain: synthinel_paris 112 | 113 | __synthinel_ancient_defaults: &synthinel_ancient-defaults 114 | <<: *synthinel-defaults 115 | domain: synthinel_ancient 116 | 117 | __synthinel_scifi_defaults: &synthinel_scifi-defaults 118 | <<: *synthinel-defaults 119 | domain: synthinel_scifi 120 | 121 | __synthinel_palace_defaults: &synthinel_palace-defaults 122 | <<: *synthinel-defaults 123 | domain: synthinel_palace 124 | 125 | __synthinel_austin_defaults: &synthinel_austin-defaults 126 | <<: *synthinel-defaults 127 | domain: synthinel_austin 128 | 129 | __synthinel_venice_defaults: &synthinel_venice-defaults 130 | <<: *synthinel-defaults 131 | domain: synthinel_venice 132 | 133 | __synthinel_modern_defaults: &synthinel_modern-defaults 134 | <<: *synthinel-defaults 135 | domain: synthinel_modern 136 | 137 | __ipi_dataset_defaults: &ipi_dataset-defaults 138 | <<: *SegmentationDataset-defaults 139 | channels: 140 | input: [ndvi, ir, red, green, blue] 141 | 142 | __hameln_defaults: &hameln-defaults 143 | <<: *ipi_dataset-defaults 144 | domain: hameln 145 | 146 | __schleswig_defaults: &schleswig-defaults 147 | <<: *ipi_dataset-defaults 148 | domain: schleswig 149 | 150 | __mecklenburg_vorpommern_defaults: &mecklenburg_vorpommern-defaults 151 | <<: *ipi_dataset-defaults 152 | domain: mecklenburg_vorpommern 153 | 154 | __ipi_DA_dataset_defaults: &ipi_DA_dataset-defaults 155 | <<: *SegmentationDataset-defaults 156 | split_weights: 157 | training: 50 158 | validation: 25 159 | test: 25 160 | ignore_class: 5 161 | 162 | __hameln_DA_defaults: &hameln_DA-defaults 163 | <<: *ipi_DA_dataset-defaults 164 | domain: hameln_DA 165 | 166 | __schleswig_DA_defaults: &schleswig_DA-defaults 167 | <<: *ipi_DA_dataset-defaults 168 | domain: schleswig_DA 169 | 170 | __dlr_landcover_defaults: &dlr_landcover-defaults 171 | <<: *SegmentationDataset-defaults 172 | domain: dlr_landcover 173 | channels: 174 | input: [ndvi, ir, red, green, blue, sar] 175 | 176 | __dlr_roadmaps_defaults: &dlr_roadmaps-defaults 177 | <<: *SegmentationDataset-defaults 178 | domain: dlr_roadmaps 179 | load_full_dlr_roadmaps_annotations: True 180 | channels: 181 | input: [red, green, blue] 182 | 183 | __dlr_roadsegmentation_defaults: &dlr_roadsegmentation-defaults 184 | <<: *SegmentationDataset-defaults 185 | domain: dlr_roadsegmentation 186 | channels: 187 | input: [red, green, blue] 188 | split_weights: 189 | training: 80 190 | validation: 20 191 | 192 | __isaid_defaults: &isaid-defaults 193 | <<: *SegmentationDataset-defaults 194 | domain: isaid 195 | channels: 196 | input: [red, green, blue] 197 | split_weights: 198 | training: 95 199 | validation: 5 200 | min_sample_entropy: 201 | ignore_cache: False 202 | create_cache: False 203 | threshold: 0.04 # iSAID has many background pixels, avoid empty training samples by setting a minimum entropy level for ground truth images 204 | training_histogram: False 205 | apply_to_validation: False 206 | apply_to_test: False 207 | 208 | __SegForestNet_defaults: &SegForestNet-defaults 209 | features: 210 | context: 0 211 | decoder: 212 | num_blocks: 8 213 | context: 1 214 | intermediate_features: 96 215 | use_residual_blocks: True 216 | vq: 217 | type: [0, 0] 218 | codebook_size: 512 219 | normalized_length: 0 220 | loss_weights: [1, 1] 221 | hard: True 222 | temperature: 223 | parameters: [epoch, epochs] 224 | func: min(epoch,round(0.8*epochs))/round(0.8*epochs) 225 | #func: 1-np.cos(0.5*np.pi*epoch/(epochs-1)) 226 | value_range: [2, 0.1] 227 | loss_weight: 228 | parameters: [epoch, epochs] 229 | func: 0 230 | value_range: [.05, .05] 231 | region_map: 232 | accumulation: add 233 | node_weight: 1 234 | softmax_temperature: 235 | parameters: [epoch, epochs] 236 | func: 0 237 | value_range: [1, 1] 238 | loss: 239 | cross_entropy: pixels 240 | ce_constant: 10 241 | distribution_metric: gini 242 | min_region_size: 4 243 | weights: [.8625, .0475, .035, .055, 0] 244 | 245 | __SemanticSegmentation_defaults: &SemanticSegmentation-defaults 246 | autoencoder: False 247 | num_samples_per_epoch: 2500 248 | unique_iterations: True 249 | alt_loss: False 250 | optimizer: 251 | type: AdamW 252 | arguments: 253 | betas: [0.75, 0.999] 254 | weight_decay: 0.0078 255 | learning_rate: 256 | min_value: 0 257 | num_cycles: 1 258 | cycle_length_factor: 2 259 | num_iterations_factor: 1 260 | gradient_clipping: 2 261 | class_weights: 262 | ignore_dataset: False 263 | ignored_class_weight: 0.4 264 | dynamic_exponent: 4 265 | terminate_early: -1 266 | model_filename_extension: .pt.gz 267 | delete_irrelevant_models: True 268 | visualize_regions: False 269 | smoothing: 0.6 270 | model_specific_defaults:map: # epochs, max. learning rate, mini-batch size 271 | - [SegForestNet, [105, 0.0025, 12]] 272 | - [FCN, [105, 0.001, 12]] 273 | - [RAFCN, [105, 0.0005, 12]] 274 | - [FarSeg, [105, 0.001, 12]] 275 | - [PFNet, [105, 0.001, 12]] 276 | - [DeepLabv3p, [105, 0.001, 12]] 277 | - [UNet, [105, 0.0005, 12]] 278 | - [UNetpp, [105, 0.0005, 12]] 279 | -------------------------------------------------------------------------------- /core/random_seeds.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gritzner/SegForestNet/b7b88ab86677b1b0f2455ff3d41975bc79212831/core/random_seeds.npy -------------------------------------------------------------------------------- /datasets/DLRDatasetLoader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import cv2 as cv 3 | import numpy as np 4 | import os 5 | 6 | 7 | class DatasetLoader_dlr_landcover(): 8 | def __init__(self, config): 9 | self.images = [] 10 | for city in ("Berlin", "Munich"): 11 | full_img = cv.imread( 12 | f"{config.dataset_path}/images/{city}_converted.tif", 13 | cv.IMREAD_UNCHANGED 14 | ) 15 | assert full_img.shape[0]%5 == 0 16 | full_img = full_img.reshape((5, -1, full_img.shape[1])) 17 | full_img, full_sar = full_img[:4], full_img[-1] 18 | full_img = np.moveaxis(full_img, 0, 2) 19 | full_gt = cv.imread( 20 | f"{config.dataset_path}/annotations/{city}_converted.tif", 21 | cv.IMREAD_UNCHANGED 22 | ) 23 | assert full_gt.shape == full_sar.shape 24 | 25 | ys = (full_sar.shape[0]//2, full_sar.shape[0]) 26 | min_ys = (0, 0, ys[0], ys[0]) 27 | max_ys = (ys[0], ys[0], ys[1], ys[1]) 28 | 29 | xs = (full_sar.shape[1]//2, full_sar.shape[1]) 30 | min_xs = (0, xs[0], 0, xs[0]) 31 | max_xs = (xs[0], xs[1], xs[0], xs[1]) 32 | 33 | for y0,y1,x0,x1 in zip(min_ys, max_ys, min_xs, max_xs): 34 | sar = full_sar[y0:y1,x0:x1] 35 | img = full_img[y0:y1,x0:x1] 36 | gt = full_gt[y0:y1,x0:x1] 37 | 38 | img_set = [] 39 | img_set.append(sar) 40 | img_set.append(img) 41 | 42 | rgb = np.empty((*sar.shape, 3), dtype=np.uint8) 43 | for c in range(3): 44 | temp = img[:,:,c] 45 | temp = (temp - np.min(temp)) / 3000 46 | temp[temp>1] = 1 47 | rgb[:,:,2-c] = 255 * temp 48 | img_set.append(rgb) 49 | 50 | img_set.append(gt) 51 | self.images.append(img_set) 52 | 53 | self.channels = { 54 | "sar": (0, 0), 55 | "blue": (1, 0), 56 | "green": (1, 1), 57 | "red": (1, 2), 58 | "ir": (1, 3), 59 | "vis_red": (2, 0), 60 | "vis_green": (2, 1), 61 | "vis_blue": (2, 2), 62 | "gt": (3, 0) 63 | } 64 | self.num_classes = 5 65 | self.gsd = 1000 66 | self.lut = ( 67 | ( 0, 0, 0), # ??? might be roads/paths 68 | (255,255, 0), # agriculture 69 | ( 0,255, 0), # forest 70 | (255, 0, 0), # built-up 71 | ( 0, 0,255) # water 72 | ) 73 | 74 | class DatasetLoader_dlr_roadmaps(): 75 | def __init__(self, config): 76 | self.images = [] 77 | for subset in ("AerialKITTI", "Bavaria"): 78 | for fn in glob.iglob(f"{config.dataset_path}/{subset}/images/*.jpg"): 79 | img_set = [] 80 | img_set.append(cv.imread(fn, cv.IMREAD_UNCHANGED)) 81 | fn = fn[:-4] 82 | mask = cv.imread(fn + "_mask.tif", cv.IMREAD_UNCHANGED) 83 | fn = fn.split("/") 84 | fn[-2] = "road_annotation_full" 85 | fn = "/".join(fn) 86 | if not(config.load_full_dlr_roadmaps_annotations and os.path.exists(fn + "_gt.tif")): 87 | fn = fn.split("/") 88 | fn[-2] = "road_annotation" 89 | fn = "/".join(fn) 90 | raw_gt = cv.imread(fn + "_gt.tif", cv.IMREAD_UNCHANGED) 91 | if len(raw_gt.shape) == 3: 92 | raw_gt = raw_gt[:,:,0] 93 | gt = np.ones(raw_gt.shape, dtype=np.uint8) 94 | gt[raw_gt==255] = 2 95 | gt[mask!=255] = 0 96 | img_set.append(gt) 97 | self.images.append(img_set) 98 | 99 | self.channels = { 100 | "red": (0, 2), 101 | "green": (0, 1), 102 | "blue": (0, 0), 103 | "gt": (1, 0) 104 | } 105 | self.num_classes = 3 106 | self.gsd = 13 107 | self.lut = ( 108 | (127,127,127), # ignore 109 | ( 0, 0, 0), # other 110 | (255,255,255) # road 111 | ) 112 | 113 | rgb2class_map = { 114 | (127,127,127): 0, 115 | ( 0, 0, 0): 1, 116 | (255, 0, 0): 2, 117 | (255,105,180): 3, 118 | ( 0, 0,255): 4, 119 | (255,255, 0): 5 120 | } 121 | 122 | def rgb2class(fn): 123 | gt = cv.imread(fn, cv.IMREAD_UNCHANGED) 124 | new_gt = np.empty(gt.shape[:2], dtype=np.uint8) 125 | for k, v in rgb2class_map.items(): 126 | i = np.logical_and(gt[:,:,2]==k[0], gt[:,:,1]==k[1]) 127 | i = np.logical_and(i, gt[:,:,0]==k[2]) 128 | i = np.where(i) 129 | new_gt[i[0],i[1]] = v 130 | return new_gt 131 | 132 | class DatasetLoader_dlr_roadsegmentation(): 133 | def __init__(self, config): 134 | self.images = [] 135 | self.image_subsets = [] 136 | for subset in ("train", "test"): 137 | for fn in glob.iglob(f"{config.dataset_path}/Aerial/{subset}_images/*.jpg"): 138 | img_set = [] 139 | img_set.append(cv.imread(fn, cv.IMREAD_UNCHANGED)) 140 | fn = fn[:-4] 141 | mask = cv.imread(fn + "_mask.tif", cv.IMREAD_UNCHANGED) 142 | fn = fn.split("/") 143 | fn[-2] = "colorAnnotation" 144 | fn = "/".join(fn) 145 | gt = rgb2class(fn + ".png") 146 | gt[mask!=255] = 0 147 | img_set.append(gt) 148 | self.images.append(img_set) 149 | self.image_subsets.extend( 150 | [0 if subset=="train" else 2] * (len(self.images) - len(self.image_subsets)) 151 | ) 152 | 153 | self.channels = { 154 | "red": (0, 2), 155 | "green": (0, 1), 156 | "blue": (0, 0), 157 | "gt": (1, 0) 158 | } 159 | self.num_classes = 6 160 | self.gsd = 9 161 | self.lut = ( 162 | (127,127,127), # ignore 163 | ( 0, 0, 0), # background 164 | (255, 0, 0), # building 165 | (255,105,180), # road 166 | ( 0, 0,255), # sidewalk 167 | (255,255, 0) # parking 168 | ) 169 | -------------------------------------------------------------------------------- /datasets/IPIDatasetLoader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import cv2 as cv 3 | import numpy as np 4 | 5 | 6 | class IPILoader(): 7 | def __init__(self, config): 8 | self.images = [] 9 | for fn in glob.iglob(f"{config.dataset_path}/gt/*.tif"): 10 | img_set = [] 11 | img_set.append(cv.imread(fn, cv.IMREAD_UNCHANGED)) 12 | tokens = fn.split("/") 13 | tokens[-2] = "RGB" 14 | fn = "/".join(tokens) 15 | img_set.append(cv.imread(fn, cv.IMREAD_UNCHANGED)) 16 | tokens[-2] = "IR" 17 | fn = "/".join(tokens) 18 | img_set.append(cv.imread(fn, cv.IMREAD_UNCHANGED)) 19 | self.images.append(img_set) 20 | 21 | self.channels = { 22 | "gt": (0, 0), 23 | "red": (1, 2), 24 | "green": (1, 1), 25 | "blue": (1, 0), 26 | "ir": (2, 0) 27 | } 28 | self.num_classes = 8 29 | self.gsd = 20 30 | self.lut = ( 31 | (255,128, 0), # building 32 | (128,128,128), # sealed surface 33 | (200,135, 70), # soil 34 | ( 0,255, 0), # grass 35 | ( 64,128, 0), # tree 36 | ( 0, 0,255), # water 37 | (255, 0, 0), # car 38 | (128, 0, 25) # other 39 | ) 40 | 41 | class DatasetLoader_hameln(IPILoader): 42 | def __init__(self, config): 43 | super().__init__(config) 44 | 45 | class DatasetLoader_schleswig(IPILoader): 46 | def __init__(self, config): 47 | super().__init__(config) 48 | 49 | class DatasetLoader_mecklenburg_vorpommern(IPILoader): 50 | def __init__(self, config): 51 | super().__init__(config) 52 | 53 | self.num_classes = 10 54 | self.gsd = 10 55 | self.lut = ( 56 | (255,128, 0), # building 57 | (128,128,128), # sealed surface 58 | (210,210,200), # unpaved road 59 | (200,135, 70), # soil 60 | (255,255, 0), # crops 61 | ( 0,255, 0), # grass 62 | ( 64,128, 0), # tree 63 | ( 0, 0,255), # water 64 | ( 64, 64, 64), # railway 65 | (128, 0, 25) # other 66 | ) 67 | 68 | class IPIDALoader(): 69 | def __init__(self, config): 70 | self.images = [] 71 | for fn in sorted(glob.iglob(f"{config.dataset_path}/**/L_DA.png")): 72 | img_set = [] 73 | img_set.append(cv.imread(fn, cv.IMREAD_UNCHANGED)) 74 | fn = fn[:fn.rfind("/")] 75 | img_set.append(cv.imread(f"{fn}/R_G_B_IR.png", cv.IMREAD_UNCHANGED)) 76 | img = cv.imread(f"{fn}/NDSM.tif", cv.IMREAD_UNCHANGED) 77 | img_set.append(np.asarray(img, dtype=np.float32)) 78 | self.images.append(img_set) 79 | 80 | self.channels = { 81 | "gt": (0, 0), 82 | "ir": (1, 3), 83 | "red": (1, 2), 84 | "green": (1, 1), 85 | "blue": (1, 0), 86 | "depth": (2, 0) 87 | } 88 | self.num_classes = 6 89 | self.gsd = 20 90 | self.lut = ( # for ground truth visualization 91 | ( 0, 0, 0), # sealed surface 92 | ( 0, 0,255), # building 93 | ( 0,255, 0), # low vegetation 94 | (255, 0, 0), # tree 95 | (255,255, 0), # car 96 | (255,255,255) # clutter 97 | ) 98 | 99 | class DatasetLoader_hameln_DA(IPIDALoader): 100 | def __init__(self, config): 101 | super().__init__(config) 102 | 103 | class DatasetLoader_schleswig_DA(IPIDALoader): 104 | def __init__(self, config): 105 | super().__init__(config) 106 | 107 | -------------------------------------------------------------------------------- /datasets/ISAIDDatasetLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import core 3 | import glob 4 | import cv2 as cv 5 | from .SparseInstanceImage import SparseInstanceImage 6 | 7 | 8 | lut = ( # for ground truth visualization 9 | ( 0, 0, 0), # background 10 | ( 0, 63, 63), # storange tank 11 | ( 0,127,127), # large vehicle 12 | ( 0, 0,127), # small vehicle 13 | ( 0,127,255), # plane 14 | ( 0, 0, 63), # ship 15 | ( 0, 0,255), # swimming pool 16 | ( 0,100,155), # harbor 17 | ( 0, 63,127), # tennis court 18 | ( 0, 63,255), # ground track field 19 | ( 0,127,191), # soccer ball field 20 | ( 0, 63, 0), # baseball diamond 21 | ( 0,127, 63), # bridge 22 | ( 0, 63,191), # basketball court 23 | ( 0,191,127), # roundabout 24 | ( 0, 0,191) # helicopter 25 | ) 26 | 27 | 28 | class DatasetLoader_isaid(): 29 | def __init__(self, config): 30 | dota_root_path, isaid_root_path = config.dataset_path 31 | 32 | class_map = np.empty((len(lut), 3), dtype=np.uint8) 33 | for i, rgb in enumerate(lut): 34 | class_map[i] = rgb 35 | class_map = np.ascontiguousarray(np.flip(class_map,axis=1)) 36 | 37 | self.images = [] 38 | self.image_subsets = [] 39 | for subset, subset_id in (("training", 0), ("validation", 2)): 40 | path_prefix = f"{isaid_root_path}/{subset}" 41 | for img_set in core.thread_pool.map(lambda fn: DatasetLoader_isaid.load_image(fn,path_prefix,class_map), sorted(glob.iglob(f"{dota_root_path}/{subset}/tmp/images/*.png"))): 42 | self.images.append(img_set) 43 | self.image_subsets.append(subset_id) 44 | 45 | self.channels = { 46 | "red": (0, 2), 47 | "green": (0, 1), 48 | "blue": (0, 0), 49 | "gt": (1, 0) 50 | } 51 | self.num_classes = 16 52 | self.gsd = 26 # GSD varies wildy in DOTA, median of all valid GSDs: 26, mean: 39 53 | self.lut = lut 54 | 55 | @staticmethod 56 | def load_image(fn, path_prefix, class_map): 57 | img_set = [cv.imread(fn)] 58 | 59 | img_id = fn.split("/")[-1].split(".")[0] 60 | sem_img = cv.imread( 61 | f"{path_prefix}/Semantic_masks/tmp/images/{img_id}_instance_color_RGB.png", 62 | cv.IMREAD_UNCHANGED 63 | ) 64 | inst_img = cv.imread( 65 | f"{path_prefix}/Instance_masks/tmp/images/{img_id}_instance_id_RGB.png", 66 | cv.IMREAD_UNCHANGED 67 | ) 68 | img_set.append(SparseInstanceImage(sem_img, inst_img, class_map)) 69 | 70 | return img_set 71 | -------------------------------------------------------------------------------- /datasets/ISPRSDatasetLoader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import cv2 as cv 3 | import numpy as np 4 | 5 | 6 | # ISPRS classes: 7 | # (255,255,255), 0: impervious surfaces 8 | # ( 0, 0,255), 1: buildings 9 | # ( 0,255,255), 2: low vegetation 10 | # ( 0,255, 0), 3: trees 11 | # (255,255, 0), 4: cars 12 | # (255, 0, 0), 5: clutter/background 13 | 14 | rgb2class_map = { 15 | (255,255,255): 0, 16 | ( 0, 0,255): 1, 17 | ( 0,255,255): 2, 18 | ( 0,255, 0): 3, 19 | (255,255, 0): 4, 20 | (255, 0, 0): 5 21 | } 22 | 23 | def rgb2class(fn): 24 | gt = cv.imread(fn, cv.IMREAD_UNCHANGED) 25 | new_gt = np.empty(gt.shape[:2], dtype=np.uint8) 26 | for k, v in rgb2class_map.items(): 27 | i = np.logical_and(gt[:,:,0]==k[2], gt[:,:,1]==k[1]) 28 | i = np.logical_and(i, gt[:,:,2]==k[0]) 29 | i = np.where(i) 30 | new_gt[i[0],i[1]] = v 31 | return new_gt 32 | 33 | 34 | lut = ( # for ground truth visualization 35 | ( 0, 0, 0), # sealed surface 36 | ( 0, 0,255), # building 37 | ( 0,255, 0), # low vegetation 38 | (255, 0, 0), # tree 39 | (255,255, 0), # car 40 | (255,255,255) # clutter 41 | ) 42 | 43 | 44 | class DatasetLoader_vaihingen(): 45 | def __init__(self, config): 46 | filenames = [fn.split("/")[-1] for fn in glob.iglob(f"{config.dataset_path}/ground_truth_COMPLETE/*.tif")] 47 | 48 | self.images = [] 49 | self.image_subsets = [] 50 | for fn in sorted(filenames): 51 | i = int(fn.split(".")[0].split("a")[-1]) 52 | if i in (1, 3, 5, 7, 11, 13, 15, 17, 21, 23, 26, 28, 30, 32, 34, 37): 53 | self.image_subsets.append(0) 54 | else: 55 | self.image_subsets.append(2) 56 | 57 | dsm_fn = fn.replace("top_mosaic_09cm", "dsm_09cm_matching") 58 | img_set = [] 59 | img_set.append(cv.imread( 60 | f"{config.dataset_path}/semantic_labeling/top/{fn}", 61 | cv.IMREAD_UNCHANGED 62 | )) 63 | img_set.append(cv.imread( 64 | f"{config.dataset_path}/semantic_labeling/dsm/{dsm_fn}", 65 | cv.IMREAD_UNCHANGED 66 | )) 67 | img_set.append(rgb2class(f"{config.dataset_path}/ground_truth_COMPLETE/{fn}")) 68 | self.images.append(img_set) 69 | 70 | self.channels = { 71 | "ir": (0, 2), 72 | "red": (0, 1), 73 | "green": (0, 0), 74 | "depth": (1, 0), 75 | "gt": (2, 0) 76 | } 77 | self.num_classes = 6 78 | self.gsd = 8 # filenames say 9cm, but all the accompanying documents say 8cm 79 | self.lut = lut 80 | 81 | class DatasetLoader_potsdam(): 82 | def __init__(self, config): 83 | filenames = [fn.split("/")[-1] for fn in glob.iglob(f"{config.dataset_path}/5_Labels_all/*.tif")] 84 | filenames = list(map(lambda fn: fn[4:fn.rfind("_")], filenames)) 85 | 86 | self.images = [] 87 | self.image_subsets = [] 88 | for fn in sorted(filenames): 89 | i = 100*int(fn[8]) + int(fn[10:]) 90 | if i in (210, 211, 212, 310, 311, 312, 410, 411, 412, 510, 511, 512, 607, 608, 609, 610, 611, 612, 707, 708, 709, 710, 711, 712): 91 | self.image_subsets.append(0) 92 | else: 93 | self.image_subsets.append(2) 94 | 95 | dsm_fn = fn.replace("potsdam_", "dsm_potsdam_0") 96 | dsm_fn = dsm_fn.replace("_7", "_07") 97 | dsm_fn = dsm_fn.replace("_8", "_08") 98 | dsm_fn = dsm_fn.replace("_9", "_09") 99 | img_set = [] 100 | img_set.append(cv.imread( 101 | f"{config.dataset_path}/2_Ortho_RGB/top_{fn}_RGB.tif", 102 | cv.IMREAD_UNCHANGED 103 | )) 104 | img_set.append(cv.imread( 105 | f"{config.dataset_path}/3_Ortho_IRRG/top_{fn}_IRRG.tif", 106 | cv.IMREAD_UNCHANGED 107 | )) 108 | img_set.append(cv.imread( 109 | f"{config.dataset_path}/1_DSM/{dsm_fn}.tif", 110 | cv.IMREAD_UNCHANGED 111 | )) 112 | if "3_13" in fn: 113 | img_set[-1] = cv.resize(img_set[-1], (6000,6000), interpolation=cv.INTER_CUBIC) 114 | if "_4_12" in fn or "6_7" in fn: 115 | img_set.append(rgb2class(f"{config.dataset_path}/5_Labels_for_participants/top_{fn}_label.tif")) 116 | else: 117 | img_set.append(rgb2class(f"{config.dataset_path}/5_Labels_all/top_{fn}_label.tif")) 118 | 119 | self.images.append(img_set) 120 | 121 | self.channels = { 122 | "ir": (1, 2), 123 | "red": (0, 2), 124 | "green": (0, 1), 125 | "blue": (0, 0), 126 | "depth": (2, 0), 127 | "gt": (3, 0) 128 | } 129 | self.num_classes = 6 130 | self.gsd = 5 131 | self.lut = lut 132 | -------------------------------------------------------------------------------- /datasets/LGNDatasetLoader.py: -------------------------------------------------------------------------------- 1 | import core 2 | import cv2 as cv 3 | import glob 4 | import json 5 | import bz2 6 | import numpy as np 7 | import rust 8 | import itertools 9 | 10 | 11 | # classes: 12 | # 1: trees 13 | # 2: buildings 14 | # 3: vehicles 15 | # 4: objects 16 | # 5: tracks 17 | # 6: streets 18 | # 7: traffic routes 19 | # 8: water 20 | # 9: sealed surfaces 21 | # 10: fields with vegetation 22 | # 11: fields without vegetation 23 | # 12: fields 24 | # 13: low vegetation 25 | # 14: unsealed surfaces 26 | # 15: surfaces 27 | 28 | 29 | class LGNDatasetLoader(): 30 | def __init__(self, left, bottom, is_hannover, config): 31 | images = [] 32 | images.append(cv.imread( 33 | next(glob.iglob(f"{config.dataset_path}/*_col.tif")), 34 | cv.IMREAD_UNCHANGED 35 | )) 36 | images.append(cv.imread( 37 | next(glob.iglob(f"{config.dataset_path}/*_ir.tif")), 38 | cv.IMREAD_UNCHANGED 39 | )) 40 | images.append(cv.imread( 41 | next(glob.iglob(f"{config.dataset_path}/dom/dom_cleaned.tif")), 42 | cv.IMREAD_UNCHANGED 43 | )) 44 | 45 | images.append(self.rasterize_ground_truth(config.dataset_path, left, bottom, images[1].shape, is_hannover)) 46 | self.split_images(images, is_hannover, config) 47 | 48 | self.channels = { 49 | "ir": (1, 0), 50 | "red": (0, 2), 51 | "green": (0, 1), 52 | "blue": (0, 0), 53 | "depth": (2, 0), 54 | "gt": (3, 0) 55 | } 56 | self.num_classes = 15 57 | self.gsd = 20 58 | self.lut = ( # for ground truth visualization, IDs get decreased by one during rasterization 59 | (255, 0, 0), # 1: trees 60 | ( 0, 0, 255), # 2: buildings 61 | (255, 255, 0), # 3: vehicles 62 | (127, 127, 0), # 4: objects 63 | (255, 0, 255), # 5: tracks 64 | (170, 0, 170), # 6: streets 65 | ( 85, 0, 85), # 7: traffic routes 66 | ( 0, 255, 255), # 8: water 67 | (127, 127, 127), # 9: sealed surfaces 68 | ( 0, 255, 0), # 10: fields with vegetation 69 | ( 0, 191, 0), # 11: fields without vegetation 70 | ( 0, 127, 0), # 12: fields 71 | ( 0, 63, 0), # 13: low vegetation 72 | ( 0, 0, 0), # 14: unsealed surfaces 73 | (255, 255, 255), # 15: surfaces 74 | ) 75 | 76 | def rasterize_ground_truth(self, path, left, bottom, img_size, is_hannover): 77 | data = {} 78 | with bz2.open(f"{path}/ground_truth.geojson.bz2", "r") as f: 79 | shapes = json.load(f) 80 | 81 | object_ids = set() 82 | for feature in shapes["features"]: 83 | object_id = feature["properties"]["id"] 84 | assert not object_id in object_ids 85 | object_ids.add(object_id) 86 | 87 | geom = feature["geometry"] 88 | assert geom["type"] in ("Polygon", "MultiPolygon") 89 | if len(geom["coordinates"]) == 0: 90 | continue 91 | if geom["type"] == "Polygon": 92 | geom = [geom["coordinates"]] 93 | else: 94 | geom = geom["coordinates"] 95 | geom = [[np.asarray(coords) for coords in poly] for poly in geom] 96 | 97 | object_class = feature["properties"]["klasse"] 98 | if not object_class in data: 99 | data[object_class] = [] 100 | data[object_class].extend(geom) 101 | 102 | img = np.zeros(img_size, dtype=np.uint8) + 255 103 | for key in reversed(sorted(data.keys())): 104 | geometries = data[key] 105 | if len(geometries) == 0: 106 | continue 107 | geometry_lengths = np.empty(len(geometries), dtype=np.int32) 108 | max_length = np.max([len(geometry) for geometry in geometries]) 109 | polygon_lengths = np.empty((len(geometries), max_length), dtype=np.int32) 110 | max_length = [np.max([polygon.shape[0] for polygon in geometry]) for geometry in geometries] 111 | max_length = np.max(max_length) 112 | polygons = np.empty((len(geometries), polygon_lengths.shape[1], max_length, 2), dtype=np.float64) 113 | for i, geometry in enumerate(geometries): 114 | geometry_lengths[i] = len(geometry) 115 | for j, polygon in enumerate(geometry): 116 | polygon_lengths[i,j] = polygon.shape[0] 117 | polygon[:,0] = (polygon[:,0] - left) / 2000 # x coordinate 118 | polygon[:,1] = 1.0 - ((polygon[:,1] - bottom) / 2000) # y coordinate 119 | polygons[i,j,:polygon.shape[0]] = polygon 120 | rust.rasterize_objects(img, geometry_lengths, polygon_lengths, polygons, key-1, core.num_threads) 121 | rust.flood_fill_holes(img) 122 | 123 | return img 124 | 125 | def split_images(self, images, is_hannover, config): 126 | self.images = [] 127 | for i, j in itertools.product(range(4), range(4)): 128 | if i == 0 and j == 3 and is_hannover and getattr(config, "skip_Eilenriede", True): 129 | continue 130 | self.images.append([img[i*2500:(i+1)*2500,j*2500:(j+1)*2500] for img in images]) 131 | 132 | class DatasetLoader_hannover(LGNDatasetLoader): 133 | def __init__(self, config): 134 | super().__init__(550000, 5802000, True, config) 135 | 136 | class DatasetLoader_buxtehude(LGNDatasetLoader): 137 | def __init__(self, config): 138 | super().__init__(546000, 5924000, False, config) 139 | 140 | class DatasetLoader_nienburg(LGNDatasetLoader): 141 | def __init__(self, config): 142 | super().__init__(514000, 5830000, False, config) 143 | -------------------------------------------------------------------------------- /datasets/MultiSegmentationDataset.py: -------------------------------------------------------------------------------- 1 | import core 2 | import datasets 3 | import types 4 | import numpy as np 5 | import torch 6 | from .SegmentationDatasetPatchProviders import AbstractPatchProvider 7 | 8 | 9 | class ConcatenatedPatchProvider(AbstractPatchProvider): 10 | def __init__(self, parent, sources, update_image_ids): 11 | self.sources = sources 12 | self.shape = [0, *sources[0].shape[1:]] 13 | self.dtype = sources[0].dtype 14 | 15 | for source, offset in zip(sources, parent.offsets): 16 | source.no_threading = True 17 | source.images = parent.base.images 18 | if update_image_ids: 19 | if hasattr(source, "augmentation"): 20 | source.augmentation.image += offset 21 | else: 22 | source.patch_info[:,0] += offset 23 | self.shape[0] += source.shape[0] 24 | assert len(self.shape) == len(source.shape) 25 | for a, b in zip(self.shape[1:], source.shape[1:]): 26 | assert a == b 27 | assert self.dtype == source.dtype 28 | 29 | if hasattr(sources[0], "patch_info"): 30 | self.patch_info = np.concatenate([source.patch_info for source in sources], axis=0) 31 | 32 | self.source_ids = [], [] 33 | offset = 0 34 | for i, source in enumerate(sources): 35 | self.source_ids[0].extend([i] * source.shape[0]) 36 | self.source_ids[1].extend([offset] * source.shape[0]) 37 | offset += source.shape[0] 38 | assert len(self.source_ids[0]) == self.shape[0] 39 | assert len(self.source_ids[0]) == len(self.source_ids[1]) 40 | 41 | def extract_patch(self, result, index): 42 | i, j = self.source_ids[0][index], self.source_ids[1][index] 43 | result[:] = self.sources[i][index-j] 44 | 45 | 46 | class InterleavedPatchProvider(AbstractPatchProvider): 47 | def __init__(self, parent, sources, update_image_ids): 48 | self.sources = sources 49 | self.shape = [0, *sources[0].shape[1:]] 50 | self.dtype = sources[0].dtype 51 | 52 | for source, offset in zip(sources, parent.offsets): 53 | source.no_threading = True 54 | source.images = parent.base.images 55 | if update_image_ids: 56 | if hasattr(source, "augmentation"): 57 | source.augmentation.image += offset 58 | else: 59 | source.patch_info[:,0] += offset 60 | self.shape[0] += source.shape[0] 61 | assert len(self.shape) == len(source.shape) 62 | for a, b in zip(self.shape[1:], source.shape[1:]): 63 | assert a == b 64 | assert self.dtype == source.dtype 65 | 66 | if hasattr(sources[0], "patch_info"): 67 | self.patch_info = np.concatenate([source.patch_info for source in sources], axis=0) 68 | 69 | self.source_ids = [] 70 | for i, source in enumerate(sources): 71 | self.source_ids.extend(zip( 72 | [i]*source.shape[0], range(source.shape[0]) 73 | )) 74 | self.source_ids.sort(key=lambda i: i[1]) 75 | assert len(self.source_ids) == self.shape[0] 76 | 77 | def extract_patch(self, result, index): 78 | i, j = self.source_ids[index] 79 | result[:] = self.sources[i][j] 80 | 81 | 82 | class MultiSegmentationDataset(): 83 | def __init__(self, config, params): 84 | assert len(config.datasets) > 1 85 | all_datasets = [core.get_object_meta_info(dataset) for dataset in config.datasets] 86 | first = all_datasets[0][1] 87 | for _, other in all_datasets[1:]: 88 | assert len(first.channels.input) == len(other.channels.input) 89 | for a, b in zip(first.channels.input, other.channels.input): 90 | assert a == b 91 | assert len(first.patch_size) == len(other.patch_size) 92 | for a, b in zip(first.patch_size, other.patch_size): 93 | assert a == b 94 | 95 | all_datasets = [core.create_object(datasets, dataset) for dataset in config.datasets] 96 | first = all_datasets[0] 97 | assert not first.has_instances 98 | self.gsd = first.gsd 99 | self.base = types.SimpleNamespace( 100 | images = first.base.images, 101 | visualization_channels = first.base.visualization_channels 102 | ) 103 | self.offsets = [0, len(self.base.images)] 104 | self.class_counts = first.class_counts 105 | for other in all_datasets[1:]: 106 | assert np.all(first.base.visualization_channels == other.base.visualization_channels) 107 | assert first.num_classes == other.num_classes 108 | assert first.ignore_class == other.ignore_class 109 | assert len(first.lut) == len(other.lut) 110 | for a, b in zip(first.lut, other.lut): 111 | assert len(a) == len(b) 112 | for c, d in zip(a, b): 113 | assert c == d 114 | assert not other.has_instances 115 | self.gsd += other.gsd 116 | self.base.images.extend(other.base.images) 117 | self.offsets.append(len(self.base.images)) 118 | self.class_counts += other.class_counts 119 | self.gsd /= len(all_datasets) 120 | 121 | self.num_classes = first.num_classes 122 | self.ignore_class = first.ignore_class 123 | self.lut = first.lut 124 | 125 | num_pixels = np.sum(self.class_counts) 126 | weights = np.empty(self.num_classes, dtype=np.float64) 127 | for c in range(self.num_classes): 128 | if self.class_counts[c] == num_pixels: 129 | weights[c] = 1 130 | elif c == self.ignore_class or self.class_counts[c] == 0: 131 | weights[c] = 0 132 | else: 133 | weights[c] = 1 - self.class_counts[c]/num_pixels 134 | weights = weights / np.max(weights) 135 | self.class_weights = torch.from_numpy(weights).float().to(core.device) 136 | 137 | patch_provider = InterleavedPatchProvider if config.interleave else ConcatenatedPatchProvider 138 | self.training = types.SimpleNamespace( 139 | x = patch_provider(self, [dataset.training.x for dataset in all_datasets], True), 140 | x_gt = patch_provider(self, [dataset.training.x_gt for dataset in all_datasets], False), 141 | x_vis = patch_provider(self, [dataset.training.x_vis for dataset in all_datasets], False), 142 | y = patch_provider(self, [dataset.training.y for dataset in all_datasets], False) 143 | ) 144 | self.validation = types.SimpleNamespace( 145 | x = patch_provider(self, [dataset.validation.x for dataset in all_datasets], True), 146 | x_gt = patch_provider(self, [dataset.validation.x_gt for dataset in all_datasets], False), 147 | x_vis = patch_provider(self, [dataset.validation.x_vis for dataset in all_datasets], False), 148 | y = patch_provider(self, [dataset.validation.y for dataset in all_datasets], False), 149 | index_map = patch_provider(self, [dataset.validation.index_map for dataset in all_datasets], False) 150 | ) 151 | self.test = types.SimpleNamespace( 152 | x = patch_provider(self, [dataset.test.x for dataset in all_datasets], True), 153 | x_gt = patch_provider(self, [dataset.test.x_gt for dataset in all_datasets], False), 154 | x_vis = patch_provider(self, [dataset.test.x_vis for dataset in all_datasets], False), 155 | y = patch_provider(self, [dataset.test.y for dataset in all_datasets], False), 156 | index_map = patch_provider(self, [dataset.test.index_map for dataset in all_datasets], False) 157 | ) 158 | 159 | del self.offsets 160 | -------------------------------------------------------------------------------- /datasets/SemcityToulouseDatasetLoader.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv 2 | import glob 3 | 4 | 5 | lut = ( # for ground truth visualization 6 | (255,255,255), # void 7 | ( 38, 38, 38), # impervious surface 8 | (238,118, 33), # building 9 | ( 34,139, 34), # pervious surface 10 | ( 0,222,137), # high vegetation 11 | (255, 0, 0), # car 12 | ( 0, 0,238), # water 13 | (160, 30,230) # sports venues 14 | ) 15 | 16 | panoptic_training_set = "04", "08" 17 | test_set = "03", "07" 18 | 19 | 20 | class DatasetLoader_toulouse(): 21 | def __init__(self, config): 22 | self.images = [] 23 | self.image_subsets = [] 24 | 25 | annotators = {} 26 | for fn in glob.iglob(f"{config.dataset_path}/semantic_05/TLS_indMap_noGeo/*.tif"): 27 | file_id, annotator = fn.split(".")[0].split("_")[-2:] 28 | if not file_id in annotators or annotator != "1": 29 | annotators[file_id] = annotator 30 | 31 | for file_id in sorted([file_id for file_id, annotator in annotators.items() if annotator != "1"]): 32 | img_set = [] 33 | rgb = cv.imread( 34 | f"{config.dataset_path}/img_multispec_05/TLS_BDSD_RGB_noGeo/TLS_BDSD_RGB_noGeo_{file_id}.tif", 35 | cv.IMREAD_UNCHANGED 36 | ) 37 | img = cv.imread( 38 | f"{config.dataset_path}/img_multispec_05/TLS_BDSD_M/bands_to_rows/TLS_BDSD_M_{file_id}.tif", 39 | cv.IMREAD_UNCHANGED 40 | ) 41 | assert rgb.shape[0]*8 == img.shape[0] and img.shape[1] == rgb.shape[1] 42 | for i in range(8): 43 | offset = rgb.shape[0] * i 44 | img_set.append(img[offset:offset+rgb.shape[0]]) 45 | assert img_set[-1].shape == rgb.shape[:2] 46 | img_set.append(rgb) 47 | img_set.append(cv.imread( 48 | f"{config.dataset_path}/semantic_05/TLS_indMap_noGeo/TLS_indMap_noGeo_{file_id}_{annotators[file_id]}.tif", 49 | cv.IMREAD_UNCHANGED 50 | )) 51 | 52 | self.images.append(img_set) 53 | if file_id in panoptic_training_set: 54 | self.image_subsets.append(1) 55 | elif file_id in test_set: 56 | self.image_subsets.append(2) 57 | else: 58 | self.image_subsets.append(0) 59 | 60 | self.channels = { 61 | "blue2": (0, 0), 62 | "blue": (1, 0), 63 | "green": (2, 0), 64 | "yellow": (3, 0), 65 | "red": (4, 0), 66 | "red2": (5, 0), 67 | "ir": (6, 0), 68 | "ir2": (7, 0), 69 | "vis_red": (8, 2), 70 | "vis_green": (8, 1), 71 | "vis_blue": (8, 0), 72 | "gt": (9, 0) 73 | } 74 | self.num_classes = 8 75 | self.gsd = 50 76 | self.lut = lut 77 | 78 | class DatasetLoader_toulouse_full(): 79 | def __init__(self, config): 80 | self.images = [] 81 | self.image_subsets = [] 82 | 83 | annotators = {} 84 | for fn in glob.iglob(f"{config.dataset_path}/semantic_05/TLS_indMap_noGeo/*.tif"): 85 | file_id, annotator = fn.split(".")[0].split("_")[-2:] 86 | if not file_id in annotators or annotator != "1": 87 | annotators[file_id] = annotator 88 | 89 | for file_id in sorted(annotators.keys()): 90 | img_set = [] 91 | rgb = cv.imread( 92 | f"{config.dataset_path}/img_multispec_05/TLS_BDSD_RGB_noGeo/TLS_BDSD_RGB_noGeo_{file_id}.tif", 93 | cv.IMREAD_UNCHANGED 94 | ) 95 | img = cv.imread( 96 | f"{config.dataset_path}/img_multispec_05/TLS_BDSD_M/bands_to_rows/TLS_BDSD_M_{file_id}.tif", 97 | cv.IMREAD_UNCHANGED 98 | ) 99 | assert rgb.shape[0]*8 == img.shape[0] and img.shape[1] == rgb.shape[1] 100 | for i in range(8): 101 | offset = rgb.shape[0] * i 102 | img_set.append(img[offset:offset+rgb.shape[0]]) 103 | assert img_set[-1].shape == rgb.shape[:2] 104 | img_set.append(rgb) 105 | img_set.append(cv.imread( 106 | f"{config.dataset_path}/semantic_05/TLS_indMap_noGeo/TLS_indMap_noGeo_{file_id}_{annotators[file_id]}.tif", 107 | cv.IMREAD_UNCHANGED 108 | )) 109 | 110 | self.images.append(img_set) 111 | if file_id in panoptic_training_set: 112 | self.image_subsets.append(1) 113 | elif file_id in test_set: 114 | self.image_subsets.append(2) 115 | else: 116 | self.image_subsets.append(0) 117 | 118 | self.channels = { 119 | "blue2": (0, 0), 120 | "blue": (1, 0), 121 | "green": (2, 0), 122 | "yellow": (3, 0), 123 | "red": (4, 0), 124 | "red2": (5, 0), 125 | "ir": (6, 0), 126 | "ir2": (7, 0), 127 | "vis_red": (8, 2), 128 | "vis_green": (8, 1), 129 | "vis_blue": (8, 0), 130 | "gt": (9, 0) 131 | } 132 | self.num_classes = 8 133 | self.gsd = 50 134 | self.lut = lut 135 | -------------------------------------------------------------------------------- /datasets/SparseInstanceImage.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import rust 3 | 4 | 5 | class SparseInstanceImage(): 6 | def __init__(self, gt_rgb, inst_rgb, class_map): 7 | assert class_map.shape[0] <= 256 8 | 9 | gt = np.empty(gt_rgb.shape[:2], dtype=np.uint8) 10 | inst = np.empty(gt.shape, dtype=np.int32) 11 | num_inst = rust.rgb2ids(gt, gt_rgb, inst, inst_rgb, class_map) 12 | 13 | self.instances = np.empty((num_inst, 6), dtype=np.uint64) 14 | rust.extract_instances(self.instances, gt, inst) 15 | 16 | masks = [np.where(inst[y0:y1,x0:x1]==i+1, 1, 0) for i, (c, y0, y1, x0, x1, offset) in enumerate(self.instances)] 17 | masks = [np.asarray(mask.flatten(), dtype=np.uint8) for mask in masks] 18 | self.instances[0,-1] = 0 19 | self.instances[1:,-1] = np.cumsum([mask.shape[0] for mask in masks[:-1]]) 20 | self.masks = np.concatenate(masks) 21 | 22 | self.shape = tuple(gt.shape) 23 | self.dtype = np.dtype(np.int32) 24 | self.nbytes = self.instances.nbytes + self.masks.nbytes, gt.nbytes + inst.nbytes 25 | 26 | assert np.all(gt == self.get_semantic_image()) 27 | assert np.all(inst == self.get_instance_image()) 28 | 29 | def get_semantic_image(self): 30 | result = np.zeros(self.shape, dtype=self.dtype) 31 | for c, y0, y1, x0, x1, begin in self.instances: 32 | height, width = y1 - y0, x1 - x0 33 | end = begin + height*width 34 | mask = self.masks[begin:end].reshape(height, width) 35 | result[y0:y1,x0:x1][mask!=0] = c 36 | return result 37 | 38 | def get_instance_image(self): 39 | result = np.zeros(self.shape, dtype=self.dtype) 40 | for i, (c, y0, y1, x0, x1, begin) in enumerate(self.instances, 1): 41 | height, width = y1 - y0, x1 - x0 42 | end = begin + height*width 43 | mask = self.masks[begin:end].reshape(height, width) 44 | result[y0:y1,x0:x1][mask!=0] = i 45 | return result 46 | -------------------------------------------------------------------------------- /datasets/SynthinelDatasetLoader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import cv2 as cv 3 | import numpy as np 4 | 5 | 6 | lut = ( # for ground truth visualization 7 | ( 0, 0, 0), # other 8 | (255,255,255) # building 9 | ) 10 | 11 | 12 | class SynthinelLoader(): 13 | def __init__(self, config, subsets): 14 | self.images = [] 15 | for subset in subsets: 16 | for fn in glob.iglob(f"{config.dataset_path}/random_city_s_{subset}_*_RGB.tif"): 17 | img_set = [] 18 | img_set.append(cv.imread(fn, cv.IMREAD_UNCHANGED)) 19 | fn = fn.split("_") 20 | fn[-1] = "GT.tif" 21 | fn = "_".join(fn) 22 | img_set.append(cv.imread(fn, cv.IMREAD_UNCHANGED)) 23 | self.images.append(img_set) 24 | 25 | self.channels = { 26 | "red": (0, 2), 27 | "green": (0, 1), 28 | "blue": (0, 0), 29 | "gt": (1, 0) 30 | } 31 | self.num_classes = 2 32 | self.gsd = 30 33 | self.lut = lut 34 | 35 | class DatasetLoader_synthinel(SynthinelLoader): 36 | def __init__(self, config): 37 | super().__init__(config, "abcghi") 38 | 39 | class DatasetLoader_synthinel_redroof(SynthinelLoader): 40 | def __init__(self, config): 41 | super().__init__(config, "a") 42 | 43 | class DatasetLoader_synthinel_paris(SynthinelLoader): 44 | def __init__(self, config): 45 | super().__init__(config, "b") 46 | 47 | class DatasetLoader_synthinel_ancient(SynthinelLoader): 48 | def __init__(self, config): 49 | super().__init__(config, "c") 50 | 51 | class DatasetLoader_synthinel_scifi(SynthinelLoader): 52 | def __init__(self, config): 53 | super().__init__(config, "d") 54 | 55 | class DatasetLoader_synthinel_palace(SynthinelLoader): 56 | def __init__(self, config): 57 | super().__init__(config, "e") 58 | 59 | class DatasetLoader_synthinel_austin(SynthinelLoader): 60 | def __init__(self, config): 61 | super().__init__(config, "g") 62 | 63 | class DatasetLoader_synthinel_venice(SynthinelLoader): 64 | def __init__(self, config): 65 | super().__init__(config, "h") 66 | 67 | class DatasetLoader_synthinel_modern(SynthinelLoader): 68 | def __init__(self, config): 69 | super().__init__(config, "i") 70 | -------------------------------------------------------------------------------- /datasets/ToyDataset.py: -------------------------------------------------------------------------------- 1 | import core 2 | import numpy as np 3 | import torch 4 | import types 5 | import rust 6 | 7 | 8 | class ToyDataset(): 9 | def __init__(self, config, params): 10 | print(f"generating toy dataset ...") 11 | 12 | generator = getattr(ToyDataset, f"{config.func.name}_generator") 13 | rng = np.random.RandomState(core.random_seeds[config.random_seed]) 14 | colors = np.asarray(( 15 | (0, 0, 0), 16 | (0, 0, 255), 17 | (0, 255, 0), 18 | (255, 0, 0), 19 | (0, 255, 255), 20 | (255, 0, 255), 21 | (255, 255, 0), 22 | (255, 255, 255) 23 | ), dtype=np.uint8) 24 | colors = np.ascontiguousarray(colors[rng.permutation(colors.shape[0])]) 25 | 26 | self.lut = getattr(ToyDataset, f"{config.func.name}_generator_lut")(colors) 27 | self.num_classes = len(self.lut) 28 | self.class_weights = torch.ones(self.num_classes, dtype=torch.float32, device=core.device) 29 | self.ignore_class = -100 30 | 31 | self.training = self.generate_subset(generator, rng, colors, config.num_samples.training, config) 32 | self.validation = self.generate_subset(generator, rng, colors, config.num_samples.validation, config) 33 | self.test = self.generate_subset(generator, rng, colors, config.num_samples.test, config) 34 | 35 | def generate_subset(self, generator, rng, colors, num_samples, config): 36 | subset = types.SimpleNamespace( 37 | x_vis = np.empty((num_samples, *config.patch_size, 3), dtype=np.uint8), 38 | y = np.empty((num_samples, *config.patch_size), dtype=np.int32) 39 | ) 40 | 41 | for i in range(num_samples): 42 | generator(subset.x_vis[i], subset.y[i], rng, colors, **config.func.params.__dict__) 43 | 44 | subset.x_vis = np.moveaxis(subset.x_vis, -1, 1) 45 | subset.x = (np.asarray(subset.x_vis, dtype=np.float32) / 127.5) - 1 46 | divider = self.num_classes / 2.0 47 | subset.x_gt = (np.asarray(subset.y, dtype=np.float32) / divider) - 1 48 | 49 | return subset 50 | 51 | @staticmethod 52 | def circles_generator(img, gt, rng, colors, num_circles, out_margin, min_radius, max_radius): 53 | img[:,:] = colors[0] 54 | gt[:] = 0 55 | for _ in range(num_circles): 56 | i = rng.randint(1, colors.shape[0]) 57 | 58 | center = np.asarray(( 59 | rng.randint(-out_margin, img.shape[0]+out_margin), 60 | rng.randint(-out_margin, img.shape[1]+out_margin), 61 | ), dtype=np.int32) 62 | r = min_radius + rng.rand() * (max_radius - min_radius) 63 | r = r, int(np.ceil(r)) + 1 64 | 65 | yrange = np.asarray((max(center[0]-r[1],0), min(center[0]+r[1],img.shape[0])), dtype=np.int32) 66 | xrange = np.asarray((max(center[1]-r[1],0), min(center[1]+r[1],img.shape[1])), dtype=np.int32) 67 | 68 | rust.draw_circle(img, gt, center, r[0], colors[i], i, yrange, xrange) 69 | 70 | @staticmethod 71 | def circles_generator_lut(colors): 72 | return tuple([tuple(color) for color in colors]) 73 | 74 | @staticmethod 75 | def quadtree_generator(img, gt, rng, colors, margin): 76 | i = rng.permutation(4) 77 | j = 2*i + rng.randint(2, size=4) 78 | 79 | y = rng.randint(margin, img.shape[0]-margin) 80 | x = rng.randint(margin, img.shape[1]-margin) 81 | 82 | img[:y,:x] = colors[j[0]] 83 | gt[:y,:x] = i[0] 84 | img[:y,x:] = colors[j[1]] 85 | gt[:y,x:] = i[1] 86 | img[y:,:x] = colors[j[2]] 87 | gt[y:,:x] = i[2] 88 | img[y:,x:] = colors[j[3]] 89 | gt[y:,x:] = i[3] 90 | 91 | @staticmethod 92 | def quadtree_generator_lut(colors): 93 | assert colors.shape[0] == 8 94 | return tuple([colors[2*i] for i in range(4)]) 95 | 96 | @staticmethod 97 | def kdtree_generator(img, gt, rng, colors, margin): 98 | i = rng.permutation(4) 99 | j = 2*i + rng.randint(2, size=4) 100 | y = rng.randint(margin, img.shape[0]-margin) 101 | 102 | x = rng.randint(margin, img.shape[1]-margin) 103 | img[:y,:x] = colors[j[0]] 104 | gt[:y,:x] = i[0] 105 | img[:y,x:] = colors[j[1]] 106 | gt[:y,x:] = i[1] 107 | 108 | x = rng.randint(margin, img.shape[1]-margin) 109 | img[y:,:x] = colors[j[2]] 110 | gt[y:,:x] = i[2] 111 | img[y:,x:] = colors[j[3]] 112 | gt[y:,x:] = i[3] 113 | 114 | @staticmethod 115 | def kdtree_generator_lut(colors): 116 | assert colors.shape[0] == 8 117 | return tuple([colors[2*i] for i in range(4)]) 118 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .ISPRSDatasetLoader import DatasetLoader_vaihingen, DatasetLoader_potsdam 2 | from .LGNDatasetLoader import DatasetLoader_hannover, DatasetLoader_buxtehude, DatasetLoader_nienburg 3 | from .SemcityToulouseDatasetLoader import DatasetLoader_toulouse, DatasetLoader_toulouse_full 4 | from .SynthinelDatasetLoader import DatasetLoader_synthinel, DatasetLoader_synthinel_redroof, DatasetLoader_synthinel_paris, DatasetLoader_synthinel_ancient, DatasetLoader_synthinel_scifi, DatasetLoader_synthinel_palace, DatasetLoader_synthinel_austin, DatasetLoader_synthinel_venice, DatasetLoader_synthinel_modern 5 | from .IPIDatasetLoader import DatasetLoader_hameln, DatasetLoader_schleswig, DatasetLoader_mecklenburg_vorpommern, DatasetLoader_hameln_DA, DatasetLoader_schleswig_DA 6 | from .DLRDatasetLoader import DatasetLoader_dlr_landcover, DatasetLoader_dlr_roadmaps, DatasetLoader_dlr_roadsegmentation 7 | from .ISAIDDatasetLoader import DatasetLoader_isaid 8 | from .SegmentationDataset import SegmentationDataset 9 | from .MultiSegmentationDataset import MultiSegmentationDataset 10 | from .ToyDataset import ToyDataset 11 | -------------------------------------------------------------------------------- /datasets/isaid_cache.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gritzner/SegForestNet/b7b88ab86677b1b0f2455ff3d41975bc79212831/datasets/isaid_cache.npz -------------------------------------------------------------------------------- /hpo/aethonproxy.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import signal 3 | import types 4 | import threading 5 | import time 6 | import re 7 | 8 | 9 | def monitor(shutdown): 10 | process = subprocess.Popen((python_bin, "aethon.py", "@monitor", slurm_config[0]), cwd="..") 11 | while not shutdown.wait(5): 12 | pass 13 | process.send_signal(signal.SIGINT) 14 | process.wait() 15 | 16 | 17 | class init(): 18 | def __init__(self): 19 | assert "python_bin" in globals() 20 | assert "slurm_config" in globals() 21 | self.shutdown = threading.Event() 22 | self.thread = threading.Thread(target=monitor, args=(self.shutdown,)) 23 | self.thread.start() 24 | print("waiting for SLURM monitoring to start ...") 25 | time.sleep(15) 26 | 27 | def close(self): 28 | self.shutdown.set() 29 | self.thread.join() 30 | 31 | 32 | def call(params): 33 | process = subprocess.run( 34 | (python_bin, "aethon.py", "@slurm", *slurm_config, *[str(p) for p in params]), 35 | cwd="..", text=True, capture_output=True 36 | ) 37 | if process.returncode != 0: 38 | raise RuntimeError() 39 | r = re.compile(r"job ID: (?P\d+)") 40 | for line in process.stdout.split("\n"): 41 | m = r.match(line.strip()) 42 | if not m is None: 43 | return f"../{slurm_config[-2]}/" + m["job_id"] 44 | raise RuntimeError() 45 | -------------------------------------------------------------------------------- /hpo/smac_example.py: -------------------------------------------------------------------------------- 1 | from ConfigSpace import Configuration, ConfigurationSpace, Integer, Float 2 | from smac import HyperparameterOptimizationFacade, Scenario 3 | from pathlib import Path 4 | import glob 5 | import bz2 6 | import json 7 | import numpy as np 8 | import aethonproxy 9 | 10 | aethonproxy.python_bin = "/home/gritzner/tmp/miniconda3/envs/torch2/bin/python" 11 | aethonproxy.slurm_config = "luis", "HPO", "24:00:00", "9G", "hpo.tar.bz2", "hpo/hpo", "hpo" 12 | 13 | def train(config: Configuration, seed: int = 0) -> float: 14 | path = aethonproxy.call(( 15 | config["learning_rate"], 16 | config["momentum"], 17 | config["weight_decay"] 18 | )) 19 | history = next(glob.iglob(f"{path}/**/history.json.bz2", recursive=True)) 20 | with bz2.open(history, "r") as f: 21 | data = json.load(f) 22 | return 1 - np.max(data["val_miou"]) 23 | 24 | if __name__ == "__main__": 25 | proxy = aethonproxy.init() 26 | configspace = ConfigurationSpace() 27 | configspace.add_hyperparameters([ 28 | Float("learning_rate", bounds=(0.001, 0.1), log=True), 29 | Float("momentum", bounds=(0., .99)), 30 | Float("weight_decay", bounds=(0.00001, .00003)) 31 | ]) 32 | scenario = Scenario( 33 | configspace, 34 | name = "hpo", 35 | output_directory = Path("tmp/smac"), 36 | deterministic = True, 37 | n_trials = 20, 38 | n_workers = 5 39 | ) 40 | smac = HyperparameterOptimizationFacade(scenario, train) 41 | print(smac.optimize()) 42 | proxy.close() 43 | -------------------------------------------------------------------------------- /models/DeepLabv3p.py: -------------------------------------------------------------------------------- 1 | # https://arxiv.org/abs/1802.02611 2 | import core 3 | import torch 4 | import torch.nn as nn 5 | import models 6 | import types 7 | 8 | 9 | class AtrousSpatialPyramidPooling(nn.ModuleList): 10 | def __init__(self, config, input_features, feature_space_size, backbone): 11 | super().__init__() 12 | # global average pooling 13 | self.append(nn.Sequential( 14 | nn.AdaptiveAvgPool2d(1), 15 | nn.Conv2d(input_features, 256, 1), 16 | backbone.normalization(256), 17 | backbone.activation(256), 18 | nn.Upsample(feature_space_size) 19 | )) 20 | # 1x1 conv 21 | self.append(nn.Sequential( 22 | nn.Conv2d(input_features, 256, 1), 23 | backbone.normalization(256), 24 | backbone.activation(256) 25 | )) 26 | # atrous conv 27 | for d in config.aspp_dilation_rates: 28 | self.append(nn.Sequential( 29 | backbone.padding(d), 30 | nn.Conv2d(input_features, input_features, 3, dilation=d, groups=input_features), 31 | nn.Conv2d(input_features, 256, 1), 32 | backbone.normalization(256), 33 | backbone.activation(256) 34 | )) 35 | 36 | def forward(self, x): 37 | results = [] 38 | for layer in self: 39 | results.append(layer(x)) 40 | return torch.cat(results, 1) 41 | 42 | 43 | class DeepLabv3pEncoder(nn.Module): 44 | def __init__(self, config, params): 45 | super().__init__() 46 | 47 | self.backbone = models.XceptionFeatureExtractor(config, params.input_shape[0]) 48 | 49 | self.skip_connection = nn.Sequential( 50 | nn.Conv2d(self.backbone.num_features[1], 48, 1), 51 | self.backbone.normalization(48), 52 | self.backbone.activation(48), 53 | ).to(core.device) 54 | 55 | self.aspp = nn.Sequential( 56 | AtrousSpatialPyramidPooling(config, self.backbone.num_features[4], params.input_shape[1] // 2**self.backbone.downsampling, self.backbone), 57 | nn.Conv2d(256*(2+len(config.aspp_dilation_rates)), 256, 1), 58 | self.backbone.normalization(256), 59 | self.backbone.activation(256), 60 | ).to(core.device) 61 | 62 | def forward(self, x): 63 | y = self.backbone(x) 64 | return self.skip_connection(y[1]), self.aspp(y[4]) 65 | 66 | 67 | class DeepLabv3p(nn.Module): 68 | def __init__(self, config, params): 69 | super().__init__() 70 | if params.input_shape[1] != params.input_shape[2]: 71 | raise RuntimeError("input shape for DeepLabv3p must be square") 72 | 73 | self.encoder = DeepLabv3pEncoder(config, params) 74 | self.model_name = "DeepLabv3p" 75 | self.upsample = nn.Upsample(scale_factor=2**(self.encoder.backbone.downsampling-2), mode="bilinear", align_corners=False) 76 | 77 | self.final = nn.Sequential( 78 | self.encoder.backbone.padding(1), 79 | nn.Conv2d(304, 304, 3, groups=304), 80 | nn.Conv2d(304, 256, 1), 81 | self.encoder.backbone.normalization(256), 82 | self.encoder.backbone.activation(256), 83 | self.encoder.backbone.padding(1), 84 | nn.Conv2d(256, 256, 3, groups=256), 85 | nn.Conv2d(256, 256, 1), 86 | self.encoder.backbone.normalization(256), 87 | self.encoder.backbone.activation(256), 88 | nn.Conv2d(256, params.num_classes, 1), 89 | nn.Upsample(scale_factor=4, mode="bilinear", align_corners=False), 90 | ).to(core.device) 91 | 92 | def forward(self, x, yt, ce_loss_func, weight, ignore_index): 93 | yp = self.encoder(x) 94 | yp = torch.cat((yp[0], self.upsample(yp[1])), 1) 95 | yp = self.final(yp) 96 | return yp, ce_loss_func(yp, yt, weight=weight, ignore_index=ignore_index) 97 | -------------------------------------------------------------------------------- /models/FCN.py: -------------------------------------------------------------------------------- 1 | # https://arxiv.org/abs/1411.4038 2 | import core 3 | import torch.nn as nn 4 | import models 5 | 6 | 7 | def FCNEncoder(config, params): 8 | encoder = models.XceptionFeatureExtractor(config, params.input_shape[0]) 9 | assert encoder.downsampling == 5 10 | return encoder.to(core.device) 11 | 12 | 13 | class FCN(nn.Module): 14 | def __init__(self, config, params): 15 | super().__init__() 16 | 17 | if params.input_shape[1] != params.input_shape[2]: 18 | raise RuntimeError("input shape for FCN must be square") 19 | 20 | self.encoder = FCNEncoder(config, params) 21 | self.model_name = "FCN" 22 | 23 | self.score32 = nn.Conv2d(self.encoder.num_features[-1], params.num_classes, 1).to(core.device) 24 | self.upsample32 = nn.ConvTranspose2d(params.num_classes, params.num_classes, 4, stride=2, bias=False).to(core.device) 25 | self.score16 = nn.Conv2d(self.encoder.num_features[-2], params.num_classes, 1).to(core.device) 26 | self.upsample16 = nn.ConvTranspose2d(params.num_classes, params.num_classes, 4, stride=2, bias=False).to(core.device) 27 | self.score8 = nn.Conv2d(self.encoder.num_features[-3], params.num_classes, 1).to(core.device) 28 | self.upsample8 = nn.ConvTranspose2d(params.num_classes, params.num_classes, 16, stride=8, bias=False).to(core.device) 29 | 30 | def forward(self, x, yt, ce_loss_func, weight, ignore_index): 31 | y = self.encoder(x) 32 | y0 = self.upsample32(self.score32(y[-1])) 33 | y1 = self.upsample16(self.score16(y[-2]) + y0[:,:,1:-1,1:-1]) 34 | y2 = self.upsample8(self.score8(y[-3]) + y1[:,:,1:-1,1:-1]) 35 | yp = y2[:,:,4:-4,4:-4] 36 | return yp, ce_loss_func(yp, yt, weight=weight, ignore_index=ignore_index) 37 | -------------------------------------------------------------------------------- /models/FarSeg.py: -------------------------------------------------------------------------------- 1 | # https://arxiv.org/abs/2011.09766 2 | # https://openaccess.thecvf.com/content_CVPR_2020/papers/Zheng_Foreground-Aware_Relation_Network_for_Geospatial_Object_Segmentation_in_High_Spatial_CVPR_2020_paper.pdf 3 | import core 4 | import torch 5 | import torch.nn as nn 6 | import functools 7 | from .FCN import FCNEncoder 8 | 9 | 10 | class FarSeg(nn.Module): 11 | def __init__(self, config, params): 12 | super().__init__() 13 | 14 | if params.input_shape[1] != params.input_shape[2]: 15 | raise RuntimeError("input shape for FarSeg must be square") 16 | config.num_features = getattr(config, "num_features", 256) 17 | 18 | self.encoder = FCNEncoder(config, params) 19 | self.model_name = "FarSeg" 20 | 21 | for i, input_features in enumerate(self.encoder.num_features[-4:]): 22 | setattr(self, f"p{i}a", nn.Conv2d(input_features, config.num_features, 1).to(core.device)) 23 | setattr( 24 | self, 25 | f"p{i}b", 26 | nn.Sequential( 27 | self.encoder.padding(1), 28 | nn.Conv2d(config.num_features, config.num_features, 3), 29 | ).to(core.device) 30 | ) 31 | setattr( 32 | self, 33 | f"p{i}c", 34 | nn.Sequential( 35 | nn.Conv2d(config.num_features, config.num_features, 1), 36 | self.encoder.normalization(config.num_features), 37 | self.encoder.activation(config.num_features) 38 | ).to(core.device) 39 | ) 40 | setattr( 41 | self, 42 | f"p{i}d", 43 | nn.Sequential( 44 | nn.Conv2d(config.num_features, config.num_features, 1), 45 | self.encoder.normalization(config.num_features), 46 | self.encoder.activation(config.num_features) 47 | ).to(core.device) 48 | ) 49 | decoder = [ 50 | self.encoder.padding(1), 51 | nn.Conv2d(config.num_features, config.num_features, 3), 52 | self.encoder.normalization(config.num_features), 53 | self.encoder.activation(config.num_features) 54 | ] 55 | for j in range(i): 56 | decoder.append(nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)) 57 | if j+1 < i: 58 | decoder.extend([ 59 | self.encoder.padding(1), 60 | nn.Conv2d(config.num_features, config.num_features, 3), 61 | self.encoder.normalization(config.num_features), 62 | self.encoder.activation(config.num_features) 63 | ]) 64 | setattr(self, f"p{i}e", nn.Sequential(*decoder).to(core.device)) 65 | 66 | self.scene_embedding = nn.Sequential( 67 | nn.AdaptiveAvgPool2d(1), 68 | nn.Conv2d(self.encoder.num_features[-1], config.num_features, 1) 69 | ).to(core.device) 70 | 71 | self.classifier = nn.Conv2d(config.num_features, params.num_classes, 1).to(core.device) 72 | 73 | def forward(self, x, yt, ce_loss_func, weight, ignore_index): 74 | ys = self.encoder(x)[-4:] 75 | scene_embedding = self.scene_embedding(ys[-1]) 76 | 77 | ys = [getattr(self,f"p{i}a")(y) for i, y in enumerate(ys)] 78 | ys = [y+nn.functional.interpolate(ys[i+1],scale_factor=2,mode="nearest") if i<3 else y for i, y in enumerate(ys)] 79 | ys = [getattr(self,f"p{i}b")(y) for i, y in enumerate(ys)] 80 | 81 | r = [getattr(self,f"p{i}c")(y)*scene_embedding for i, y in enumerate(ys)] 82 | r = [torch.sigmoid(r.sum(1).unsqueeze(1)) for r in r] 83 | ys = [getattr(self,f"p{i}d")(y)*r[i] for i, y in enumerate(ys)] 84 | 85 | ys = [getattr(self,f"p{i}e")(y) for i, y in enumerate(ys)] 86 | y = functools.reduce(lambda a,b: a+b, ys) 87 | 88 | y = self.classifier(y) 89 | y = nn.functional.interpolate(y, scale_factor=4, mode="bilinear", align_corners=False) 90 | 91 | return y, ce_loss_func(y, yt, weight=weight, ignore_index=ignore_index) 92 | -------------------------------------------------------------------------------- /models/RAFCN.py: -------------------------------------------------------------------------------- 1 | # https://arxiv.org/abs/1904.05730 2 | # https://openaccess.thecvf.com/content_CVPR_2019/papers/Mou_A_Relation-Augmented_Fully_Convolutional_Network_for_Semantic_Segmentation_in_Aerial_CVPR_2019_paper.pdf 3 | import core 4 | import torch 5 | import torch.nn as nn 6 | from .FCN import FCNEncoder 7 | 8 | 9 | class RelationModule(nn.Module): 10 | def __init__(self, in_features, out_features, hw, upsampling_factor): 11 | super().__init__() 12 | 13 | self.channel_u = nn.Conv2d(in_features, in_features, 1) 14 | self.channel_v = nn.Conv2d(in_features, in_features, 1) 15 | 16 | self.spatial_u = nn.Conv2d(in_features, in_features, 1) 17 | self.spatial_v = nn.Conv2d(in_features, in_features, 1) 18 | 19 | self.score = nn.Conv2d(in_features+hw, out_features, 1) 20 | self.upsample = nn.ConvTranspose2d(out_features, out_features, 2*upsampling_factor, stride=upsampling_factor, bias=False) 21 | 22 | def forward(self, x, y=None): 23 | 24 | # channel relations 25 | c = nn.functional.adaptive_avg_pool2d(x, 1) 26 | u = self.channel_u(c).squeeze().unsqueeze(2) 27 | v = self.channel_v(c).squeeze().unsqueeze(1) 28 | c = torch.bmm(u, v) 29 | c = nn.functional.softmax(c, dim=1) 30 | 31 | z = x.reshape(x.shape[0], x.shape[1], -1) 32 | x = torch.bmm(c, z).reshape(*x.shape) 33 | 34 | # spatial relations 35 | u = self.channel_u(x).reshape(x.shape[0], x.shape[1], -1).permute(0, 2, 1) 36 | v = self.channel_v(x).reshape(x.shape[0], x.shape[1], -1) 37 | z = torch.bmm(u, v).reshape(x.shape[0], -1, x.shape[2], x.shape[3]) 38 | x = torch.cat((x, z), dim=1) 39 | 40 | # FCN decoder 41 | return self.upsample(self.score(x) + y[:,:,1:-1,1:-1]) if isinstance(y, torch.Tensor) else self.upsample(self.score(x)) 42 | 43 | 44 | class RAFCN(nn.Module): 45 | def __init__(self, config, params): 46 | super().__init__() 47 | 48 | if params.input_shape[1] != params.input_shape[2]: 49 | raise RuntimeError("input shape for RAFCN must be square") 50 | 51 | self.encoder = FCNEncoder(config, params) 52 | self.model_name = "RAFCN" 53 | 54 | hw = (params.input_shape[1]//32)**2 55 | self.output32 = RelationModule(self.encoder.num_features[-1], params.num_classes, hw, 2).to(core.device) 56 | hw = (params.input_shape[1]//16)**2 57 | self.output16 = RelationModule(self.encoder.num_features[-2], params.num_classes, hw, 2).to(core.device) 58 | hw = (params.input_shape[1]//8)**2 59 | self.output8 = RelationModule(self.encoder.num_features[-3], params.num_classes, hw, 8).to(core.device) 60 | 61 | def forward(self, x, yt, ce_loss_func, weight, ignore_index): 62 | y = self.encoder(x) 63 | y = self.output8( 64 | y[-3], 65 | self.output16( 66 | y[-2], 67 | self.output32(y[-1]) 68 | ) 69 | ) 70 | yp = y[:,:,4:-4,4:-4] 71 | return yp, ce_loss_func(yp, yt, weight=weight, ignore_index=ignore_index) 72 | -------------------------------------------------------------------------------- /models/SegForestComponents.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.functional import relu 3 | 4 | 5 | class Line(): 6 | num_params = 3 7 | 8 | @staticmethod 9 | def compute(x, sample_points): 10 | return (x[:,:2] * sample_points).sum(1) - x[:,-1] 11 | 12 | class Square(): 13 | num_params = 3 14 | 15 | @staticmethod 16 | def compute(x, sample_points): 17 | t = x[:,:2] - sample_points 18 | return t.abs().max(1)[0] - x[:,-1] 19 | 20 | class Circle(): 21 | num_params = 3 22 | 23 | @staticmethod 24 | def compute(x, sample_points): 25 | t = x[:,:2] - sample_points 26 | return (t**2).sum(1) - x[:,-1] 27 | 28 | class Ellipse(): 29 | num_params = 5 30 | 31 | @staticmethod 32 | def compute(x, sample_points): 33 | d0 = ((x[:,:2] - sample_points)**2).sum(1).sqrt() 34 | d1 = ((x[:,2:4] - sample_points)**2).sum(1).sqrt() 35 | return d0 + d1 - x[:,-1] 36 | 37 | class Hyperbola(): 38 | num_params = 5 39 | 40 | @staticmethod 41 | def compute(x, sample_points): 42 | d0 = ((x[:,:2] - sample_points)**2).sum(1).sqrt() 43 | d1 = ((x[:,2:4] - sample_points)**2).sum(1).sqrt() 44 | return (d0 - d1).abs() - x[:,-1] 45 | 46 | class Parabola(): 47 | num_params = 5 48 | 49 | @staticmethod 50 | def compute(x, sample_points): 51 | d0 = ((x[:,:2] - sample_points)**2).sum(1).sqrt() 52 | d1 = Line.compute(x[:,2:], sample_points) 53 | return d0 - d1 54 | 55 | 56 | class Leaf(): 57 | def __init__(self, index): 58 | self.index = index 59 | 60 | class LeafNode(): 61 | def __init__(self): 62 | self.num_params = 0 63 | self.children = [Leaf] 64 | 65 | def apply_add(self, region_map, x, sample_points, region_map_lambda): 66 | region_map[:,self.indices[0]] += region_map_lambda 67 | 68 | def apply_mul(self, region_map, x, sample_points, region_map_lambda): 69 | region_map[:,self.indices[0]] *= region_map_lambda 70 | 71 | def apply_old(self, region_map, x, sample_points, region_map_lambda): 72 | region_map[:,self.indices[0]] *= region_map_lambda 73 | 74 | 75 | class BSPNode(): 76 | def __init__(self, left, right, sdf): 77 | self.sdf = sdf 78 | self.num_params = sdf.num_params 79 | self.children = [left, right] 80 | 81 | def apply_add(self, region_map, x, sample_points, region_map_lambda): 82 | t = self.sdf.compute(x, sample_points) 83 | t = (region_map_lambda * t).unsqueeze(1) 84 | region_map[:,self.indices[0]] += relu(t) 85 | region_map[:,self.indices[1]] += relu(-t) 86 | 87 | def apply_mul(self, region_map, x, sample_points, region_map_lambda): 88 | t = self.sdf.compute(x, sample_points) 89 | t[t.sign()==0] = 10**-6 90 | t = t.unsqueeze(1) 91 | region_map[:,self.indices[0]] *= relu(t) 92 | region_map[:,self.indices[1]] *= relu(-t) 93 | 94 | def apply_old(self, region_map, x, sample_points, region_map_lambda): 95 | t = self.sdf.compute(x, sample_points) 96 | t[t.sign()==0] = 10**-6 97 | t = torch.sigmoid(region_map_lambda[1] * t).unsqueeze(1) 98 | region_map[:,self.indices[0]] *= region_map_lambda[0] * t 99 | region_map[:,self.indices[1]] *= region_map_lambda[0] * (1-t) 100 | 101 | def BSPTree(depth, sdf): 102 | return Leaf if depth==0 else BSPNode(BSPTree(depth-1,sdf), BSPTree(depth-1,sdf), sdf) 103 | 104 | 105 | class QuadtreeNode(): 106 | def __init__(self, a, b, c, d): 107 | self.num_params = 2 108 | self.children = [a, b, c, d] 109 | 110 | def apply_add(self, region_map, x, sample_points, region_map_lambda): 111 | t = x - sample_points 112 | t0 = relu(t[:,0]).unsqueeze(1) 113 | t1 = relu(-t[:,0]).unsqueeze(1) 114 | t2 = relu(t[:,1]).unsqueeze(1) 115 | t3 = relu(-t[:,1]).unsqueeze(1) 116 | region_map[:,self.indices[0]] += region_map_lambda * t0 * t2 117 | region_map[:,self.indices[1]] += region_map_lambda * t0 * t3 118 | region_map[:,self.indices[2]] += region_map_lambda * t1 * t2 119 | region_map[:,self.indices[3]] += region_map_lambda * t1 * t3 120 | 121 | def apply_mul(self, region_map, x, sample_points, region_map_lambda): 122 | t = x - sample_points 123 | t[t.sign()==0] = 10**-3 124 | t0 = relu(t[:,0]).unsqueeze(1) 125 | t1 = relu(-t[:,0]).unsqueeze(1) 126 | t2 = relu(t[:,1]).unsqueeze(1) 127 | t3 = relu(-t[:,1]).unsqueeze(1) 128 | region_map[:,self.indices[0]] *= t0 * t2 129 | region_map[:,self.indices[1]] *= t0 * t3 130 | region_map[:,self.indices[2]] *= t1 * t2 131 | region_map[:,self.indices[3]] *= t1 * t3 132 | 133 | def apply_old(self, region_map, x, sample_points, region_map_lambda): 134 | t = x - sample_points 135 | t[t.sign()==0] = 10**-3 136 | t0 = torch.sigmoid(region_map_lambda[1] * t[:,0]).unsqueeze(1) 137 | t1 = torch.sigmoid(region_map_lambda[1] * t[:,1]).unsqueeze(1) 138 | region_map[:,self.indices[0]] *= region_map_lambda[0] * t0 * t1 139 | region_map[:,self.indices[1]] *= region_map_lambda[0] * t0 * (1-t1) 140 | region_map[:,self.indices[2]] *= region_map_lambda[0] * (1-t0) * t1 141 | region_map[:,self.indices[3]] *= region_map_lambda[0] * (1-t0) * (1-t1) 142 | 143 | def Quadtree(depth): 144 | return Leaf if depth==0 else QuadtreeNode(Quadtree(depth-1), Quadtree(depth-1), Quadtree(depth-1), Quadtree(depth-1)) 145 | 146 | 147 | class KDTreeNode(): 148 | def __init__(self, left, right, dim): 149 | self.dim = dim 150 | self.num_params = 1 151 | self.children = [left, right] 152 | 153 | def apply_add(self, region_map, x, sample_points, region_map_lambda): 154 | t = x - sample_points[:,self.dim].unsqueeze(1) 155 | t *= region_map_lambda 156 | region_map[:,self.indices[0]] += relu(t) 157 | region_map[:,self.indices[1]] += relu(-t) 158 | 159 | def apply_mul(self, region_map, x, sample_points, region_map_lambda): 160 | t = x - sample_points[:,self.dim].unsqueeze(1) 161 | t[t.sign()==0] = 10**-6 162 | region_map[:,self.indices[0]] *= relu(t) 163 | region_map[:,self.indices[1]] *= relu(-t) 164 | 165 | def apply_old(self, region_map, x, sample_points, region_map_lambda): 166 | t = x - sample_points[:,self.dim].unsqueeze(1) 167 | t[t.sign()==0] = 10**-6 168 | t = torch.sigmoid(region_map_lambda[1] * t) 169 | region_map[:,self.indices[0]] *= region_map_lambda[0] * t 170 | region_map[:,self.indices[1]] *= region_map_lambda[0] * (1-t) 171 | 172 | def KDTree(depth, root_dim): 173 | child_dim = (root_dim + 1) % 2 174 | return Leaf if depth==0 else KDTreeNode(KDTree(depth-1,child_dim), KDTree(depth-1,child_dim), root_dim) 175 | 176 | 177 | class DynKDTreeNode(): 178 | def __init__(self, left, right): 179 | self.num_params = 3 180 | self.children = [left, right] 181 | 182 | def apply_add(self, region_map, x, sample_points, region_map_lambda): 183 | t = x[:,2] - torch.where(x[:,:2].argmax(1)==0, sample_points[:,0], sample_points[:,1]) 184 | t = (region_map_lambda * t).unsqueeze(1) 185 | region_map[:,self.indices[0]] += relu(t) 186 | region_map[:,self.indices[1]] += relu(-t) 187 | 188 | def apply_mul(self, region_map, x, sample_points, region_map_lambda): 189 | t = x[:,2] - torch.where(x[:,:2].argmax(1)==0, sample_points[:,0], sample_points[:,1]) 190 | t[t.sign()==0] = 10**-6 191 | t = t.unsqueeze(1) 192 | region_map[:,self.indices[0]] *= relu(t) 193 | region_map[:,self.indices[1]] *= relu(-t) 194 | 195 | def apply_old(self, region_map, x, sample_points, region_map_lambda): 196 | t = x[:,2] - torch.where(x[:,:2].argmax(1)==0, sample_points[:,0], sample_points[:,1]) 197 | t[t.sign()==0] = 10**-6 198 | t = torch.sigmoid(region_map_lambda[1] * t).unsqueeze(1) 199 | region_map[:,self.indices[0]] *= region_map_lambda[0] * t 200 | region_map[:,self.indices[1]] *= region_map_lambda[0] * (1-t) 201 | 202 | def DynKDTree(depth): 203 | return Leaf if depth==0 else DynKDTreeNode(DynKDTree(depth-1), DynKDTree(depth-1)) 204 | -------------------------------------------------------------------------------- /models/SegForestTree.py: -------------------------------------------------------------------------------- 1 | import core 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.functional import relu, leaky_relu 6 | import types 7 | import functools 8 | from .SegForestComponents import * 9 | from .SegForestTreeDecoder import TreeFeatureDecoder, BSPSegNetDecoder 10 | 11 | 12 | class PartitioningTree(): 13 | def __init__(self, config, params, encoder, tree): 14 | self.config = tree 15 | self.downsampling_factor = encoder.downsampling_factor 16 | self.region_map_rendering = types.SimpleNamespace( 17 | func = getattr(PartitioningTree, f"render_region_map_{config.region_map.accumulation}"), 18 | node_weight = config.region_map.node_weight 19 | ) 20 | if config.region_map.accumulation == "mul2": 21 | self.region_map_rendering.node_weight[0] = getattr( 22 | PartitioningTree, 23 | f"distance_transform_{self.region_map_rendering.node_weight[0]}" 24 | ) 25 | if self.region_map_rendering.node_weight[3]: 26 | self.region_map_rendering_params = nn.Parameter(torch.as_tensor( 27 | self.region_map_rendering.node_weight[1], 28 | dtype=torch.float32, device=core.device 29 | )) 30 | self.region_map_rendering.node_weight[1] = self.region_map_rendering_params 31 | self.per_region_outputs = tree.outputs.shape[0] if len(tree.classifier) == 0 else tree.classifier[0] 32 | 33 | self.inner_nodes = [] 34 | num_params = 0 35 | self.num_leaf_nodes = 0 36 | 37 | queue = [eval(tree.graph)] 38 | while len(queue) > 0: 39 | node = queue.pop(0) 40 | assert config.region_map.accumulation != "mul2" or type(node) == BSPNode 41 | self.inner_nodes.append(node) 42 | num_params += node.num_params 43 | for i, child in enumerate(node.children): 44 | if child == Leaf: 45 | node.children[i] = Leaf(self.num_leaf_nodes) 46 | self.num_leaf_nodes += 1 47 | else: 48 | queue.append(child) 49 | self.region_map_rendering.base_shape = (self.num_leaf_nodes, *params.input_shape[1:]) 50 | 51 | num_params = max(num_params, 1) 52 | self.num_tree_parameters = num_params + self.num_leaf_nodes * tree.outputs.shape[0] 53 | 54 | for node in reversed(self.inner_nodes): 55 | node.indices = [] 56 | for child in node.children: 57 | if type(child) == Leaf: 58 | node.indices.append(np.asarray([child.index], dtype=np.int32)) 59 | else: 60 | node.indices.append(np.concatenate(child.indices)) 61 | node.children = tuple(node.children) 62 | node.indices = tuple(node.indices) 63 | self.inner_nodes = tuple(self.inner_nodes) 64 | 65 | decoder_factory = globals()[getattr(config.decoder, "type", "TreeFeatureDecoder")] 66 | self.decoder0 = decoder_factory( 67 | config, encoder, tree.actual_num_features.shape, num_params, is_shape_decoder = True 68 | ) 69 | assert tree.shape_to_content in (0, 1, 2) 70 | num_params = (0, tree.actual_num_features.shape, num_params)[tree.shape_to_content] 71 | self.decoder1 = decoder_factory( 72 | config, encoder, tree.actual_num_features.content + num_params, 73 | self.num_leaf_nodes * self.per_region_outputs, tree.actual_num_features.content 74 | ) 75 | 76 | if len(tree.classifier) > 0: 77 | num_features = tree.classifier.copy() 78 | if tree.classifier_skip_from == 0: 79 | num_features[0] += params.input_shape[0] 80 | elif tree.classifier_skip_from > 0: 81 | num_features[0] += encoder.num_features[tree.classifier_skip_from-1] 82 | self.classifier = [] 83 | for in_features, out_features in zip(num_features[:-1], num_features[1:]): 84 | if tree.classifier_context > 0: 85 | self.classifier.append(encoder.padding(tree.classifier_context)) 86 | self.classifier.extend([ 87 | nn.Conv2d(in_features, out_features, 1+2*tree.classifier_context), 88 | encoder.normalization(out_features), 89 | encoder.activation(out_features) 90 | ]) 91 | if tree.classifier_context > 0: 92 | self.classifier.append(encoder.padding(tree.classifier_context)) 93 | self.classifier.append(nn.Conv2d(num_features[-1], tree.outputs.shape[0], 1+2*tree.classifier_context)) 94 | self.classifier = nn.Sequential(*self.classifier).to(core.device) 95 | 96 | def render(self, x, z, y, sample_points): 97 | # seperate features into shape and content and decode them seperately 98 | shape, content, z = ( 99 | z[:,:self.config.num_features.shape], 100 | z[:,-self.config.num_features.content:], 101 | z[:,self.config.num_features.shape:-self.config.num_features.content] 102 | ) 103 | if self.config.shape_to_content > 0: 104 | self.tree_parameters = (self.decoder0(shape), None) # delay decoding 105 | if self.config.shape_to_content == 1: 106 | self.tree_parameters = ( 107 | self.tree_parameters[0], 108 | self.decoder1(torch.cat((self.decoder0[0][0](shape).detach(), content), dim=1)) 109 | ) 110 | elif self.config.shape_to_content == 2: 111 | self.tree_parameters = ( 112 | self.tree_parameters[0], 113 | self.decoder1(torch.cat((self.tree_parameters[0].detach(), content), dim=1)) 114 | ) 115 | else: 116 | self.tree_parameters = (self.decoder0(shape), self.decoder1(content)) 117 | 118 | #d_shape = shape - shape.mean(dim=(0, 2, 3))[None,:,None,None] 119 | #d_content = content - content.mean(dim=(0, 2, 3))[None,:,None,None] 120 | #self.cov = d_shape[:,:,None,:,:] * d_content[:,None,:,:,:] 121 | #self.cov = self.cov.mean(dim=(0, 3, 4)) 122 | 123 | # process shape features 124 | shape = self.tree_parameters[0].repeat_interleave(self.downsampling_factor, 3) 125 | shape = shape.repeat_interleave(self.downsampling_factor, 2) 126 | 127 | # render partitioning tree 128 | self.region_map = self.region_map_rendering.func( 129 | (x[0].shape[0], *self.region_map_rendering.base_shape), self.region_map_rendering.node_weight, 130 | self.region_map_rendering.softmax_temperature, shape, self.inner_nodes, sample_points, 131 | ) 132 | 133 | # process content features 134 | content = self.tree_parameters[1].reshape(z.shape[0], self.num_leaf_nodes, self.per_region_outputs, *z.shape[2:]) 135 | content = content.repeat_interleave(self.downsampling_factor, 4).repeat_interleave(self.downsampling_factor, 3) 136 | for i in range(self.num_leaf_nodes): # update prediction 137 | if hasattr(self, "classifier"): 138 | if self.config.classifier_skip_from >= 0: 139 | logits = self.classifier(torch.cat((x[self.config.classifier_skip_from], content[:,i]), dim=1)) 140 | else: 141 | logits = self.classifier(content[:,i]) 142 | else: 143 | logits = content[:,i] 144 | y[:,self.config.outputs] += self.region_map[:,i].unsqueeze(1) * logits 145 | 146 | return z 147 | 148 | @staticmethod 149 | def render_region_map_add(region_map_shape, node_weight, softmax_temperature, shape, nodes, sample_points): 150 | region_map = torch.zeros(region_map_shape, dtype=torch.float32, device=sample_points.device) 151 | for node in nodes: 152 | node_params, shape = shape[:,:node.num_params], shape[:,node.num_params:] 153 | node.apply_add(region_map, node_params, sample_points, node_weight) 154 | region_map = region_map / softmax_temperature 155 | return region_map.softmax(1) 156 | 157 | @staticmethod 158 | def render_region_map_mul(region_map_shape, node_weight, softmax_temperature, shape, nodes, sample_points): 159 | region_map = torch.ones(region_map_shape, dtype=torch.float32, device=sample_points.device) 160 | for node in nodes: 161 | node_params, shape = shape[:,:node.num_params], shape[:,node.num_params:] 162 | node.apply_mul(region_map, node_params, sample_points, node_weight) 163 | return region_map / region_map.sum(1, keepdim=True) 164 | 165 | @staticmethod 166 | def render_region_map_mul2(region_map_shape, node_weight, softmax_temperature, shape, nodes, sample_points): 167 | region_map = torch.empty(region_map_shape, dtype=torch.float32, device=sample_points.device) 168 | distance_maps = torch.empty((region_map_shape[0], 2, len(nodes), *region_map_shape[2:]), dtype=torch.float32, device=sample_points.device) 169 | 170 | transform_func, transform_weights, leaky_slope = node_weight[:3] 171 | for i, node in enumerate(nodes): 172 | node_params, shape = shape[:,:node.num_params], shape[:,node.num_params:] 173 | distance_maps[:,0,i] = node.sdf.compute(node_params, sample_points) 174 | distance_maps[:,1] = -distance_maps[:,0] 175 | distance_maps = transform_func( 176 | transform_weights[0] * distance_maps + transform_weights[1], 177 | transform_weights[2], leaky_slope 178 | ) 179 | 180 | for i in range(region_map_shape[1]): 181 | distances = [distance_maps[:,0,j] for j,node in enumerate(nodes) if i in node.indices[0]] 182 | distances.extend([distance_maps[:,1,j] for j,node in enumerate(nodes) if i in node.indices[1]]) 183 | region_map[:,i] = functools.reduce(lambda x,y: x*y, distances) 184 | 185 | region_map = region_map / softmax_temperature 186 | return region_map.softmax(1) 187 | 188 | @staticmethod 189 | def distance_transform_relu(distance_maps, weight, leaky_slope): 190 | return weight * relu(distance_maps) 191 | 192 | @staticmethod 193 | def distance_transform_leaky_relu(distance_maps, weight, leaky_slope): 194 | return weight * leaky_relu(distance_maps, negative_slope=leaky_slope) 195 | 196 | @staticmethod 197 | def distance_transform_sigmoid(distance_maps, weight, leaky_slope): 198 | return weight * torch.sigmoid(distance_maps) 199 | 200 | @staticmethod 201 | def distance_transform_clamp(distance_maps, weight, leaky_slope): 202 | return weight * torch.clamp(distance_maps, 0, 1) 203 | 204 | @staticmethod 205 | def distance_transform_leaky_clamp(distance_maps, weight, leaky_slope): 206 | distance_maps[distance_maps<0] *= leaky_slope 207 | i = distance_maps>1 208 | distance_maps[i] = 1 + (distance_maps[i] - 1) * leaky_slope 209 | return weight * distance_maps 210 | 211 | @staticmethod 212 | def distance_transform_smoothstep(distance_maps, weight, leaky_slope): 213 | distance_maps = torch.clamp(distance_maps, 0, 1) 214 | distance_maps = 3 * distance_maps**2 - 2 * distance_maps**3 215 | return weight * distance_maps 216 | 217 | @staticmethod 218 | def distance_transform_leaky_smoothstep(distance_maps, weight, leaky_slope): 219 | distance_maps[distance_maps<0] *= leaky_slope 220 | i = distance_maps>1 221 | distance_maps[i] = 1 + (distance_maps[i] - 1) * leaky_slope 222 | distance_maps = 3 * distance_maps**2 - 2 * distance_maps**3 223 | return weight * distance_maps 224 | 225 | @staticmethod 226 | def render_region_map_old(region_map_shape, node_weight, softmax_temperature, shape, nodes, sample_points): 227 | region_map = torch.ones(region_map_shape, dtype=torch.float32, device=sample_points.device) 228 | for node in nodes: 229 | node_params, shape = shape[:,:node.num_params], shape[:,node.num_params:] 230 | node.apply_old(region_map, node_params, sample_points, node_weight) 231 | region_map = region_map / softmax_temperature 232 | return region_map.softmax(1) 233 | 234 | def get_region_distributions(self, yt, ignore_class=-1, instances=False): 235 | assert ignore_class >= 0 or instances 236 | 237 | if yt.shape[-1] != self.config.outputs.shape[0] and not instances: 238 | temp = torch.zeros([*yt.shape[:-1], self.config.outputs.shape[0]+1], dtype=yt.dtype, device=core.device) 239 | temp[:,:,:,:-1] = yt[:,:,:,self.config.outputs] 240 | yt = temp 241 | i = torch.nonzero(yt.sum(-1)==0, as_tuple=True) 242 | yt[i[0],i[1],i[2],-1] = 1 243 | 244 | downsampled_shape = ( 245 | yt.shape[0], 246 | yt.shape[1] // self.downsampling_factor, self.downsampling_factor, 247 | yt.shape[2] // self.downsampling_factor, self.downsampling_factor, 248 | yt.shape[3] 249 | ) 250 | 251 | p = torch.empty((self.num_leaf_nodes, *downsampled_shape[:2], downsampled_shape[3], yt.shape[3]), dtype=torch.float32, device=yt.device) 252 | s = torch.empty((*p.shape[:-1],), dtype=torch.float32, device=yt.device) 253 | 254 | for i in range(self.num_leaf_nodes): 255 | p[i] = (yt * self.region_map[:,i,:,:,None]).reshape(*downsampled_shape).sum(-2).sum(-3) 256 | s[i] = p[i].sum(-1) 257 | s[i][s[i]<10**-6] = 10**-6 258 | p[i] = p[i].clone() / s[i].unsqueeze(-1).clone() 259 | 260 | if self.config.outputs.shape[0] > 1 and ignore_class in self.config.outputs and not instances: 261 | index = np.nonzero(self.config.outputs==ignore_class)[0][0] 262 | p[:,:,:,:,index] = 0 263 | i = p.sum(-1) < 10**-6 264 | j = p[i] 265 | j[:,0 if index>0 else 1] = 1 266 | p[i] = j 267 | p = p / p.sum(-1).unsqueeze(-1) 268 | 269 | return p, s 270 | -------------------------------------------------------------------------------- /models/SegForestTreeDecoder.py: -------------------------------------------------------------------------------- 1 | import core 2 | import torch 3 | import torch.nn as nn 4 | import utils 5 | 6 | 7 | class TreeFeatureDecoder(nn.ModuleList): 8 | def __init__(self, config, encoder, num_input_features, num_output_features, effective_num_input_features = -1, is_shape_decoder = False): 9 | super().__init__() 10 | self.use_residuals = config.decoder.use_residual_blocks 11 | c = (config.decoder.context, 2*config.decoder.context + 1) 12 | f = config.decoder.intermediate_features 13 | assert type(c[0]) == int and 0 <= c[0] 14 | 15 | if effective_num_input_features < 0: 16 | effective_num_input_features = num_input_features 17 | if effective_num_input_features >= num_output_features: 18 | print(f"[WARNING] features for tree parameter prediction not bottlenecked! (input: {effective_num_input_features}, output: {num_output_features})") 19 | 20 | vq_config = config.decoder.vq.__dict__.copy() 21 | vq_config["type"] = vq_config["type"][0 if is_shape_decoder else 1] 22 | if num_input_features == effective_num_input_features: 23 | vq_layer = utils.create_vector_quantization_layer(feature_size=num_input_features, **vq_config) 24 | else: 25 | assert num_input_features > effective_num_input_features 26 | 27 | class PartialVectorQuantization(nn.Module): 28 | def __init__(self): 29 | super().__init__() 30 | self.vq_layer = utils.create_vector_quantization_layer(feature_size=effective_num_input_features, **vq_config) 31 | self.num_other_features = num_input_features - effective_num_input_features 32 | 33 | def prepare_for_epoch(self, epoch, epochs): 34 | self.vq_layer.prepare_for_epoch(epoch, epochs) 35 | 36 | def forward(self, x): 37 | x = x[:,:self.num_other_features], self.vq_layer(x[:,self.num_other_features:]) 38 | self.loss = self.vq_layer.loss 39 | return torch.cat(x, dim=1) 40 | 41 | vq_layer = PartialVectorQuantization() 42 | 43 | self.append( 44 | nn.Sequential( 45 | vq_layer, 46 | nn.Conv2d(num_input_features, f, 1), 47 | encoder.normalization(f) 48 | ).to(core.device) 49 | ) 50 | 51 | for i in range(config.decoder.num_blocks): 52 | block = [encoder.activation(f)] 53 | if c[0] > 0: 54 | block.extend([ 55 | encoder.padding(c[0]), 56 | nn.Conv2d(f, f, c[1], groups=f), 57 | encoder.normalization(f), 58 | encoder.activation(f) 59 | ]) 60 | block.extend([ 61 | nn.Conv2d(f, f, 1), 62 | encoder.normalization(f) 63 | ]) 64 | self.append( 65 | nn.Sequential(*block).to(core.device) 66 | ) 67 | 68 | self.append( 69 | nn.Sequential( 70 | encoder.activation(f), 71 | nn.Conv2d(f, num_output_features, 1) 72 | ).to(core.device) 73 | ) 74 | 75 | def forward(self, x): 76 | for i, layer in enumerate(self): 77 | x = layer(x) if (i==0 or i+1 == len(self) or not self.use_residuals) else layer(x) + x 78 | return x 79 | 80 | 81 | def BSPSegNetDecoder(config, encoder, num_input_features, num_output_features, effective_num_input_features = -1, is_shape_decoder = False): 82 | features = [num_input_features] 83 | features.extend(config.decoder.intermediate_features) 84 | 85 | decoder = [nn.Sequential(utils.create_vector_quantization_layer(type=0))] 86 | for i in range(len(features)-1): 87 | decoder.extend([ 88 | nn.Conv2d(features[i], features[i+1], 1), 89 | encoder.normalization(features[i+1]), 90 | encoder.activation(features[i+1]) 91 | ]) 92 | decoder.append(nn.Conv2d(features[-1], num_output_features, 1)) 93 | 94 | return nn.Sequential(*decoder).to(core.device) 95 | -------------------------------------------------------------------------------- /models/UNet.py: -------------------------------------------------------------------------------- 1 | # https://arxiv.org/abs/1505.04597 2 | import core 3 | import torch 4 | import torch.nn as nn 5 | import utils 6 | 7 | 8 | class ConvBlock(nn.Module): 9 | def __init__(self, parent, channels, pooling=False, upsampling=False): 10 | super().__init__() 11 | 12 | if pooling: 13 | block = [nn.MaxPool2d(2)] 14 | else: 15 | block = [] 16 | 17 | block.extend([ 18 | parent.padding(1), 19 | nn.Conv2d(channels[0], channels[1], kernel_size=3), 20 | parent.normalization(channels[1]), 21 | parent.activation(channels[1]), 22 | parent.padding(1), 23 | nn.Conv2d(channels[1], channels[2], kernel_size=3), 24 | parent.normalization(channels[2]), 25 | parent.activation(channels[2]) 26 | ]) 27 | 28 | if upsampling: 29 | block.append(nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)) 30 | 31 | self.conv_block = nn.Sequential(*block) 32 | 33 | def forward(self, x): 34 | return self.conv_block(x) 35 | 36 | 37 | class UNetEncoder(nn.Module): 38 | def __init__(self, config, params, num_features): 39 | super().__init__() 40 | 41 | self.activation = utils.relu_wrapper(getattr(config,"relu_type","LeakyReLU")) # default from paper: ReLU 42 | self.padding = getattr(nn, getattr(config, "padding", "ZeroPad2d")) 43 | self.normalization = utils.norm_wrapper(getattr(config,"norm_type","BatchNorm2d")) # default from paper: BatchNorm2d 44 | self.downsampling = 4 45 | 46 | backbone = [ 47 | ConvBlock(self, [params.input_shape[0], num_features[0], num_features[0]]), 48 | ConvBlock(self, [num_features[0], num_features[1], num_features[1]], pooling=True), 49 | ConvBlock(self, [num_features[1], num_features[2], num_features[2]], pooling=True), 50 | ConvBlock(self, [num_features[2], num_features[3], num_features[3]], pooling=True), 51 | ConvBlock(self, [num_features[3], num_features[4], num_features[4]], pooling=True) 52 | ] 53 | self.backbone = nn.Sequential(*backbone).to(core.device) 54 | 55 | def forward(self, x): 56 | y = [x] 57 | for block in self.backbone: 58 | y.append(block(y[-1])) 59 | return tuple(y[1:]) 60 | 61 | 62 | class UNet(nn.Module): 63 | def __init__(self, config, params): 64 | super().__init__() 65 | 66 | if params.input_shape[1] != params.input_shape[2]: 67 | raise RuntimeError("input shape for UNet must be square") 68 | 69 | self.model_name = "UNet" 70 | num_features = (64, 128, 256, 512, 1024) 71 | if getattr(config, "small", False): 72 | num_features = tuple([i//2 for i in num_features]) 73 | 74 | self.encoder = UNetEncoder(config, params, num_features) 75 | self.decoder = [ 76 | ConvBlock(self.encoder, [sum(num_features[3:5]), num_features[3], num_features[3]], upsampling=True), 77 | ConvBlock(self.encoder, [sum(num_features[2:4]), num_features[2], num_features[2]], upsampling=True), 78 | ConvBlock(self.encoder, [sum(num_features[1:3]), num_features[1], num_features[1]], upsampling=True), 79 | ConvBlock(self.encoder, [sum(num_features[0:2]), num_features[0], num_features[0]]), 80 | nn.Conv2d(num_features[0], params.num_classes, kernel_size=1) 81 | ] 82 | self.decoder = nn.Sequential(*self.decoder).to(core.device) 83 | 84 | def forward(self, x, yt, ce_loss_func, weight, ignore_index): 85 | y0 = self.encoder(x) 86 | y1 = nn.functional.interpolate(y0[-1], scale_factor=2, mode="bilinear", align_corners=False) 87 | for i, block in enumerate(self.decoder): 88 | if i < 4: 89 | y1 = torch.cat([y0[3-i], y1], dim=1) 90 | y1 = block(y1) 91 | return y1, ce_loss_func(y1, yt, weight=weight, ignore_index=ignore_index) 92 | -------------------------------------------------------------------------------- /models/UNetpp.py: -------------------------------------------------------------------------------- 1 | # https://arxiv.org/abs/1807.10165 2 | import core 3 | import torch 4 | import torch.nn as nn 5 | import utils 6 | from .UNet import ConvBlock, UNetEncoder 7 | 8 | 9 | class UNetpp(nn.Module): 10 | def __init__(self, config, params): 11 | super().__init__() 12 | 13 | if params.input_shape[1] != params.input_shape[2]: 14 | raise RuntimeError("input shape for UNet must be square") 15 | 16 | self.model_name = "UNetpp" 17 | self.deep_supervision = getattr(config, "deep_supervision", True) 18 | 19 | self.encoder = UNetEncoder(config, params, (32, 64, 128, 256, 512)) 20 | 21 | self.x31 = ConvBlock(self.encoder, [768, 256, 256]).to(core.device) 22 | 23 | self.x21 = ConvBlock(self.encoder, [384, 128, 128]).to(core.device) 24 | self.x22 = ConvBlock(self.encoder, [512, 128, 128]).to(core.device) 25 | 26 | self.x11 = ConvBlock(self.encoder, [192, 64, 64]).to(core.device) 27 | self.x12 = ConvBlock(self.encoder, [256, 64, 64]).to(core.device) 28 | self.x13 = ConvBlock(self.encoder, [320, 64, 64]).to(core.device) 29 | 30 | self.x01 = ConvBlock(self.encoder, [96, 32, 32]).to(core.device) 31 | self.x02 = ConvBlock(self.encoder, [128, 32, 32]).to(core.device) 32 | self.x03 = ConvBlock(self.encoder, [160, 32, 32]).to(core.device) 33 | self.x04 = ConvBlock(self.encoder, [192, 32, 32]).to(core.device) 34 | 35 | if self.deep_supervision: 36 | self.classify1 = nn.Conv2d(32, params.num_classes, kernel_size=1).to(core.device) 37 | self.classify2 = nn.Conv2d(32, params.num_classes, kernel_size=1).to(core.device) 38 | self.classify3 = nn.Conv2d(32, params.num_classes, kernel_size=1).to(core.device) 39 | self.classify4 = nn.Conv2d(32, params.num_classes, kernel_size=1).to(core.device) 40 | 41 | def forward(self, x, yt, ce_loss_func, weight, ignore_index): 42 | x00, x10, x20, x30, x40 = self.encoder(x) 43 | 44 | x31 = nn.functional.interpolate(x40, scale_factor=2, mode="bilinear", align_corners=False) 45 | x31 = self.x31(torch.cat((x30, x31), dim=1)) 46 | 47 | x21 = nn.functional.interpolate(x30, scale_factor=2, mode="bilinear", align_corners=False) 48 | x21 = self.x21(torch.cat((x20, x21), dim=1)) 49 | 50 | x22 = nn.functional.interpolate(x31, scale_factor=2, mode="bilinear", align_corners=False) 51 | x22 = self.x22(torch.cat((x20, x21, x22), dim=1)) 52 | 53 | x11 = nn.functional.interpolate(x20, scale_factor=2, mode="bilinear", align_corners=False) 54 | x11 = self.x11(torch.cat((x10, x11), dim=1)) 55 | 56 | x12 = nn.functional.interpolate(x21, scale_factor=2, mode="bilinear", align_corners=False) 57 | x12 = self.x12(torch.cat((x10, x11, x12), dim=1)) 58 | 59 | x13 = nn.functional.interpolate(x22, scale_factor=2, mode="bilinear", align_corners=False) 60 | x13 = self.x13(torch.cat((x10, x11, x12, x13), dim=1)) 61 | 62 | x01 = nn.functional.interpolate(x10, scale_factor=2, mode="bilinear", align_corners=False) 63 | x01 = self.x01(torch.cat((x00, x01), dim=1)) 64 | 65 | x02 = nn.functional.interpolate(x11, scale_factor=2, mode="bilinear", align_corners=False) 66 | x02 = self.x02(torch.cat((x00, x01, x02), dim=1)) 67 | 68 | x03 = nn.functional.interpolate(x12, scale_factor=2, mode="bilinear", align_corners=False) 69 | x03 = self.x03(torch.cat((x00, x01, x02, x03), dim=1)) 70 | 71 | x04 = nn.functional.interpolate(x13, scale_factor=2, mode="bilinear", align_corners=False) 72 | x04 = self.x04(torch.cat((x00, x01, x02, x03, x04), dim=1)) 73 | 74 | if self.deep_supervision: 75 | yp = self.classify1(x01) + self.classify2(x02) + self.classify3(x03) + self.classify4(x04) 76 | else: 77 | yp = self.classify4(x04) 78 | return yp, ce_loss_func(yp, yt, weight=weight, ignore_index=ignore_index) 79 | -------------------------------------------------------------------------------- /models/Xception.py: -------------------------------------------------------------------------------- 1 | # https://arxiv.org/abs/1610.02357 2 | import core 3 | import torch 4 | import torch.nn as nn 5 | import utils 6 | 7 | 8 | class EntryFlow(nn.ModuleList): 9 | def __init__(self, parent, in_filters, out_filters, additional_activation=True, stride=2): 10 | super().__init__() 11 | block = [parent.activation(in_filters)] if additional_activation else [] 12 | block.extend([ 13 | parent.padding(1), 14 | nn.Conv2d(in_filters, in_filters, 3, groups=in_filters), 15 | nn.Conv2d(in_filters, out_filters, 1), 16 | parent.normalization(out_filters), 17 | parent.activation(out_filters), 18 | parent.padding(1), 19 | nn.Conv2d(out_filters, out_filters, 3, groups=out_filters), 20 | nn.Conv2d(out_filters, out_filters, 1), 21 | parent.normalization(out_filters), 22 | parent.padding(1), 23 | nn.MaxPool2d(3, stride=stride) 24 | ]) 25 | self.append(nn.Sequential(*block)) 26 | self.append(nn.Sequential( 27 | nn.Conv2d(in_filters, out_filters, 1, stride=stride), 28 | parent.normalization(out_filters) 29 | )) 30 | 31 | def forward(self, x): 32 | return self[0](x) + self[1](x) 33 | 34 | 35 | class MiddleFlow(nn.ModuleList): 36 | def __init__(self, parent, n): 37 | super().__init__() 38 | for _ in range(n): 39 | self.append(nn.Sequential( 40 | parent.activation(728), 41 | parent.padding(1), 42 | nn.Conv2d(728, 728, 3, groups=728), 43 | nn.Conv2d(728, 728, 1), 44 | parent.normalization(728), 45 | parent.activation(728), 46 | parent.padding(1), 47 | nn.Conv2d(728, 728, 3, groups=728), 48 | nn.Conv2d(728, 728, 1), 49 | parent.normalization(728), 50 | parent.activation(728), 51 | parent.padding(1), 52 | nn.Conv2d(728, 728, 3, groups=728), 53 | nn.Conv2d(728, 728, 1), 54 | parent.normalization(728), 55 | )) 56 | 57 | def forward(self, x): 58 | for layer in self: 59 | x = layer(x) + x 60 | return x 61 | 62 | 63 | class ExitFlow(nn.ModuleList): 64 | def __init__(self, parent, stride): 65 | super().__init__() 66 | self.append(nn.Sequential( 67 | parent.activation(728), 68 | parent.padding(1), 69 | nn.Conv2d(728, 728, 3, groups=728), 70 | nn.Conv2d(728, 728, 1), 71 | parent.normalization(728), 72 | parent.activation(728), 73 | parent.padding(1), 74 | nn.Conv2d(728, 728, 3, groups=728), 75 | nn.Conv2d(728, 1024, 1), 76 | parent.normalization(1024), 77 | parent.padding(1), 78 | nn.MaxPool2d(3, stride=stride), 79 | )) 80 | self.append(nn.Sequential( 81 | nn.Conv2d(728, 1024, 1, stride=stride), 82 | parent.normalization(1024), 83 | )) 84 | 85 | def forward(self, x): 86 | return self[0](x) + self[1](x) 87 | 88 | 89 | class XceptionFeatureExtractor(nn.ModuleList): 90 | def __init__(self, config, num_input_channels): 91 | super().__init__() 92 | 93 | if not (3 <= config.downsampling <= 5): 94 | raise RuntimeError(f"unsupported downsampling setting '{config.downsampling}'") 95 | self.model_name = "XceptionFeatureExtractor" 96 | self.downsampling = config.downsampling 97 | self.num_features = (64, 128, 256, 728, 2048) 98 | self.activation = utils.relu_wrapper(getattr(config,"relu_type","LeakyReLU")) # default from paper: ReLU 99 | self.padding = getattr(nn, getattr(config, "padding", "ZeroPad2d")) 100 | self.normalization = utils.norm_wrapper(getattr(config,"norm_type","BatchNorm2d")) # default from paper: BatchNorm2d 101 | 102 | self.append(nn.Sequential( 103 | self.padding(1), 104 | nn.Conv2d(num_input_channels, 32, 3, stride=2), 105 | self.normalization(32), 106 | self.activation(32), 107 | self.padding(1), 108 | nn.Conv2d(32, 64, 3), 109 | self.normalization(64), 110 | self.activation(64), 111 | ).to(core.device)) 112 | 113 | self.append(nn.Sequential( 114 | EntryFlow(self, 64, 128, False), 115 | ).to(core.device)) 116 | 117 | self.append(nn.Sequential( 118 | EntryFlow(self, 128, 256), 119 | ).to(core.device)) 120 | 121 | self.append(nn.Sequential( 122 | EntryFlow(self, 256, 728, stride=2 if config.downsampling > 3 else 1), 123 | MiddleFlow(self, 8), 124 | ).to(core.device)) 125 | 126 | self.append(nn.Sequential( 127 | ExitFlow(self, 2 if config.downsampling == 5 else 1), 128 | self.padding(1), 129 | nn.Conv2d(1024, 1024, 3, groups=1024), 130 | nn.Conv2d(1024, 1536, 1), 131 | self.normalization(1536), 132 | self.activation(1536), 133 | self.padding(1), 134 | nn.Conv2d(1536, 1536, 3, groups=1536), 135 | nn.Conv2d(1536, 2048, 1), 136 | self.normalization(2048), 137 | self.activation(2048), 138 | ).to(core.device)) 139 | 140 | if getattr(config, "pretrained_encoder", False): 141 | print("using pretrained model weights") 142 | with core.open(core.user.model_weights_paths["Xception"], "rb") as f: 143 | weights = torch.load(f, map_location=core.device) 144 | import bz2 145 | import json 146 | with bz2.open("models/xception.json.bz2", "r") as f: 147 | mapping = json.load(f) 148 | new_weights = {k: (torch.zeros_like(v) if "bias" in k else v.clone()) for k, v in self.state_dict().items()} 149 | for k, v in mapping.items(): 150 | new_weights[k] = weights[v] 151 | self.load_state_dict(new_weights) 152 | 153 | def forward(self, x): 154 | y = [x,] 155 | for layer in self: 156 | y.append(layer(y[-1])) 157 | return tuple(y[1:]) 158 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .Xception import XceptionFeatureExtractor 2 | from .UNet import UNet 3 | from .UNetpp import UNetpp 4 | from .FCN import FCN 5 | from .RAFCN import RAFCN 6 | from .FarSeg import FarSeg 7 | from .PFNet import PFNet 8 | from .DeepLabv3p import DeepLabv3p 9 | from .SegForestNet import SegForestNet 10 | -------------------------------------------------------------------------------- /models/xception.json.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gritzner/SegForestNet/b7b88ab86677b1b0f2455ff3d41975bc79212831/models/xception.json.bz2 -------------------------------------------------------------------------------- /rust/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "aethon" 3 | version = "0.1.0" 4 | authors = ["Daniel Gritzner "] 5 | edition = "2018" 6 | 7 | [dependencies] 8 | num = "0.4" 9 | rand = "0.8.4" 10 | 11 | [lib] 12 | crate-type = ["cdylib"] 13 | 14 | [profile.release] 15 | debug-assertions = false 16 | codegen-units = 1 17 | lto = "fat" 18 | panic = "abort" 19 | strip = "symbols" 20 | -------------------------------------------------------------------------------- /rust/__init__.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import glob 3 | import types 4 | import itertools 5 | import subprocess 6 | import sys 7 | import json 8 | import ctypes 9 | import numpy as np 10 | 11 | 12 | def init(run_cargo, cache_path): 13 | def generate_type_combinations(export): 14 | keys = tuple(export.ty.keys()) 15 | for ty in itertools.product(*[export.ty[k] for k in keys]): 16 | yield {keys[i][0]:ty[i] for i in range(len(keys))} 17 | 18 | def get_type_string(ty, order): 19 | return "".join([f"_{ty[order[i]]}" for i in range(len(order))]) 20 | 21 | lib_path = "rust/target/release/libaethon.so" 22 | if os.path.exists(lib_path): 23 | lib_modified_time = os.path.getmtime(lib_path) 24 | for fn in glob.iglob("rust/src/*.rs"): 25 | if os.path.getmtime(fn) > lib_modified_time: 26 | run_cargo = True 27 | else: 28 | run_cargo = True 29 | 30 | if run_cargo: 31 | exports = [] 32 | for src_fn in glob.iglob("rust/src/*.rs"): 33 | fn = src_fn.split("/")[-1].split(".")[0] 34 | if fn == "exports": 35 | continue 36 | with open(src_fn, "r") as f: 37 | content = "".join(f.readlines()).replace("\n", " ") 38 | start = content.find("//PYTHON_EXPORT") 39 | while start != -1: 40 | end = content.find("{", start) 41 | exports.append((fn, content[start:end])) 42 | start = content.find("//PYTHON_EXPORT", end) 43 | 44 | for i, (fn, content) in enumerate(exports): 45 | export = types.SimpleNamespace(fn=fn, contains_ndarrays=False) 46 | 47 | index = content.find("pub fn ") 48 | assert index != - 1 49 | ty, content = content[15:index].strip(), content[index+7:] 50 | export.ty = [] 51 | if len(ty) > 0: 52 | for ty in ty.split(";"): 53 | export.ty.append([ 54 | ty.strip() for ty in ty.split(",") 55 | ]) 56 | 57 | index = content.find("(") 58 | assert index != -1 59 | export.name, content = content[:index], content[index+1:] 60 | index = export.name.find("<") 61 | if index != -1: 62 | export.name, generics = export.name[:index], export.name[index+1:] 63 | for j, ty in enumerate(generics.split(",")): 64 | export.ty[j].append(ty.split(":")[0].strip()) 65 | export.ty_order = {j:ty[-1] for j,ty in enumerate(export.ty)} 66 | export.ty = {ty[-1]:[*ty[:-1]] for ty in export.ty} 67 | 68 | index = content.find(")") 69 | assert index != -1 70 | params, content = content[:index], content[index+1:] 71 | export.params = [] 72 | for param in params.split(","): 73 | index = param.find(":") 74 | if index == -1: 75 | continue 76 | param = param[index+1:].strip() 77 | while " " in param: 78 | param = param.replace(" ", " ") 79 | index = param.find("<") 80 | if index != -1: 81 | export.contains_ndarrays = True 82 | param = param[:index].strip(), param[index+1:-1].strip() 83 | assert param[0] == "NdArray" 84 | else: 85 | param = param, 86 | export.params.append(param) 87 | 88 | index = content.find("->") 89 | export.ret_type = content[index+2:].strip() if index != -1 else None 90 | 91 | exports[i] = export 92 | 93 | def generate_export(f, export, ty): 94 | type_string = get_type_string(ty, getattr(export, "ty_order", {})) 95 | f.write(f"#[no_mangle]\npub extern \"C\" fn _{export.name}{type_string}(\n") 96 | for i, param in enumerate(export.params): 97 | suffix = "," if i < len(export.params)-1 else "" 98 | if len(param) > 1: 99 | t = ty[param[1]] if param[1] in ty else param[1] 100 | f.write(f"\tp{i}: *const {t}, p{i}_shape: *const i32, p{i}_strides: *const i32, p{i}_ndims: i32{suffix}\n") 101 | else: 102 | f.write(f"\tp{i}: {param[0]}{suffix}\n") 103 | f.write(") ") 104 | if not export.ret_type is None: 105 | f.write(f"-> {export.ret_type} ") 106 | f.write("{\n") 107 | if export.contains_ndarrays: 108 | for i, param in enumerate(export.params): 109 | if len(param) == 1: 110 | continue 111 | f.write(f"\tlet p{i} = NdArray::from_ptr(p{i}, p{i}_shape, p{i}_strides, p{i}_ndims);\n") 112 | f.write("\n") 113 | f.write(f"\t{export.name}(") 114 | params = ", ".join([f"p{i}" for i in range(len(export.params))]) 115 | f.write(params) 116 | f.write(")\n}\n\n") 117 | 118 | with open("rust/src/exports.rs", "w") as f: 119 | for fn in set([export.fn for export in exports]): 120 | funcs = "{%s}"%",".join([export.name for export in exports if export.fn == fn]) 121 | f.write(f"use crate::{fn}::{funcs};\n") 122 | f.write("use crate::utils::*;\n\n") 123 | for export in exports: 124 | if len(export.ty) > 0: 125 | for ty in generate_type_combinations(export): 126 | generate_export(f, export, ty) 127 | else: 128 | generate_export(f, export, {}) 129 | 130 | ret_val = subprocess.call("cargo build --release --manifest-path rust/Cargo.toml", shell=True) 131 | if ret_val != 0: 132 | sys.exit(ret_val) 133 | 134 | with open(f"{cache_path}/rust_exports.json", "w") as f: 135 | json.dump([export.__dict__ for export in exports], f) 136 | else: 137 | with open(f"{cache_path}/rust_exports.json", "r") as f: 138 | exports = [types.SimpleNamespace(**export) for export in json.load(f)] 139 | for export in exports: 140 | if not hasattr(export, "ty_order"): 141 | continue 142 | export.ty_order = {int(k):v for k,v in export.ty_order.items()} 143 | 144 | type_map = { 145 | "bool": (ctypes.c_bool, None), 146 | "i8": (ctypes.c_int8, np.int8), 147 | "i16": (ctypes.c_int16, np.int16), 148 | "i32": (ctypes.c_int32, np.int32), 149 | "i64": (ctypes.c_int64, np.int64), 150 | "isize": (ctypes.c_ssize_t, None), 151 | "u8": (ctypes.c_uint8, np.uint8), 152 | "u16": (ctypes.c_uint16, np.uint16), 153 | "u32": (ctypes.c_uint32, np.uint32), 154 | "u64": (ctypes.c_uint64, np.uint64), 155 | "usize": (ctypes.c_size_t, None), 156 | "f32": (ctypes.c_float, np.float32), 157 | "f64": (ctypes.c_double, np.float64) 158 | } 159 | reverse_type_map = {np.dtype(numpy_type):rust_type for rust_type, (_, numpy_type) in type_map.items() if not numpy_type is None} 160 | type_map = {rust_type:python_type for rust_type, (python_type, _) in type_map.items()} 161 | 162 | lib = ctypes.CDLL(lib_path) 163 | 164 | class FunctionWrapper(): 165 | def __init__(self, export, name_suffix="", params=None): 166 | self.func = lib[f"_{export.name}{name_suffix}"] 167 | params = export.params if params is None else params 168 | argtypes = [] 169 | for param in params: 170 | if len(param) > 1: 171 | argtypes.extend([ 172 | np.ctypeslib.ndpointer(type_map[param[1]]), 173 | np.ctypeslib.ndpointer(ctypes.c_int32), 174 | np.ctypeslib.ndpointer(ctypes.c_int32), 175 | ctypes.c_int32 176 | ]) 177 | else: 178 | argtypes.append(type_map[param[0]]) 179 | self.func.argtypes = argtypes 180 | if not export.ret_type is None: 181 | self.func.restype = type_map[export.ret_type] 182 | 183 | def __call__(self, *args): 184 | rust_args = [] 185 | for arg in args: 186 | if isinstance(arg, np.ndarray): 187 | rust_args.extend([arg, np.asarray(arg.shape,dtype=np.int32), np.asarray(arg.strides,dtype=np.int32), len(arg.shape)]) 188 | else: 189 | rust_args.append(arg) 190 | return self.func(*rust_args) 191 | 192 | class GenericFunctionWrapper(): 193 | def __init__(self, export): 194 | self.funcs = {} 195 | for ty in generate_type_combinations(export): 196 | type_string = get_type_string(ty, export.ty_order) 197 | params = [(param[0],ty[param[1]]) if len(param)>1 and param[1] in ty else param for param in export.params] 198 | self.funcs[type_string] = FunctionWrapper(export, type_string, params) 199 | self.type_args = [] 200 | for i in range(len(export.ty_order)): 201 | ty = export.ty_order[i] 202 | j = [j for j,param in enumerate(export.params) if len(param)>1 and param[1]==ty][0] 203 | self.type_args.append(j) 204 | 205 | def __call__(self, *args): 206 | type_string = "" 207 | for i in self.type_args: 208 | t = reverse_type_map[args[i].dtype] 209 | type_string = f"{type_string}_{t}" 210 | return self.funcs[type_string](*args) 211 | 212 | for export in exports: 213 | if export.contains_ndarrays: 214 | if len(export.ty) == 0: 215 | func = FunctionWrapper(export) 216 | else: 217 | func = GenericFunctionWrapper(export) 218 | else: 219 | func = lib[f"_{export.name}"] 220 | func.argtypes = [type_map[param[0]] for param in export.params] 221 | if not export.ret_type is None: 222 | func.restype = type_map[export.ret_type] 223 | globals()[export.name] = func 224 | -------------------------------------------------------------------------------- /rust/src/drawcircle.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::*; 2 | 3 | //PYTHON_EXPORT 4 | pub fn draw_circle( 5 | mut img: NdArray, mut gt: NdArray, center: NdArray, radius: f32, 6 | color: NdArray, class_index: i32, yrange: NdArray, xrange: NdArray 7 | ) { 8 | let radius = radius.powi(2); 9 | 10 | for y in yrange[0]..yrange[1] { 11 | let dy = ((center[0] - y) as i32).pow(2); 12 | 13 | for x in xrange[0]..xrange[1] { 14 | let dx = ((center[1] - x) as i32).pow(2); 15 | if ((dy + dx) as f32) > radius { 16 | continue; 17 | } 18 | 19 | for c in 0..3 { 20 | img[[y, x, c]] = color[c]; 21 | } 22 | gt[[y, x]] = class_index; 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /rust/src/lib.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | mod utils; 3 | mod rasterize; 4 | mod sparseimage; 5 | mod pixelstats; 6 | mod metrics; 7 | mod drawcircle; 8 | mod exports; 9 | -------------------------------------------------------------------------------- /rust/src/metrics.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::*; 2 | 3 | //PYTHON_EXPORT 4 | pub fn add_to_conf_mat(mut conf_mat: NdArray, yt: NdArray, yp: NdArray) { 5 | for i in 0..yt.shape[0] { 6 | for j in 0..yt.shape[1] { 7 | for k in 0..yt.shape[2] { 8 | let true_class = yt[[i,j,k]]; 9 | let predicted_class = yp[[i,j,k]] as i32; 10 | conf_mat[[true_class,predicted_class]] += 1; 11 | } 12 | } 13 | } 14 | } 15 | 16 | //PYTHON_EXPORT 17 | pub fn prepare_per_pixel_entropy(mut p: NdArray, threshold: f64) { 18 | for y in 0..p.shape[0] { 19 | for x in 0..p.shape[1] { 20 | for c in 0..p.shape[2] { 21 | let i = [y, x, c]; 22 | let v = p[i]; 23 | if v < threshold { 24 | p[i] = 0.0; 25 | } else { 26 | p[i] = -v * v.ln(); 27 | } 28 | } 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /rust/src/pixelstats.rs: -------------------------------------------------------------------------------- 1 | use num::{Integer, ToPrimitive}; 2 | use crate::utils::*; 3 | 4 | //PYTHON_EXPORT u8,u16,i32; u8,u16,i32 5 | pub fn accumulate_pixel_statistics( 6 | mut accum_buffer: NdArray, mut accum_buffer_depth: NdArray, 7 | mut class_counts: NdArray, image: NdArray, gt: NdArray, 8 | depth: NdArray) { 9 | 10 | for y in 0..image.shape[0] { 11 | for x in 0..image.shape[1] { 12 | for c in 0..image.shape[2] { 13 | let val: T = image[[y, x, c]]; 14 | let val = val.to_u64().unwrap(); 15 | accum_buffer[[c,0]] += val; 16 | accum_buffer[[c,1]] += val * val; 17 | } 18 | 19 | let val = depth[[y, x]] as f64; 20 | accum_buffer_depth[0] += val; 21 | accum_buffer_depth[1] += val * val; 22 | 23 | let class: U = gt[[y,x]]; 24 | let class = class.to_i32().unwrap(); 25 | class_counts[class] += 1; 26 | } 27 | } 28 | } 29 | 30 | //PYTHON_EXPORT u8,u16,i32 31 | pub fn accumulate_pixel_statistics_sparse( 32 | mut accum_buffer: NdArray, mut accum_buffer_depth: NdArray, mut class_counts: NdArray, 33 | image: NdArray, instances: NdArray, masks: NdArray, depth: NdArray) { 34 | 35 | for y in 0..image.shape[0] { 36 | for x in 0..image.shape[1] { 37 | for c in 0..image.shape[2] { 38 | let val: T = image[[y, x, c]]; 39 | let val = val.to_u64().unwrap(); 40 | accum_buffer[[c,0]] += val; 41 | accum_buffer[[c,1]] += val * val; 42 | } 43 | 44 | let val = depth[[y, x]] as f64; 45 | accum_buffer_depth[0] += val; 46 | accum_buffer_depth[1] += val * val; 47 | } 48 | } 49 | 50 | class_counts[0] += (image.shape[0] * image.shape[1]) as u64; 51 | for i in 0..instances.shape[0] { 52 | let class = instances[[i, 0]] as i32; 53 | 54 | let height = instances[[i, 2]] - instances[[i, 1]]; 55 | let width = instances[[i, 4]] - instances[[i, 3]]; 56 | 57 | let begin = instances[[i, 5]] as i32; 58 | let end = begin + (height * width) as i32; 59 | 60 | for j in begin..end { 61 | if masks[j] != 0 { 62 | class_counts[class] += 1; 63 | class_counts[0] -= 1_u64; 64 | } 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /rust/src/rasterize.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::*; 2 | use std::sync::Arc; 3 | 4 | fn is_clockwise(polygon: &NdArray) -> bool { 5 | let mut area = 0.; 6 | for i in 0..polygon.shape[0]-1 { 7 | area += polygon[[i, 0]] * polygon[[i+1, 1]]; 8 | area -= polygon[[i, 1]] * polygon[[i+1, 0]]; 9 | } 10 | area >= 0. 11 | } 12 | 13 | fn is_inside(polygon: &NdArray, x: f64, y: f64) -> bool { 14 | let mut crossings: usize = 0; 15 | 16 | for i in 0..polygon.shape[0]-1 { 17 | let xs = [polygon[[i, 0]], polygon[[i+1, 0]]]; 18 | let ys = [polygon[[i, 1]], polygon[[i+1, 1]]]; 19 | 20 | if y < ys[0].min(ys[1]) || y > ys[0].max(ys[1]) { 21 | continue; 22 | } 23 | 24 | if ys[0] != ys[1] { 25 | let mut intersection = (y - ys[0]) / (ys[1] - ys[0]); 26 | intersection *= xs[1] - xs[0]; 27 | intersection += xs[0]; 28 | crossings += if x <= intersection && intersection <= 2.0 { 1 } else { 0 }; 29 | } else { 30 | crossings += if x <= xs[0] && xs[0] <= 2.0 { 1 } else { 0 }; 31 | crossings += if x <= xs[1] && xs[1] <= 2.0 { 1 } else { 0 }; 32 | } 33 | } 34 | 35 | (crossings % 2) == 1 36 | } 37 | 38 | //PYTHON_EXPORT 39 | pub fn rasterize_objects( 40 | img: NdArray, geometry_lengths: NdArray, polygon_lengths: NdArray, 41 | polygons: NdArray, label: u8, num_threads: i32) { 42 | 43 | assert!(num_threads > 0); 44 | 45 | let height = img.shape[0] as f64; 46 | let width = img.shape[1] as f64; 47 | 48 | let height_factor = 1. / height; 49 | let width_factor = 1. / width; 50 | 51 | let half_pel_y = 0.5 * height_factor; 52 | let half_pel_x = 0.5 * width_factor; 53 | 54 | let img = Arc::new(img); 55 | let mut threads = Vec::new(); 56 | 57 | for i in 0..geometry_lengths.shape[0] { 58 | let mut left = f64::INFINITY; 59 | let mut right = -f64::INFINITY; 60 | let mut top = f64::INFINITY; 61 | let mut bottom = -f64::INFINITY; 62 | let mut geometry = Vec::new(); 63 | 64 | for j in 0..geometry_lengths[i] { 65 | let k = polygon_lengths[[i, j]]; 66 | let polygon = polygons.view(0, i).view(0, j).view_range(0, 0, k); 67 | 68 | for l in 0..polygon.shape[0] { 69 | let x = polygon[[l, 0]]; 70 | left = left.min(x); 71 | right = right.max(x); 72 | 73 | let y = polygon[[l, 1]]; 74 | top = top.min(y); 75 | bottom = bottom.max(y); 76 | } 77 | 78 | let clockwise = is_clockwise(&polygon); 79 | geometry.push((polygon, clockwise)); 80 | } 81 | 82 | let left = (img.shape[1] as i32).min((left * width) as i32).max(0); 83 | let right = (img.shape[1] as i32).min(1 + (right * width) as i32).max(0); 84 | let top = (img.shape[0] as i32).min((top * height) as i32).max(0); 85 | let bottom = (img.shape[0] as i32).min(1 + (bottom * height) as i32).max(0); 86 | 87 | let left = Arc::new(left); 88 | let right = Arc::new(right); 89 | let top = Arc::new(top); 90 | let bottom = Arc::new(bottom); 91 | let geometry = Arc::new(geometry); 92 | 93 | for thread_id in 0..num_threads { 94 | let img = img.clone(); 95 | let left = left.clone(); 96 | let right = right.clone(); 97 | let top = top.clone(); 98 | let bottom = bottom.clone(); 99 | let geometry = geometry.clone(); 100 | 101 | let handle = std::thread::spawn(move || { 102 | let mut img = (*img).clone(); 103 | 104 | for y in *top..*bottom { 105 | if y % num_threads != thread_id { 106 | continue; 107 | } 108 | 109 | let fy = ((y as f64) * height_factor) + half_pel_y; 110 | 111 | for x in *left..*right { 112 | let fx = ((x as f64) * width_factor) + half_pel_x; 113 | 114 | let mut inside_object = false; 115 | let mut inside_hole = false; 116 | 117 | for (polygon, clockwise) in geometry.iter() { 118 | if is_inside(polygon, fx, fy) { 119 | if *clockwise { 120 | inside_object = true; 121 | } else { 122 | inside_hole = true; 123 | } 124 | } 125 | } 126 | 127 | if inside_object && !inside_hole { 128 | img[[y, x]] = label; 129 | } 130 | } 131 | } 132 | }); 133 | 134 | threads.push(handle); 135 | } 136 | 137 | while !threads.is_empty() { 138 | threads.pop().unwrap().join().unwrap(); 139 | } 140 | } 141 | } 142 | 143 | fn fill(img: &mut NdArray, cy: i32, cx: i32, votes: &mut NdArray, ignored_votes: usize) -> bool { 144 | for i in 0..votes.shape[0] { 145 | votes[i] = 0; 146 | } 147 | 148 | for y in (cy-1).max(0)..(cy+2).min(img.shape[0]) { 149 | for x in (cx-1).max(0)..(cx+2).min(img.shape[1]) { 150 | let mut val = img[[y, x]]; 151 | if val == 255 { 152 | val = 15; 153 | } 154 | votes[val] += 1; 155 | } 156 | } 157 | votes[15] -= ignored_votes; 158 | 159 | let mut majority = 0; 160 | for i in 1..votes.shape[0] { 161 | if votes[i] > votes[majority] { 162 | majority = i; 163 | } 164 | } 165 | 166 | if majority < 15 { 167 | img[[cy, cx]] = majority as u8; 168 | return true; 169 | } 170 | 171 | false 172 | } 173 | 174 | //PYTHON_EXPORT 175 | pub fn flood_fill_holes(mut img: NdArray) { 176 | let mut undefined_pixels = Vec::new(); 177 | for y in 0..img.shape[0] { 178 | for x in 0..img.shape[1] { 179 | if img[[y, x]] == 255 { 180 | undefined_pixels.push((y, x)); 181 | } 182 | } 183 | } 184 | 185 | ndarray!(votes -> 0usize; 16); 186 | let mut ignored_votes = 1; // don't include the center pixel 187 | let mut changed_pixels = Vec::new(); 188 | 189 | while !undefined_pixels.is_empty() { 190 | changed_pixels.clear(); 191 | for i in 0..undefined_pixels.len() { 192 | let (y, x) = undefined_pixels[i]; 193 | let result = fill(&mut img, y, x, &mut votes, ignored_votes); 194 | if result { 195 | changed_pixels.push(i); 196 | } 197 | } 198 | 199 | if changed_pixels.is_empty() { 200 | ignored_votes += 1; // ignore an additional 'undefined' vote to eventually grow defined region again 201 | } else { 202 | ignored_votes = 1; 203 | changed_pixels.reverse(); 204 | for i in &changed_pixels { 205 | undefined_pixels.remove(*i); 206 | } 207 | } 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /rust/src/sparseimage.rs: -------------------------------------------------------------------------------- 1 | use crate::utils::*; 2 | use std::collections::BTreeMap; 3 | 4 | //PYTHON_EXPORT 5 | pub fn rgb2ids(mut gt_ids: NdArray, gt_rgb: NdArray, mut inst_ids: NdArray, inst_rgb: NdArray, class_map_array: NdArray) -> usize { 6 | let mut class_map = BTreeMap::new(); 7 | for i in 0..class_map_array.shape[0] { 8 | class_map.insert((class_map_array[[i,0]], class_map_array[[i,1]], class_map_array[[i,2]]), i as u8); 9 | } 10 | 11 | let mut instance_map = BTreeMap::new(); 12 | 13 | for y in 0..gt_rgb.shape[0] { 14 | for x in 0..gt_rgb.shape[1] { 15 | let class_id = class_map[&(gt_rgb[[y,x,0]], gt_rgb[[y,x,1]], gt_rgb[[y,x,2]])]; 16 | gt_ids[[y, x]] = class_id; 17 | 18 | let mut instance_id = 0; 19 | if class_id > 0 { 20 | let new_id = (instance_map.len() + 1) as i32; 21 | let instance_rgb = (inst_rgb[[y,x,0]], inst_rgb[[y,x,1]], inst_rgb[[y,x,2]]); 22 | let instance = instance_map.entry(instance_rgb).or_insert(new_id); 23 | instance_id = *instance; 24 | } 25 | inst_ids[[y, x]] = instance_id; 26 | } 27 | } 28 | 29 | instance_map.len() 30 | } 31 | 32 | //PYTHON_EXPORT 33 | pub fn extract_instances(mut instances: NdArray, gt: NdArray, inst_img: NdArray) { 34 | for i in 0..instances.shape[0] { 35 | instances[[i, 0]] = 0; // class ID 36 | instances[[i, 1]] = u64::MAX; // first y (inclusive) 37 | instances[[i, 2]] = u64::MIN; // last y (exclusive) 38 | instances[[i, 3]] = u64::MAX; // first x (inclusive) 39 | instances[[i, 4]] = u64::MIN; // last x (exclusive) 40 | } 41 | 42 | for y in 0..gt.shape[0] { 43 | let y64 = y as u64; 44 | 45 | for x in 0..gt.shape[1] { 46 | let i = inst_img[[y, x]]; 47 | if i == 0 { 48 | continue; 49 | } 50 | 51 | let i = i - 1; 52 | let x64 = x as u64; 53 | 54 | instances[[i, 0]] = gt[[y, x]] as u64; 55 | instances[[i, 1]] = instances[[i, 1]].min(y64); 56 | instances[[i, 2]] = instances[[i, 2]].max(y64+1); 57 | instances[[i, 3]] = instances[[i, 3]].min(x64); 58 | instances[[i, 4]] = instances[[i, 4]].max(x64+1); 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /rust/src/utils.rs: -------------------------------------------------------------------------------- 1 | pub struct NdArray { 2 | data: *const T, 3 | pub shape: Vec, 4 | strides: Vec 5 | } 6 | 7 | unsafe impl Send for NdArray {} 8 | unsafe impl Sync for NdArray {} 9 | 10 | impl NdArray where T: Copy { 11 | pub fn new(buffer: &mut Vec, initial: T, shape: Vec) -> NdArray { 12 | let mut strides = vec![1; shape.len()]; 13 | for i in (0..shape.len()-1).rev() { 14 | strides[i] = strides[i+1] * (shape[i+1] as isize); 15 | } 16 | buffer.resize((strides[0] as usize) * (shape[0] as usize), initial); 17 | 18 | NdArray { 19 | data: buffer.as_ptr(), 20 | shape: shape.clone(), 21 | strides: strides 22 | } 23 | } 24 | 25 | // use this function very carefully as it can be abused to break Rust's safety guarantees! 26 | pub fn from_ptr(ptr: *const T, shape: *const i32, strides_in: *const i32, num_dims: i32) -> NdArray { 27 | let num_dims = num_dims as usize; 28 | let mut strides = Vec::with_capacity(num_dims); 29 | let item_size = std::mem::size_of::() as isize; 30 | 31 | unsafe { 32 | let shape = std::slice::from_raw_parts(shape, num_dims).to_vec(); 33 | for stride in std::slice::from_raw_parts(strides_in, num_dims) { 34 | strides.push((*stride as isize) / item_size); 35 | } 36 | 37 | NdArray { 38 | data: ptr, 39 | shape: shape.clone(), 40 | strides: strides 41 | } 42 | } 43 | } 44 | 45 | // use this function very carefully as it can be abused to break Rust's safety guarantees! 46 | pub fn clone(&self) -> NdArray { 47 | NdArray { 48 | data: self.data, 49 | shape: self.shape.clone(), 50 | strides: self.strides.clone() 51 | } 52 | } 53 | 54 | // use this function very carefully as it can be abused to break Rust's safety guarantees! 55 | pub fn view(&self, dim: usize, i: i32) -> NdArray { 56 | debug_assert!(dim < self.shape.len()); 57 | debug_assert!(i < self.shape[dim]); 58 | 59 | let mut shape = self.shape.clone(); 60 | shape.remove(dim); 61 | let mut strides = self.strides.clone(); 62 | strides.remove(dim); 63 | 64 | unsafe { 65 | NdArray { 66 | data: self.data.offset((i as isize) * self.strides[dim]), 67 | shape: shape, 68 | strides: strides 69 | } 70 | } 71 | } 72 | 73 | // use this function very carefully as it can be abused to break Rust's safety guarantees! 74 | pub fn view_range(&self, dim: usize, low: i32, high: i32) -> NdArray { 75 | debug_assert!(dim < self.shape.len()); 76 | debug_assert!(low < high); 77 | debug_assert!(0 <= low); 78 | debug_assert!(high <= self.shape[dim]); 79 | 80 | let mut shape = self.shape.clone(); 81 | shape[dim] = high - low; 82 | 83 | unsafe { 84 | NdArray { 85 | data: self.data.offset((low as isize) * self.strides[dim]), 86 | shape: shape, 87 | strides: self.strides.clone() 88 | } 89 | } 90 | } 91 | } 92 | 93 | macro_rules! generate_index_ops { 94 | ( compute_index: $s: ident, $i: ident, $n: literal -> $result: ident ) => { 95 | let mut $result = 0; 96 | for i in 0..$n { 97 | $result += ($i[i] as isize) * $s.strides[i]; 98 | } 99 | }; 100 | 101 | ( $n: literal, $t: ty ) => { 102 | impl std::ops::Index<[$t; $n]> for NdArray { 103 | type Output = T; 104 | 105 | fn index(&self, i: [$t; $n]) -> &Self::Output { 106 | debug_assert_eq!(self.shape.len(), $n); 107 | generate_index_ops!(compute_index: self, i, $n -> result); 108 | unsafe { 109 | &*self.data.offset(result) 110 | } 111 | } 112 | } 113 | 114 | impl std::ops::IndexMut<[$t; $n]> for NdArray { 115 | fn index_mut(&mut self, i: [$t; $n]) -> &mut Self::Output { 116 | debug_assert_eq!(self.shape.len(), $n); 117 | generate_index_ops!(compute_index: self, i, $n -> result); 118 | unsafe { 119 | &mut *(self.data.offset(result) as *mut T) 120 | } 121 | } 122 | } 123 | }; 124 | 125 | ( $t: ty ) => { 126 | impl std::ops::Index<$t> for NdArray { 127 | type Output = T; 128 | 129 | fn index(&self, i: $t) -> &Self::Output { 130 | debug_assert_eq!(self.shape.len(), 1); 131 | unsafe { 132 | &*self.data.offset((i as isize) * self.strides[0]) 133 | } 134 | } 135 | } 136 | 137 | impl std::ops::IndexMut<$t> for NdArray { 138 | fn index_mut(&mut self, i: $t) -> &mut Self::Output { 139 | debug_assert_eq!(self.shape.len(), 1); 140 | unsafe { 141 | &mut *(self.data.offset((i as isize) * self.strides[0]) as *mut T) 142 | } 143 | } 144 | } 145 | 146 | generate_index_ops!(1, $t); 147 | generate_index_ops!(2, $t); 148 | generate_index_ops!(3, $t); 149 | generate_index_ops!(4, $t); 150 | generate_index_ops!(5, $t); 151 | generate_index_ops!(6, $t); 152 | generate_index_ops!(7, $t); 153 | generate_index_ops!(8, $t); 154 | generate_index_ops!(9, $t); 155 | } 156 | } 157 | 158 | generate_index_ops!(i8); 159 | generate_index_ops!(i16); 160 | generate_index_ops!(i32); 161 | generate_index_ops!(i64); 162 | generate_index_ops!(isize); 163 | generate_index_ops!(u8); 164 | generate_index_ops!(u16); 165 | generate_index_ops!(u32); 166 | generate_index_ops!(u64); 167 | generate_index_ops!(usize); 168 | 169 | macro_rules! ndarray { 170 | ( $name: ident -> $initial: expr; $($shape: expr),+ ) => { 171 | let mut $name = vec![$initial]; 172 | let mut $name = NdArray::new(&mut $name, $initial, vec![ $($shape),+ ] ); 173 | }; 174 | } 175 | -------------------------------------------------------------------------------- /samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gritzner/SegForestNet/b7b88ab86677b1b0f2455ff3d41975bc79212831/samples.png -------------------------------------------------------------------------------- /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .array import array 2 | from .archive import archive 3 | from .remotesync import remotesync 4 | from .monitor import monitor 5 | from .slurm import slurm 6 | from .semanticsegmentation import SemanticSegmentation 7 | from .evalmodel import EvalSemanticSegmentationModel 8 | -------------------------------------------------------------------------------- /tasks/archive.py: -------------------------------------------------------------------------------- 1 | import core 2 | from datetime import datetime 3 | from tempfile import TemporaryDirectory 4 | import subprocess 5 | import sys 6 | import glob 7 | import os 8 | 9 | 10 | def add_path_to_archive(archive_fn, path): 11 | for fn in glob.iglob(f"{path}*"): 12 | if fn in ("git_log.txt", "hpo", "tmp", "rust/Cargo.lock", "rust/src/exports.rs", "rust/target", archive_fn): 13 | continue 14 | if os.path.isfile(fn): 15 | core.call(f"tar -rf {archive_fn} {fn}") 16 | elif fn.split("/")[-1] != "__pycache__": 17 | add_path_to_archive(archive_fn, f"{fn}/") 18 | 19 | 20 | def archive(): 21 | dt = datetime.now() 22 | archive_fn = f"Aethon-{dt.year}_{dt.month:02}_{dt.day:02}-{dt.hour:02}_{dt.minute:02}_{dt.second:02}.tar" 23 | 24 | with TemporaryDirectory() as temp_dir: 25 | core.call(f"git log -{core.user.git_log_num_lines} > {temp_dir}/git_log.txt") 26 | core.call(f"tar -cf {archive_fn} -C {temp_dir} git_log.txt") 27 | core.call(f"mkdir -p {temp_dir}/.cargo") 28 | with open(".cargo/config.toml", "r") as f_in: 29 | with open(f"{temp_dir}/.cargo/config.toml", "w") as f_out: 30 | f_out.write(f'''[source.crates-io] 31 | replace-with = "vendored-sources" 32 | 33 | [source.vendored-sources] 34 | directory = "vendor" 35 | 36 | {f_in.read()} 37 | ''') 38 | core.call(f"tar -rf {archive_fn} -C {temp_dir} .cargo/config.toml") 39 | ret_val = subprocess.call(f"cargo vendor --manifest-path {core.base_path}/rust/Cargo.toml", shell=True, cwd=temp_dir) 40 | if ret_val != 0: 41 | sys.exit(ret_val) 42 | for fn in glob.iglob(f"{temp_dir}/vendor/**", recursive=True): 43 | core.call(f"tar -rf {archive_fn} -C {temp_dir} {fn[len(temp_dir)+1:]}") 44 | 45 | add_path_to_archive(archive_fn, "") 46 | core.call(f"bzip2 -f {archive_fn}") 47 | -------------------------------------------------------------------------------- /tasks/array.py: -------------------------------------------------------------------------------- 1 | import core 2 | import subprocess 3 | import sys 4 | from concurrent.futures import ThreadPoolExecutor 5 | import threading 6 | import signal 7 | import time 8 | import argparse 9 | import os 10 | import json 11 | import types 12 | from itertools import repeat 13 | 14 | 15 | def run_task(args, i): 16 | process = subprocess.run( 17 | (sys.executable, f"cfgs/{args.configuration}.py"), input=str(i[1]), text=True, capture_output=True 18 | ) 19 | if process.returncode != 0: 20 | print(f"[error@{i}] {process.stderr.strip()}") 21 | return process, None 22 | s = process.stdout.strip() 23 | print(f"[{i}] {s}") 24 | if not args.print_only: 25 | if len(args.log) > 0: 26 | return process, subprocess.run(("python", "aethon.py", *s.split(" ")), text=True, capture_output=True) 27 | else: 28 | return process, types.SimpleNamespace(returncode=core.call(f"python aethon.py {s}"), stdout="", stderr="") 29 | return process, None 30 | 31 | 32 | def proxy_func(args, log, lock, i, j): 33 | i = i, j 34 | process0, process1 = run_task(args, i) 35 | returncode = 0 36 | if process0.returncode != 0: 37 | returncode = process0.returncode 38 | elif (not process1 is None) and process1.returncode != 0: 39 | returncode = process1.returncode 40 | if len(args.log) == 0: 41 | return returncode 42 | 43 | results = {} 44 | for j, process in enumerate((process0, process1)): 45 | if process is None: 46 | continue 47 | results[f"process{j}"] = { 48 | "returncode": process.returncode, 49 | "stdout": process.stdout.strip(), 50 | "stderr": process.stderr.strip() 51 | } 52 | 53 | lock.acquire() 54 | log[str(i)] = results 55 | with open(args.log, "w") as f: 56 | json.dump(log, f) 57 | lock.release() 58 | 59 | return returncode 60 | 61 | 62 | def monitor(config, shutdown): 63 | process = subprocess.Popen((sys.executable, "aethon.py", "@monitor", config)) 64 | while not shutdown.wait(5): 65 | pass 66 | process.send_signal(signal.SIGINT) 67 | process.wait() 68 | 69 | 70 | def array(): 71 | parser = argparse.ArgumentParser(prog=os.path.basename(sys.argv[0])+" @array", description="Task Array") 72 | parser.add_argument("configuration", type=str, help="configuration file") 73 | parser.add_argument("array_range", type=str, help="array range, e.g., 0-9 (end inclusive)") 74 | parser.add_argument("--threads", default=0, type=int, help="set the number of worker threads to use") 75 | parser.add_argument("--print-only", action="store_true", help="only print resulting Aethon parameters but do not execute") 76 | parser.add_argument("--monitor", default="", type=str, help="start SLURM monitor thread using the supplied configuration") 77 | parser.add_argument("--log", default="", type=str, help="capture results of tasks and log them to a JSON file") 78 | args = parser.parse_args([f"--{arg[1:]}" if arg[0]=="@" else arg for arg in core.args.parameters]) 79 | 80 | if len(args.monitor) > 0: 81 | shutdown = threading.Event() 82 | monitor_thread = threading.Thread(target=monitor, args=(args.monitor, shutdown)) 83 | monitor_thread.start() 84 | print("waiting for SLURM monitoring to start ...") 85 | time.sleep(15) 86 | 87 | if "," in args.array_range: 88 | temp = args.array_range.split(",") 89 | else: 90 | temp = [args.array_range] 91 | args.array_range = [] 92 | for array_range in temp: 93 | if "x" in array_range: 94 | n, array_range = array_range.split("x") 95 | args.array_range.extend([array_range] * int(n)) 96 | else: 97 | args.array_range.append(array_range) 98 | del temp 99 | for i, array_range in enumerate(args.array_range): 100 | if "-" in array_range: 101 | array_range = [int(x) for x in array_range.split("-")] 102 | assert len(array_range) == 2 103 | else: 104 | array_range = [int(array_range)] * 2 105 | array_range[1] += 1 106 | args.array_range[i] = range(*array_range) 107 | 108 | tasks = [] 109 | log = {} 110 | returncode = 0 111 | try: 112 | if args.threads > 0: 113 | lock = threading.Lock() 114 | with ThreadPoolExecutor(max_workers=args.threads) as thread_pool: 115 | tasks = [thread_pool.map(lambda i, j: proxy_func(args, log, lock, i, j), repeat(i), array_range) for i, array_range in enumerate(args.array_range)] 116 | else: 117 | lock = types.SimpleNamespace(acquire=lambda : None, release=lambda : None) 118 | tasks = [[proxy_func(args, log, lock, i, j) for j in array_range] for i, array_range in enumerate(args.array_range)] 119 | for task_set in tasks: 120 | for i in task_set: 121 | if i != 0: 122 | returncode = i 123 | except KeyboardInterrupt: 124 | returncode = 1 125 | 126 | if len(args.monitor) > 0: 127 | shutdown.set() 128 | monitor_thread.join() 129 | 130 | sys.exit(returncode) 131 | -------------------------------------------------------------------------------- /tasks/evalmodel.py: -------------------------------------------------------------------------------- 1 | import core 2 | import datasets 3 | import models 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import utils 8 | import time 9 | import cv2 as cv 10 | 11 | 12 | #TODO: may be broken and in need of fixing 13 | class EvalSemanticSegmentationModel(utils.ModelFitter): 14 | def __init__(self, config, params): 15 | super().__init__(config) 16 | 17 | self.dataset = core.create_object(datasets, config.dataset) 18 | print() # add a new line for nicer output formatting 19 | 20 | self.num_mini_batches = 1 21 | self.config.epochs = 1 22 | if config.ignore_dataset_class_weights: 23 | self.dataset.class_weights[:] = 1 24 | self.dataset.class_weights[self.dataset.ignore_class] = 0 25 | 26 | self.conf_mat = utils.ConfusionMatrix(self.dataset.num_classes, self.dataset.ignore_class) 27 | self.conf_mat.reset() 28 | self.output_set.update(("val_loss", "val_acc", "val_miou", "test_loss", "test_acc", "test_miou")) 29 | 30 | lut = np.flip(np.asarray(self.dataset.lut, dtype=np.uint8), axis=1) 31 | core.call(f"mkdir -p {core.output_path}/val_images") 32 | for image_id in np.unique(self.dataset.validation.x.patch_info[:,0]): 33 | img = self.dataset.base.images[image_id].base[:,:,self.dataset.validation.x_vis.channels] 34 | if img.dtype != np.uint8: 35 | img = np.asarray(img, dtype=np.uint8) 36 | cv.imwrite(f"{core.output_path}/val_images/{image_id}_input.png", np.flip(img, axis=2), (cv.IMWRITE_PNG_COMPRESSION, 9)) 37 | cv.imwrite(f"{core.output_path}/val_images/{image_id}_gt.png", lut[self.dataset.base.images[image_id].gt], (cv.IMWRITE_PNG_COMPRESSION, 9)) 38 | 39 | def pre_evaluate(self, epoch): 40 | self.model_prepare_func(*self.eval_params.config.model_epochs, True) 41 | self.num_mini_batches = int(np.ceil(self.eval_params.dataset.x.shape[0] / self.eval_params.config.mini_batch_size)) 42 | self.eval_params.loss = 0 43 | self.eval_params.predictions = { 44 | image_id: [ 45 | np.zeros((*self.dataset.base.images[image_id].base.shape[:2], self.dataset.num_classes), dtype=np.float64), 46 | np.zeros((*self.dataset.base.images[image_id].base.shape[:2], 1), dtype=np.uint64) 47 | ] for image_id in np.unique(self.eval_params.dataset.x.patch_info[:,0]) 48 | } 49 | self.eval_params.time = time.perf_counter() 50 | 51 | def pre_train(self, epoch, batch, batch_data): 52 | if not self.eval_params.enabled: 53 | return 54 | indices = slice(batch * self.eval_params.config.mini_batch_size, (batch+1) * self.eval_params.config.mini_batch_size) 55 | batch_data.x = self.eval_params.dataset.x_gt[indices] if self.eval_params.config.autoencoder else self.eval_params.dataset.x[indices] 56 | batch_data.yt = self.eval_params.dataset.y[indices] 57 | batch_data.index_map = self.eval_params.dataset.index_map[indices] 58 | 59 | def train(self, epoch, batch, batch_data, metrics): 60 | if not self.eval_params.enabled: 61 | return 62 | with torch.no_grad(): 63 | x = torch.from_numpy(batch_data.x).float().to(core.device) 64 | yt = torch.from_numpy(batch_data.yt).long().to(core.device).detach() 65 | yp = self.model(x) 66 | loss = self.loss_func(yp, yt, weight=self.dataset.class_weights, ignore_index=self.dataset.ignore_class) 67 | self.eval_params.loss += loss.item() 68 | batch_data.yp = yp.softmax(1).cpu().numpy() 69 | 70 | def post_train(self, epoch, batch, batch_data, metrics): 71 | if not self.eval_params.enabled: 72 | return 73 | for yp, index_map in zip(batch_data.yp, batch_data.index_map): 74 | p = self.eval_params.predictions[index_map[0,0,0]] 75 | p[0][index_map[1],index_map[2]] += np.moveaxis(yp, 0, 2) 76 | p[1][index_map[1],index_map[2]] += 1 77 | 78 | def post_evaluate(self, epoch): 79 | self.eval_params.time = time.perf_counter() - self.eval_params.time 80 | self.eval_params.metrics[f"{self.eval_params.prefix}time"] = self.eval_params.time 81 | self.eval_params.metrics[f"{self.eval_params.prefix}time_per_sample"] = self.eval_params.time / self.eval_params.dataset.x.shape[0] 82 | self.eval_params.metrics[f"{self.eval_params.prefix}loss"] = self.eval_params.loss / self.num_mini_batches 83 | 84 | lut = np.flip(np.asarray(self.dataset.lut, dtype=np.uint8), axis=1) 85 | self.conf_mat.reset() 86 | for image_id, p in self.eval_params.predictions.items(): 87 | yt = self.dataset.base.images[image_id].gt 88 | yt = yt if isinstance(yt,np.ndarray) else yt.get_semantic_image() 89 | yt = yt if yt.dtype==np.int32 else np.asarray(yt, dtype=np.int32) 90 | yp = np.argmax(p[0] / p[1], axis=2) 91 | self.conf_mat.add( 92 | np.expand_dims(yt, axis=0), 93 | np.expand_dims(yp, axis=0) 94 | ) 95 | if self.eval_params.save_predicitions: 96 | cv.imwrite( 97 | f"{core.output_path}/val_images/{self.eval_params.config.key}/{image_id}_prediction.png", 98 | lut[yp], (cv.IMWRITE_PNG_COMPRESSION, 9) 99 | ) 100 | 101 | for key, value in self.conf_mat.compute_metrics().__dict__.items(): 102 | self.eval_params.metrics[f"{self.eval_params.prefix}{key}"] = value 103 | self.eval_params.metrics[f"{self.eval_params.prefix}conf_mat"] = self.conf_mat.to_dict() 104 | 105 | def finalize(self): 106 | for config in self.config.models: 107 | print(f"evaluating configuration '{config.key}' ...") 108 | 109 | self.dataset.replace_normalization_params(np.load(f"{core.output_path}/normalization_params.npy")) 110 | if hasattr(config, "normalization_parameters"): 111 | norm_params = np.load(config.normalization_parameters.path) 112 | if config.normalization_parameters.std_only: 113 | norm_params = norm_params[:,1] 114 | self.dataset.replace_normalization_params(norm_params) 115 | 116 | shape = list(self.dataset.training.x.shape[-3:]) 117 | if config.autoencoder: 118 | shape[0] = 1 119 | self.model = core.create_object(models.segmentation, config.model, input_shape=shape, num_classes=self.dataset.num_classes) 120 | with core.open(config.model_weights, "rb") as f: 121 | self.model.load_state_dict(torch.load(f, map_location=core.device)) 122 | self.model.eval() 123 | 124 | if config.ignore_model_loss or not hasattr(self.model, "get_loss_function"): 125 | self.loss_func = nn.functional.cross_entropy 126 | else: 127 | self.loss_func = self.model.get_loss_function() 128 | self.model_prepare_func = getattr(self.model, "prepare_for_epoch", lambda x,y,z: None) 129 | 130 | metrics = {} 131 | core.call(f"mkdir -p {core.output_path}/val_images/{config.key}") 132 | self.evaluate( 133 | 0, metrics=metrics, config=config, 134 | dataset=self.dataset.validation, prefix="val_", save_predicitions=True 135 | ) 136 | self.evaluate( 137 | 0, metrics=metrics, config=config, 138 | dataset=self.dataset.test, prefix="test_", save_predicitions=False 139 | ) 140 | self.history[config.key] = metrics 141 | -------------------------------------------------------------------------------- /tasks/monitor.py: -------------------------------------------------------------------------------- 1 | import core 2 | import yaml 3 | import threading 4 | import signal 5 | import functools 6 | import types 7 | import socket 8 | import subprocess 9 | 10 | 11 | def monitor(): 12 | assert len(core.args.parameters) > 0 13 | with open(f"{core.base_path}/cfgs/{core.args.parameters[0]}.yaml", "r") as f: 14 | config = core.parse_dict(yaml.safe_load(f.read())) 15 | 16 | def monitor(state): 17 | prefix = ("ssh", config.servers.login) if hasattr(config, "servers") else () 18 | while True: 19 | try: 20 | process = subprocess.run( 21 | (*prefix, "squeue", "-u", config.slurm.monitor.user, "-o", "%i"), 22 | timeout=config.slurm.monitor.remote_timeout, text=True, capture_output=True 23 | ) 24 | except subprocess.TimeoutExpired: 25 | continue 26 | if process.returncode == 0: 27 | jobs = set() 28 | ignored_jobs = set() 29 | for jobid in process.stdout.split("\n")[1:-1]: 30 | if "_" in jobid: 31 | ignored_jobs.add(jobid) 32 | else: 33 | jobs.add(int(jobid)) 34 | if config.slurm.monitor.verbose: 35 | print("jobs in queue:", jobs) 36 | print("ignored jobs:", ignored_jobs) 37 | with state.lock: 38 | state.jobs.clear() 39 | state.jobs.update(jobs) 40 | if state.shutdown.wait(config.slurm.monitor.timer): 41 | return 42 | 43 | state = types.SimpleNamespace( 44 | shutdown = threading.Event(), 45 | lock = threading.Lock(), 46 | jobs = set() 47 | ) 48 | def signal_handler(sig, frame, state): 49 | print("shutting down ...") 50 | state.shutdown.set() 51 | signal.signal(signal.SIGINT, functools.partial(signal_handler, state=state)) 52 | 53 | monitor_thread = threading.Thread(target=monitor, args=(state,)) 54 | monitor_thread.start() 55 | 56 | server = socket.create_server(("localhost", config.slurm.monitor.port)) 57 | server.settimeout(config.slurm.monitor.timeout) 58 | 59 | print(f"server monitoring '{core.args.parameters[0]}' is running on port {config.slurm.monitor.port} ...") 60 | while not state.shutdown.is_set(): 61 | try: 62 | conn, _addr = server.accept() 63 | data = conn.recv(256).decode() 64 | tokens = data.split(":") 65 | if len(tokens) != 2: 66 | print("ignoring invalid message:", data) 67 | else: 68 | msg, jobid = tokens[0].lower(), int(tokens[1]) 69 | if msg == "register": 70 | with state.lock: 71 | state.jobs.add(jobid) 72 | elif msg == "query": 73 | with state.lock: 74 | done = not jobid in state.jobs 75 | conn.send(str(done).encode()) 76 | else: 77 | print("ignoring unknown command:", data) 78 | except TimeoutError: 79 | pass 80 | 81 | server.close() 82 | monitor_thread.join() 83 | -------------------------------------------------------------------------------- /tasks/remotesync.py: -------------------------------------------------------------------------------- 1 | import core 2 | import yaml 3 | import sys 4 | 5 | 6 | def remotesync(): 7 | assert len(core.args.parameters) > 1 8 | config, dataset = core.args.parameters[:2] 9 | with open(f"{core.base_path}/cfgs/{config}.yaml", "r") as f: 10 | config = core.parse_dict(yaml.safe_load(f.read())) 11 | target_path = f"{config.primary_path}/{dataset}" 12 | if len(core.args.parameters) == 2 or not core.str2bool(core.args.parameters[2]): 13 | dataset_paths = core.user.dataset_paths[dataset] 14 | returncode = core.call(f"ssh {config.servers.login} mkdir -p {target_path}") 15 | if returncode != 0: 16 | sys.exit(returncode) 17 | if type(dataset_paths) == str: 18 | dataset_paths = (dataset_paths,) 19 | for i, dataset_path in enumerate(dataset_paths): 20 | returncode = core.call(f"rsync -avz {dataset_path} {config.servers.transfer}:{target_path}/{i}") 21 | if returncode != 0: 22 | sys.exit(returncode) 23 | else: 24 | sys.exit(core.call(f"ssh {config.servers.transfer} rsync -avz {target_path} {config.secondary_path}")) 25 | -------------------------------------------------------------------------------- /tasks/slurm.py: -------------------------------------------------------------------------------- 1 | import core 2 | import yaml 3 | import subprocess 4 | import re 5 | import sys 6 | import socket 7 | import time 8 | 9 | 10 | def slurm(): 11 | assert len(core.args.parameters) > 6 12 | config, jobname, jobtime, mem, code, output_path = core.args.parameters[:6] 13 | 14 | with open(f"{core.base_path}/cfgs/{config}.yaml", "r") as f: 15 | config = core.parse_dict(yaml.safe_load(f.read())) 16 | prefix = ("ssh", config.servers.login) if hasattr(config, "servers") else () 17 | config.slurm.options["job-name"] = jobname 18 | config.slurm.options["time"] = jobtime 19 | config.slurm.options["mem"] = mem 20 | config.slurm.options["chdir"] = config.slurm.jobs_path 21 | slurm_options = "\n".join([f"#SBATCH --{str(k)}={str(v)}" for k, v in config.slurm.options.items()]) 22 | script = config.slurm.script.format( 23 | slurm_options = slurm_options, 24 | jobs_path = config.slurm.jobs_path, 25 | code = code, 26 | parameters = " ".join([str(p) for p in core.args.parameters[6:]]) 27 | ) 28 | 29 | process = subprocess.run( 30 | (*prefix, "sbatch"), input=script, text=True, capture_output=True 31 | ) 32 | print("[stdout]", process.stdout.strip()) 33 | print("[stderr]", process.stderr.strip()) 34 | print() 35 | if process.returncode != 0: 36 | print("sbatch failed") 37 | sys.exit(process.returncode) 38 | 39 | m = re.search(r"Submitted batch job (?P\d+)", process.stdout) 40 | if m is None: 41 | print("job ID not found") 42 | sys.exit(1) 43 | jobid = int(m["job_id"]) 44 | print("job ID:", jobid) 45 | 46 | with socket.create_connection(("localhost", config.slurm.monitor.port)) as client: 47 | client.send(f"register:{jobid}".encode()) 48 | while True: 49 | time.sleep(config.slurm.monitor.timer) 50 | with socket.create_connection(("localhost", config.slurm.monitor.port)) as client: 51 | client.send(f"query:{jobid}".encode()) 52 | if core.str2bool(client.recv(256).decode()): 53 | break 54 | 55 | log_path = f"{config.slurm.jobs_path}/slurm-{jobid}.out" 56 | output_path = f"{output_path}/{jobid}" 57 | core.call(f"mkdir -p {output_path}") 58 | if hasattr(config, "servers"): 59 | core.call(f"ssh {config.servers.login} mv {log_path} {config.slurm.jobs_path}/{jobid}/tmp") 60 | core.call(f"rsync -aqz {config.servers.transfer}:{config.slurm.jobs_path}/{jobid}/tmp/ {output_path}") 61 | core.call(f"ssh {config.servers.login} rm -rf {config.slurm.jobs_path}/{jobid}") 62 | else: 63 | core.call(f"mv {log_path} {config.slurm.jobs_path}/{jobid}/tmp") 64 | core.call(f"rsync -aqz {config.slurm.jobs_path}/{jobid}/tmp/ {output_path}") 65 | core.call(f"rm -rf {config.slurm.jobs_path}/{jobid}") 66 | -------------------------------------------------------------------------------- /user.yaml: -------------------------------------------------------------------------------- 1 | git_log_num_lines: 12 2 | dataset_paths:map: 3 | - [hannover, /path/to/datasets/LGN/Hannover] 4 | - [buxtehude, /path/to/datasets/LGN/Buxtehude] 5 | - [nienburg, /path/to/datasets/LGN/Nienburg] 6 | - [toulouse, /path/to/datasets/SemCityToulouse/tmp] 7 | - [toulouse_full, /path/to/datasets/SemCityToulouse/tmp] 8 | - [toulouse_pan, /path/to/datasets/SemCityToulouse/tmp] 9 | - [vaihingen, /path/to/datasets/ISPRS/Vaihingen/tmp] 10 | - [potsdam, /path/to/datasets/ISPRS/Potsdam/tmp] 11 | - [isaid, [/path/to/datasets/DOTA/v10, /path/to/datasets/iSAID]] 12 | - [dlr_landcover, /path/to/datasets/DLR-Datasets/landcover/tmp] 13 | - [dlr_roadmaps, /path/to/datasets/DLR-Datasets/roadmaps/tmp] 14 | - [dlr_roadsegmentation, /path/to/datasets/DLR-Datasets/roadsegmentation/tmp] 15 | - [hameln, /path/to/datasets/IPI/Hameln] 16 | - [schleswig, /path/to/datasets/IPI/Schleswig] 17 | - [mecklenburg_vorpommern, /path/to/datasets/IPI/MV] 18 | - [hameln_DA, /path/to/datasets/IPI/Schleswig-Hameln/preprocessed/Hameln20cm] 19 | - [schleswig_DA, /path/to/datasets/IPI/Schleswig-Hameln/preprocessed/Schleswig20cm] 20 | - [synthinel, /path/to/datasets/Synthinel-1/tmp] 21 | - [synthinel_redroof, /path/to/datasets/Synthinel-1/tmp] 22 | - [synthinel_paris, /path/to/datasets/Synthinel-1/tmp] 23 | - [synthinel_ancient, /path/to/datasets/Synthinel-1/tmp] 24 | - [synthinel_scifi, /path/to/datasets/Synthinel-1/tmp] 25 | - [synthinel_palace, /path/to/datasets/Synthinel-1/tmp] 26 | - [synthinel_austin, /path/to/datasets/Synthinel-1/tmp] 27 | - [synthinel_venice, /path/to/datasets/Synthinel-1/tmp] 28 | - [synthinel_modern, /path/to/datasets/Synthinel-1/tmp] 29 | model_weights_paths:map: 30 | - [Xception, /path/to/downloaded/model/weights/xception.pt.gz] 31 | 32 | 33 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .modelfitter import ModelFitter 2 | from .lrscheduler import LearningRateScheduler 3 | from .confusionmatrix import ConfusionMatrix 4 | from .vectorquantization import create_vector_quantization_layer 5 | from .sam import SAM 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import types 10 | 11 | 12 | def hsv2rgb(h, s, v): 13 | def f(n): 14 | k = (n + 6*h) % 6 15 | return v - v*s*max(0,min(k,4-k,1)) 16 | return round(255*f(5)), round(255*f(3)), round(255*f(1)) 17 | 18 | 19 | def relu_wrapper(relu_type): 20 | if type(relu_type) == list: 21 | relu_params = relu_type[1:] 22 | relu_type = relu_type[0] 23 | else: 24 | relu_params = [] 25 | relu_func = getattr(nn, relu_type) 26 | if relu_type == "PReLU" and len(relu_params) > 0 and relu_params[0] != 1: 27 | return lambda x: relu_func(x) 28 | else: 29 | return lambda _: relu_func(*relu_params) 30 | 31 | 32 | def norm_wrapper(norm_type): 33 | if norm_type == "LayerNormNd": 34 | return lambda x: nn.GroupNorm(1, x) 35 | elif norm_type == "PixelNorm2d": 36 | class PixelNorm2d(nn.Module): 37 | def __init__(self, k): 38 | super().__init__() 39 | self.alpha = nn.Parameter(torch.ones((k, 1, 1))) 40 | self.beta = nn.Parameter(torch.zeros((k, 1, 1))) 41 | 42 | def forward(self, x): 43 | d = x - x.mean(1, keepdim=True) 44 | s = d.pow(2).mean(1, keepdim=True) 45 | x = d / torch.sqrt(s + 10**-6) 46 | return self.alpha * x + self.beta 47 | 48 | return PixelNorm2d 49 | return getattr(nn, norm_type) 50 | 51 | def optim_wrapper(optim_type): 52 | if optim_type == "SAM": 53 | return SAM 54 | else: 55 | return getattr(torch.optim, optim_type) 56 | 57 | def get_mini_batch_iterator(mini_batch_size, return_indices=False): 58 | class MiniBatchIterator(): 59 | def __init__(self): 60 | self.mini_batch_size = mini_batch_size 61 | self.return_indices = return_indices 62 | 63 | def __call__(self, data, *data2): 64 | if len(data2) > 0: 65 | for i in data2: 66 | assert data.shape[0] == i.shape[0] 67 | with torch.no_grad(): 68 | for j in range(0, data.shape[0], self.mini_batch_size): 69 | batch_data = data[j:j+self.mini_batch_size] 70 | if len(data2) == 0 and not self.return_indices: 71 | yield batch_data 72 | else: 73 | if self.return_indices: 74 | batch_data = [np.asarray(range(j,j+self.mini_batch_size))[:batch_data.shape[0]], batch_data] 75 | else: 76 | batch_data = [batch_data] 77 | batch_data.extend([ 78 | i[j:j+self.mini_batch_size] for i in data2 79 | ]) 80 | yield tuple(batch_data) 81 | 82 | return MiniBatchIterator() 83 | 84 | def get_scheduler(config): 85 | if isinstance(config, dict): 86 | config = types.SimpleNamespace(**config) 87 | 88 | class Scheduler(): 89 | def __init__(self): 90 | params = ",".join(config.parameters) 91 | self.func = eval(f"lambda {params}: {config.func}") 92 | self.value_range = config.value_range 93 | self.value = self.value_range[0] 94 | 95 | def __call__(self, *params): 96 | self.value = self.func(*params) 97 | self.value = (1-self.value)*self.value_range[0] + self.value*self.value_range[1] 98 | 99 | return Scheduler() 100 | -------------------------------------------------------------------------------- /utils/confusionmatrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import types 3 | import rust 4 | 5 | 6 | class ConfusionMatrix(): 7 | def __init__(self, num_classes, ignore_class): 8 | assert num_classes > 1 9 | self.C = np.empty([num_classes, num_classes], dtype=np.uint64) 10 | self.ignore_class = ignore_class 11 | 12 | def reset(self): 13 | self.C[:] = 0 14 | 15 | def add(self, yt, yp): 16 | rust.add_to_conf_mat(self.C, yt, yp) 17 | 18 | def compute_metrics(self): 19 | result = {} 20 | 21 | for c in range(self.C.shape[0]): 22 | tp = self.C[c,c] 23 | fp = np.sum(self.C[:,c]) - tp 24 | if self.ignore_class >= 0 and c != self.ignore_class: 25 | fp -= self.C[self.ignore_class,c] 26 | fn = np.sum(self.C[c,:]) - tp 27 | 28 | result[f"iou{c}"] = float(tp / max(tp+fp+fn,1)) 29 | result[f"p{c}"] = float(tp / max(tp+fp,1)) 30 | result[f"r{c}"] = float(tp / max(tp+fn,1)) 31 | result[f"f1_{c}"] = float((2*tp) / max(2*tp+fp+fn,1)) 32 | 33 | result["acc"] = float(np.sum(np.diag(self.C)) / max(np.sum(self.C),1)) 34 | result["miou"] = float(np.mean([result[f"iou{c}"] for c in range(self.C.shape[0]) if c != self.ignore_class])) 35 | result["mp"] = float(np.mean([result[f"p{c}"] for c in range(self.C.shape[0]) if c != self.ignore_class])) 36 | result["mr"] = float(np.mean([result[f"r{c}"] for c in range(self.C.shape[0]) if c != self.ignore_class])) 37 | result["mf1"] = float(np.mean([result[f"f1_{c}"] for c in range(self.C.shape[0]) if c != self.ignore_class])) 38 | 39 | return types.SimpleNamespace(**result) 40 | 41 | def to_dict(self): 42 | return {c: [int(self.C[c,c2]) for c2 in range(self.C.shape[1])] for c in range(self.C.shape[0])} 43 | 44 | @staticmethod 45 | def from_dict(C, ignore_class): 46 | conf_mat = ConfusionMatrix(len(C), ignore_class) 47 | for c, cs in C.items(): 48 | c = int(c) 49 | assert 0 <= c < len(C) 50 | assert type(cs) == list 51 | assert len(cs) == len(C) 52 | for c2, n in enumerate(cs): 53 | assert type(n) == int 54 | assert 0 <= n 55 | conf_mat.C[c,c2] = n 56 | return conf_mat 57 | -------------------------------------------------------------------------------- /utils/lrscheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class LearningRateScheduler(): 6 | def __init__(self, optimizer, min_learning_rate, num_cycles, cycle_length_factor, num_iterations): 7 | self.index = 0 8 | self.optimizer = optimizer 9 | self.min_learning_rate = min_learning_rate 10 | 11 | self.cycles = LearningRateScheduler.get_cycle_breakpoints( 12 | num_cycles, cycle_length_factor, num_iterations 13 | ) 14 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 15 | self.optimizer, self.cycles[0], eta_min=self.min_learning_rate 16 | ) 17 | 18 | @staticmethod 19 | def get_cycle_breakpoints(num_cycles, cycle_length_factor, num_iterations): 20 | b = num_iterations / np.sum([cycle_length_factor**i for i in range(num_cycles)]) 21 | cycles = [round(b*cycle_length_factor**i) for i in range(num_cycles)] 22 | cycles = [np.sum(cycles[:i+1]) for i in range(num_cycles)] 23 | cycles[-1] = num_iterations 24 | return np.asarray(cycles) - 1 25 | 26 | def get_lr(self): 27 | return self.scheduler.get_last_lr()[0] 28 | 29 | def step(self, step_num): 30 | if step_num < self.cycles[self.index]: 31 | self.scheduler.step() 32 | elif self.index+1 < len(self.cycles): 33 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 34 | self.optimizer, (self.cycles[self.index+1] - self.cycles[self.index]) - 1, eta_min=self.min_learning_Rate 35 | ) 36 | self.index += 1 37 | -------------------------------------------------------------------------------- /utils/modelfitter.py: -------------------------------------------------------------------------------- 1 | import core 2 | import os 3 | import time 4 | import datetime 5 | import numpy as np 6 | import json 7 | import concurrent.futures 8 | import queue 9 | import numbers 10 | import traceback 11 | import types 12 | import torch 13 | 14 | 15 | class Queue(queue.Queue): 16 | def put(self, item): 17 | try: 18 | super().put(item, timeout=2) 19 | return False 20 | except queue.Full: 21 | return True 22 | 23 | def get(self): 24 | try: 25 | return super().get(timeout=2) 26 | except queue.Empty: 27 | return None 28 | 29 | 30 | class ModelFitter(): 31 | def __init__(self, config): 32 | self.config = config 33 | self.output_float_precision = getattr(config, "output_float_precision", 4) 34 | self.output_set = set() 35 | self.history_filename = getattr(config, "history_filename", "history.json.bz2") 36 | self.history = {} 37 | if core.device.type == "cuda": 38 | cuda_total_mem = torch.cuda.mem_get_info(core.device)[1] 39 | self.history["cuda_total_mem"] = cuda_total_mem, round(10 * cuda_total_mem/1024**2) / 10 40 | max_queue_size = getattr(config, "max_queue_size", 4) 41 | self.max_queue_size = (max_queue_size - 1) if max_queue_size > 1 else 1 42 | 43 | def pre_epoch(self, epoch): 44 | pass 45 | 46 | def pre_evaluate(self, epoch): 47 | pass 48 | 49 | def pre_train(self, epoch, batch, batch_data): 50 | pass 51 | 52 | def train(self, epoch, batch, batch_data, metrics): 53 | pass 54 | 55 | def post_train(self, epoch, batch, batch_data, metrics): 56 | pass 57 | 58 | def post_evaluate(self, epoch): 59 | pass 60 | 61 | def post_epoch(self, epoch, metrics): 62 | pass 63 | 64 | def finalize(self): 65 | pass 66 | 67 | def run(self): 68 | epochs = self.config.terminate_early if getattr(self.config,"terminate_early",-1)>0 else self.config.epochs 69 | for epoch in range(epochs): 70 | epoch_timestamp = time.perf_counter() 71 | print(f"starting epoch {epoch+1} at", datetime.datetime.now().strftime("%Y-%b-%d, %H:%M:%S")) 72 | self.pre_epoch(epoch) 73 | self._fit_epoch(epoch) 74 | metrics = {key: np.mean(self.history[key][-self.num_mini_batches:]) if isinstance(self.history[key][-1], numbers.Number) and not isinstance(self.history[key][-1], bool) and "loss" in key else self.history[key][-1] for key in self._metrics_keys} 75 | metrics = types.SimpleNamespace(**metrics) 76 | self.post_epoch(epoch, metrics) 77 | if core.device.type == "cuda": 78 | metrics.cuda_free_mem = torch.cuda.mem_get_info(core.device)[0] 79 | metrics.cuda_free_mem = metrics.cuda_free_mem, round(10 * metrics.cuda_free_mem/1024**2) / 10 80 | metrics = metrics.__dict__ 81 | for key in self._metrics_keys: 82 | del metrics[key] 83 | for key, value in metrics.items(): 84 | if not key in self.history: 85 | self.history[key] = [] 86 | self.history[key].append(value) 87 | self._save_history_to_file() 88 | epoch_timestamp = time.perf_counter() - epoch_timestamp 89 | self._progress(epoch, -1, epoch_timestamp, metrics) 90 | self.finalize() 91 | self._save_history_to_file() 92 | 93 | def evaluate(self, epoch, **params): 94 | num_train_mini_batches = self.num_mini_batches 95 | self.num_mini_batches = 0 96 | torch.set_grad_enabled(False) 97 | self.eval_params = types.SimpleNamespace(enabled=True, **params) 98 | self.pre_evaluate(epoch) 99 | with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: 100 | try: 101 | self._terminate_workers = False 102 | q_train = Queue(maxsize=self.max_queue_size) 103 | executor.submit(self._get_batches, epoch, q_train) 104 | q_metrics = Queue(maxsize=self.max_queue_size) 105 | executor.submit(self._train, epoch, q_train, q_metrics) 106 | for batch in range(self.num_mini_batches): 107 | data = q_metrics.get() 108 | while not data: 109 | if self._terminate_workers: 110 | raise RuntimeError("terminating main thread") 111 | data = q_metrics.get() 112 | data, metrics = data 113 | self.post_train(epoch, batch, data, metrics) 114 | except: 115 | traceback.print_exc() 116 | self._terminate_workers = True 117 | if self._terminate_workers: 118 | raise RuntimeError("terminated") 119 | self.post_evaluate(epoch) 120 | torch.set_grad_enabled(True) 121 | self.num_mini_batches = num_train_mini_batches 122 | 123 | def _fit_epoch(self, epoch): 124 | self._last_line_length = None 125 | self._metrics_keys = set() 126 | self.eval_params = types.SimpleNamespace(enabled=False) 127 | with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: 128 | try: 129 | self._terminate_workers = False 130 | q_train = Queue(maxsize=self.max_queue_size) 131 | executor.submit(self._get_batches, epoch, q_train) 132 | q_metrics = Queue(maxsize=self.max_queue_size) 133 | executor.submit(self._train, epoch, q_train, q_metrics) 134 | batch_timestamp = time.perf_counter() 135 | for batch in range(self.num_mini_batches): 136 | data = q_metrics.get() 137 | while not data: 138 | if self._terminate_workers: 139 | raise RuntimeError("terminating main thread") 140 | data = q_metrics.get() 141 | data, metrics = data 142 | self.post_train(epoch, batch, data, metrics) 143 | metrics = metrics.__dict__ 144 | for key, value in metrics.items(): 145 | if not key in self._metrics_keys: 146 | self._metrics_keys.add(key) 147 | if not key in self.history: 148 | self.history[key] = [] 149 | values = self.history[key] 150 | values.append(value) 151 | if isinstance(value, numbers.Number) and not isinstance(value, bool): 152 | if "loss" in key: 153 | metrics[key] = np.mean(values[-(batch+1):]) 154 | self._progress(epoch, batch, (time.perf_counter() - batch_timestamp) / (batch + 1), metrics) 155 | except: 156 | traceback.print_exc() 157 | self._terminate_workers = True 158 | if self._terminate_workers: 159 | raise RuntimeError("terminated") 160 | print("") 161 | 162 | def _get_batches(self, epoch, q): 163 | if self.eval_params.enabled: 164 | torch.set_grad_enabled(False) 165 | try: 166 | for batch in range(self.num_mini_batches): 167 | data = types.SimpleNamespace() 168 | self.pre_train(epoch, batch, data) 169 | while q.put(data): 170 | if self._terminate_workers: 171 | raise RuntimeError("terminating first worker") 172 | except: 173 | traceback.print_exc() 174 | self._terminate_workers = True 175 | if self.eval_params.enabled: 176 | torch.set_grad_enabled(True) 177 | 178 | def _train(self, epoch, q_in, q_out): 179 | if self.eval_params.enabled: 180 | torch.set_grad_enabled(False) 181 | try: 182 | for batch in range(self.num_mini_batches): 183 | data = q_in.get() 184 | while not data: 185 | if self._terminate_workers: 186 | raise RuntimeError("terminating second worker during get()") 187 | data = q_in.get() 188 | metrics = types.SimpleNamespace() 189 | self.train(epoch, batch, data, metrics) 190 | while q_out.put((data, metrics)): 191 | if self._terminate_workers: 192 | raise RuntimeError("terminating second worker during put()") 193 | except: 194 | traceback.print_exc() 195 | self._terminate_workers = True 196 | if self.eval_params.enabled: 197 | torch.set_grad_enabled(True) 198 | 199 | def _progress(self, epoch, batch, elapsed_time, metrics): 200 | epochs = self.config.terminate_early if getattr(self.config,"terminate_early",-1)>0 else self.config.epochs 201 | s = f"epoch = {epoch+1}/{epochs}" 202 | if batch >= 0: 203 | s = f"{s}, iteration = {batch+1}/{self.num_mini_batches}" 204 | if elapsed_time < 1: 205 | elapsed_time *= 1000 206 | unit = "ms" 207 | if elapsed_time < 1: 208 | elapsed_time *= 1000 209 | unit = "us" 210 | elapsed_time = round(elapsed_time) 211 | s = f"{s}, time = {elapsed_time}{unit}" 212 | elif elapsed_time < 60: 213 | s = f"{s}, time = {elapsed_time:.2f}s" 214 | else: 215 | elapsed_time /= 60 216 | unit = "min" 217 | if elapsed_time >= 60: 218 | elapsed_time /= 60 219 | unit = "h" 220 | if elapsed_time >= 24: 221 | elapsed_time /= 24 222 | unit = "d" 223 | s = f"{s}, time = {elapsed_time:.1f}{unit}" 224 | for key, value in metrics.items(): 225 | if len(self.output_set) > 0 and key not in self.output_set: 226 | continue 227 | t = type(value) 228 | if t == float or t == np.float32 or t == np.float64: 229 | format_s = f"%s, %s = %.{self.output_float_precision}f" 230 | s = format_s%(s, key, value) 231 | else: 232 | s = f"{s}, {key} = {value}" 233 | if self._last_line_length and batch >= 0 and len(s) < self._last_line_length: 234 | s += (self._last_line_length-len(s))*" " 235 | print(s, end="\r" if batch >= 0 else "\n\n") 236 | self._last_line_length = len(s) 237 | 238 | def _save_history_to_file(self): 239 | if core.output_path == None or self.history_filename == None: 240 | return 241 | 242 | filename = f"{core.output_path}/{self.history_filename}" 243 | print(f"saving history to '{filename}'") 244 | with core.open(filename, "wb") as f: 245 | f.write(json.dumps(self.history).encode("utf-8")) 246 | -------------------------------------------------------------------------------- /utils/preprocess/dlr_landcover.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import tifffile 4 | import numpy as np 5 | 6 | home = os.environ["HOME"] 7 | with open(f"{home}/.aethon/user.yaml", "r") as f: 8 | user_config = yaml.full_load("\n".join(f.readlines())) 9 | 10 | for dataset_path in user_config["dataset_paths:map"]: 11 | if dataset_path[0] == "dlr_landcover": 12 | dataset_path = dataset_path[1] 13 | break 14 | 15 | for city, city_gt, swap in (("Berlin","berlin",False), ("Munich","munich",True)): 16 | sar = tifffile.imread(f"{dataset_path}/images/{city}_s1.tif") 17 | img = tifffile.imread(f"{dataset_path}/images/{city}_s2.tif") 18 | assert sar.dtype == img.dtype 19 | if swap: 20 | img = np.moveaxis(img, 0, 2) 21 | assert sar.shape == img.shape[:2] 22 | new_img = np.empty((img.shape[0]*(img.shape[2]+1), img.shape[1]), dtype=img.dtype) 23 | for i in range(img.shape[2]): 24 | offset = img.shape[0] * i 25 | new_img[offset:offset+img.shape[0]] = img[:,:,i] 26 | offset = img.shape[0] * img.shape[2] 27 | new_img[offset:offset+img.shape[0]] = sar 28 | tifffile.imwrite(f"{dataset_path}/images/{city}_converted.tif", new_img, compression="zlib") 29 | img = tifffile.imread(f"{dataset_path}/annotations/{city_gt}_anno.tif") 30 | tifffile.imwrite(f"{dataset_path}/annotations/{city}_converted.tif", img, compression="zlib") 31 | -------------------------------------------------------------------------------- /utils/preprocess/model_weights.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import gzip 3 | import torch 4 | 5 | for modelname, filename in (("mobilenetv2_100", "mobilenetv2.pt.gz"), ("legacy_xception", "xception.pt.gz")): 6 | weights = timm.create_model(modelname, pretrained=True).state_dict() 7 | with gzip.open(filename, "wb") as f: 8 | torch.save(weights, f) 9 | -------------------------------------------------------------------------------- /utils/preprocess/toulouse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import subprocess 4 | import glob 5 | import tifffile 6 | import numpy as np 7 | 8 | home = os.environ["HOME"] 9 | with open(f"{home}/.aethon/user.yaml", "r") as f: 10 | user_config = yaml.full_load("\n".join(f.readlines())) 11 | 12 | for dataset_path in user_config["dataset_paths:map"]: 13 | if dataset_path[0] == "toulouse": 14 | dataset_path = dataset_path[1] 15 | break 16 | 17 | dataset_path = f"{dataset_path}/img_multispec_05/TLS_BDSD_M" 18 | converted_files_path = f"{dataset_path}/bands_to_rows" 19 | subprocess.call(f"mkdir -p {converted_files_path}", shell=True) 20 | 21 | for fn in glob.iglob(f"{dataset_path}/TLS_BDSD_M_*.tif"): 22 | img = tifffile.imread(fn) 23 | new_img = np.empty((img.shape[0]*img.shape[2], img.shape[1]), dtype=img.dtype) 24 | actual_fn = fn.split("/")[-1] 25 | print(f"{actual_fn}:", img.shape, img.dtype, "->", new_img.shape, new_img.dtype) 26 | for channel in range(img.shape[2]): 27 | offset = img.shape[0] * channel 28 | new_img[offset:offset+img.shape[0]] = img[:,:,channel] 29 | tifffile.imwrite(f"{converted_files_path}/{actual_fn}", new_img, compression="zlib") 30 | 31 | dataset_path = "/".join(dataset_path.split("/")[:-2]) 32 | dataset_path = f"{dataset_path}/instances_building_05/TLS_instances_building_indMap" 33 | converted_files_path = f"{dataset_path}/u16/" 34 | subprocess.call(f"mkdir -p {converted_files_path}", shell=True) 35 | 36 | next_id = 1 37 | for fn in glob.iglob(f"{dataset_path}/*.tif"): 38 | img = tifffile.imread(fn) 39 | new_img = np.empty(img.shape, dtype=np.uint16) 40 | for instance_id in np.unique(img): 41 | if instance_id == 0: 42 | new_img[img==instance_id] = 0 43 | else: 44 | new_img[img==instance_id] = next_id 45 | next_id += 1 46 | actual_fn = fn.split("/")[-1] 47 | print(actual_fn, next_id-1) 48 | tifffile.imwrite(f"{converted_files_path}/{actual_fn}", new_img, compression="zlib") 49 | -------------------------------------------------------------------------------- /utils/sam.py: -------------------------------------------------------------------------------- 1 | # copy of sam.py from https://github.com/davda54/sam [1,2] 2 | # minor change: base_optimizer defaults to SGD as used in https://github.com/val-iisc/SDAT [3] 3 | # 4 | # paper references: 5 | # [1] https://arxiv.org/abs/2010.01412 6 | # [2] https://arxiv.org/abs/2102.11600 7 | # [3] https://arxiv.org/abs/2206.08213 8 | # 9 | import torch 10 | 11 | 12 | class SAM(torch.optim.Optimizer): 13 | def __init__(self, params, base_optimizer=torch.optim.SGD, rho=0.05, adaptive=False, **kwargs): 14 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 15 | 16 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) 17 | super(SAM, self).__init__(params, defaults) 18 | 19 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 20 | self.param_groups = self.base_optimizer.param_groups 21 | self.defaults.update(self.base_optimizer.defaults) 22 | 23 | @torch.no_grad() 24 | def first_step(self, zero_grad=False): 25 | grad_norm = self._grad_norm() 26 | for group in self.param_groups: 27 | scale = group["rho"] / (grad_norm + 1e-12) 28 | 29 | for p in group["params"]: 30 | if p.grad is None: continue 31 | self.state[p]["old_p"] = p.data.clone() 32 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 33 | p.add_(e_w) # climb to the local maximum "w + e(w)" 34 | 35 | if zero_grad: self.zero_grad() 36 | 37 | @torch.no_grad() 38 | def second_step(self, zero_grad=False): 39 | for group in self.param_groups: 40 | for p in group["params"]: 41 | if p.grad is None: continue 42 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 43 | 44 | self.base_optimizer.step() # do the actual "sharpness-aware" update 45 | 46 | if zero_grad: self.zero_grad() 47 | 48 | @torch.no_grad() 49 | def step(self, closure=None): 50 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 51 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 52 | 53 | self.first_step(zero_grad=True) 54 | closure() 55 | self.second_step() 56 | 57 | def _grad_norm(self): 58 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 59 | norm = torch.norm( 60 | torch.stack([ 61 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 62 | for group in self.param_groups for p in group["params"] 63 | if p.grad is not None 64 | ]), 65 | p=2 66 | ) 67 | return norm 68 | 69 | def load_state_dict(self, state_dict): 70 | super().load_state_dict(state_dict) 71 | self.base_optimizer.param_groups = self.param_groups 72 | --------------------------------------------------------------------------------