├── .gitignore ├── LICENSE ├── README.md ├── configs ├── multi-request-multi-support │ ├── mrms_allnorm.yml │ ├── mrms_occdeg.yml │ ├── mrms_randcom.yml │ ├── mrms_when2com.yml │ └── mrms_who2com.yml └── single-request-multiple-support │ ├── srms_allnorm.yml │ ├── srms_occdeg.yml │ ├── srms_randcom.yml │ ├── srms_when2com.yml │ └── srms_who2com.yml ├── ptsemseg ├── __init__.py ├── augmentations │ ├── __init__.py │ └── augmentations.py ├── loader │ ├── __init__.py │ └── airsim_loader.py ├── loss │ ├── __init__.py │ └── loss.py ├── metrics.py ├── models │ ├── __init__.py │ ├── agent.py │ ├── backbone.py │ └── utils.py ├── optimizers │ └── __init__.py ├── probe.py ├── process_img.py ├── schedulers │ ├── __init__.py │ └── schedulers.py ├── trainer.py └── utils.py ├── requirements.txt ├── teaser ├── 1359-teaser.gif └── pytorch-logo-dark.png ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Torch Models 7 | *.pkl 8 | *.pth 9 | current_train.py 10 | video_test*.py 11 | *.swp 12 | data 13 | ckpt 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | local_test.py 21 | .DS_STORE 22 | .idea/ 23 | .vscode/ 24 | env/ 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *,cover 59 | .hypothesis/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # IPython Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # dotenv 92 | .env 93 | 94 | # virtualenv 95 | venv/ 96 | ENV/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | runs 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 GT-RIPL, Yen-Cheng Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## When2com: Multi-Agent Perception via Communication Graph Grouping 2 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 3 | 4 | This is the PyTorch implementation of our paper:
5 | **When2com: Multi-Agent Perception via Communication Graph Grouping**
6 | [__***Yen-Cheng Liu***__](https://ycliu93.github.io/), [Junjiao Tian](https://www.linkedin.com/in/junjiao-tian-42b9758a/), [Nathaniel Glaser](https://sites.google.com/view/nathanglaser/), [Zsolt Kira](https://www.cc.gatech.edu/~zk15/)
7 | IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2020
8 | 9 | 10 | [[Paper](http://openaccess.thecvf.com/content_CVPR_2020/papers/Liu_When2com_Multi-Agent_Perception_via_Communication_Graph_Grouping_CVPR_2020_paper.pdf)] [[GitHub](https://github.gatech.edu/RIPL/multi-agent-perception)] [[Project](https://ycliu93.github.io/projects/multi-agent-perception.html)] 11 | 12 |

13 | 14 |

15 | 16 | ## Prerequisites 17 | - Python 3.6 18 | - Pytorch 0.4.1 19 | - Other required packages in `requirement.txt` 20 | 21 | 22 | ## Getting started 23 | ### Download and install miniconda 24 | ``` 25 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 26 | bash Miniconda3-latest-Linux-x86_64.sh 27 | ``` 28 | 29 | ### Create conda environment 30 | ``` 31 | conda create -n semseg python=3.6 32 | source actviate semseg 33 | ``` 34 | 35 | ### Install the required packages 36 | ``` 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | ### Download AirSim-MAP dataset and unzip it. 41 | - Download the zip file you would like to run 42 | 43 | [![Alt text](https://ycliu93.github.io/projects/cvpr20_assets/airsim_map.png)](https://gtvault-my.sharepoint.com/:f:/g/personal/yliu3133_gatech_edu/Ett0G1_5YYdBpgojk0uWESgBi95dO79LkbYaKRhlBIkVJQ?e=vdjklb/) 44 | 45 | 46 | ### Move the datasets to the dataset path 47 | ``` 48 | mkdir dataset 49 | mv (dataset folder name) dataset/ 50 | ``` 51 | 52 | ### Training 53 | ``` 54 | # [Single-request multi-support] All norm 55 | python train.py --config configs/srms-allnorm.yml --gpu=0 56 | 57 | # [Multi-request multi-support] when2com model 58 | python train.py --config configs/mrms-when2com.yml --gpu=0 59 | 60 | ``` 61 | 62 | ### Testing 63 | ``` 64 | # [Single-request multi-support] All norm 65 | python test.py --config configs/srms-allnorm.yml --model_path --gpu=0 66 | 67 | # [Multi-request multi-support] when2com model 68 | python test.py --config configs/mrms-when2com.yml --model_path --gpu=0 69 | ``` 70 | 71 | ## Acknowledgments 72 | - This work was supported by ONR grant N00014-18-1-2829. 73 | - This code is built upon the implementation from [Pytorch-semseg](https://github.com/meetshah1995/pytorch-semseg). 74 | 75 | ## Citation 76 | If you find this repository useful, please cite our paper: 77 | 78 | ``` 79 | @inproceedings{liu2020when2com, 80 | title={When2com: Multi-Agent Perception via Communication Graph Grouping}, 81 | author={Yen-Cheng Liu and Junjiao Tian and Nathaniel Glaser and Zsolt Kira}, 82 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 83 | year={2020} 84 | } 85 | ``` 86 | -------------------------------------------------------------------------------- /configs/multi-request-multi-support/mrms_allnorm.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: Single_agent 3 | shuffle_features: None 4 | agent_num: 6 5 | enc_backbone: resnet_encoder 6 | dec_backbone: simple_decoder 7 | feat_squeezer: -1 8 | feat_channel: 512 9 | multiple_output: True 10 | data: 11 | dataset: airsim 12 | train_split: train 13 | val_split: val 14 | test_split: test 15 | img_rows: 512 16 | img_cols: 512 17 | path: dataset/airsim-mrms-data 18 | noisy_type: None 19 | target_view: '6agent' 20 | training: 21 | train_iters: 12 22 | batch_size: 2 23 | val_interval: 6 24 | n_workers: 4 25 | print_interval: 2 26 | optimizer: 27 | name: 'adam' 28 | lr: 1.0e-5 29 | loss: 30 | name: 'cross_entropy' 31 | size_average: True 32 | lr_schedule: 33 | resume: None 34 | -------------------------------------------------------------------------------- /configs/multi-request-multi-support/mrms_occdeg.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: Single_agent 3 | shuffle_features: None 4 | agent_num: 6 5 | enc_backbone: resnet_encoder 6 | dec_backbone: simple_decoder 7 | feat_squeezer: -1 8 | feat_channel: 512 9 | multiple_output: True 10 | data: 11 | dataset: airsim 12 | train_split: train 13 | val_split: val 14 | test_split: test 15 | img_rows: 512 16 | img_cols: 512 17 | path: dataset/airsim-mrms-noise-data 18 | noisy_type: None 19 | target_view: '6agent' 20 | training: 21 | train_iters: 200000 22 | batch_size: 2 23 | val_interval: 1000 24 | n_workers: 4 25 | print_interval: 50 26 | optimizer: 27 | name: 'adam' 28 | lr: 1.0e-5 29 | loss: 30 | name: 'cross_entropy' 31 | size_average: True 32 | lr_schedule: 33 | resume: None 34 | -------------------------------------------------------------------------------- /configs/multi-request-multi-support/mrms_randcom.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: MIMO_All_agents 3 | shuffle_features: 'selection' 4 | agent_num: 6 5 | enc_backbone: resnet_encoder 6 | dec_backbone: simple_decoder 7 | feat_squeezer: -1 8 | feat_channel: 512 9 | multiple_output: True 10 | data: 11 | dataset: airsim 12 | train_split: train 13 | val_split: val 14 | test_split: test 15 | img_rows: 512 16 | img_cols: 512 17 | path: dataset/airsim-mrms-noise-data 18 | noisy_type: None 19 | target_view: '6agent' 20 | commun_label: 'mimo' 21 | training: 22 | train_iters: 200000 23 | batch_size: 1 24 | val_interval: 1000 25 | n_workers: 4 26 | print_interval: 50 27 | optimizer: 28 | name: 'adam' 29 | lr: 1.0e-5 30 | loss: 31 | name: 'cross_entropy' 32 | size_average: True 33 | lr_schedule: 34 | resume: None 35 | -------------------------------------------------------------------------------- /configs/multi-request-multi-support/mrms_when2com.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: MIMOcom 3 | agent_num: 6 4 | shared_policy: True 5 | shared_img_encoder: 'unified' 6 | attention: 'general' 7 | sparse: False 8 | query: True 9 | query_size: 32 10 | key_size: 1024 11 | enc_backbone: resnet_encoder 12 | dec_backbone: simple_decoder 13 | feat_squeezer: -1 14 | feat_channel: 512 15 | multiple_output: True 16 | data: 17 | dataset: airsim 18 | train_split: train 19 | val_split: val 20 | test_split: test 21 | img_rows: 512 22 | img_cols: 512 23 | path: dataset/airsim-mrms-noise-data 24 | noisy_type: None 25 | target_view: '6agent' 26 | commun_label: 'mimo' 27 | training: 28 | train_iters: 200000 29 | batch_size: 2 30 | val_interval: 1000 31 | n_workers: 8 32 | print_interval: 50 33 | optimizer: 34 | name: 'adam' 35 | lr: 1.0e-5 36 | loss: 37 | name: 'cross_entropy' 38 | size_average: True 39 | lr_schedule: 40 | resume: None 41 | -------------------------------------------------------------------------------- /configs/multi-request-multi-support/mrms_who2com.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: MIMOcomWho 3 | agent_num: 6 4 | shared_policy: True 5 | shared_img_encoder: 'unified' 6 | attention: 'general' 7 | sparse: False 8 | query: False 9 | query_size: 32 10 | key_size: 1024 11 | enc_backbone: resnet_encoder 12 | dec_backbone: simple_decoder 13 | feat_squeezer: -1 14 | feat_channel: 512 15 | multiple_output: True 16 | data: 17 | dataset: airsim 18 | train_split: train 19 | val_split: val 20 | test_split: test 21 | img_rows: 512 22 | img_cols: 512 23 | path: dataset/airsim-mrms-noise-data 24 | noisy_type: None 25 | target_view: '6agent' 26 | commun_label: 'mimo' 27 | training: 28 | train_iters: 200000 29 | batch_size: 2 30 | val_interval: 1000 31 | n_workers: 8 32 | print_interval: 50 33 | optimizer: 34 | name: 'adam' 35 | lr: 1.0e-5 36 | loss: 37 | name: 'cross_entropy' 38 | size_average: True 39 | lr_schedule: 40 | resume: None 41 | -------------------------------------------------------------------------------- /configs/single-request-multiple-support/srms_allnorm.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: Single_agent 3 | shuffle_features: None 4 | agent_num: 5 5 | enc_backbone: resnet_encoder 6 | dec_backbone: simple_decoder 7 | feat_squeezer: -1 8 | feat_channel: 512 9 | multiple_output: False 10 | data: 11 | dataset: airsim 12 | train_split: train 13 | val_split: val 14 | test_split: test 15 | img_rows: 512 16 | img_cols: 512 17 | path: dataset/airsim-srms-data 18 | noisy_type: None 19 | target_view: 'target' 20 | commun_label: 'None' 21 | training: 22 | train_iters: 200000 23 | batch_size: 2 24 | val_interval: 1000 25 | n_workers: 8 26 | print_interval: 50 27 | optimizer: 28 | name: 'adam' 29 | lr: 1.0e-5 30 | loss: 31 | name: 'cross_entropy' 32 | size_average: True 33 | lr_schedule: 34 | resume: None 35 | -------------------------------------------------------------------------------- /configs/single-request-multiple-support/srms_occdeg.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: Single_agent 3 | shuffle_features: None 4 | agent_num: 5 5 | enc_backbone: resnet_encoder 6 | dec_backbone: simple_decoder 7 | feat_squeezer: -1 8 | feat_channel: 512 9 | multiple_output: False 10 | data: 11 | dataset: airsim 12 | train_split: train 13 | val_split: val 14 | test_split: test 15 | img_rows: 512 16 | img_cols: 512 17 | path: dataset/airsim-srms-noise-data 18 | noisy_type: None 19 | target_view: 'target' 20 | commun_label: 'None' 21 | training: 22 | train_iters: 200000 23 | batch_size: 2 24 | val_interval: 1000 25 | n_workers: 8 26 | print_interval: 50 27 | optimizer: 28 | name: 'adam' 29 | lr: 1.0e-5 30 | loss: 31 | name: 'cross_entropy' 32 | size_average: True 33 | lr_schedule: 34 | resume: None 35 | -------------------------------------------------------------------------------- /configs/single-request-multiple-support/srms_randcom.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: All_agents 3 | shuffle_features: 'selection' 4 | agent_num: 5 5 | enc_backbone: resnet_encoder 6 | dec_backbone: simple_decoder 7 | feat_squeezer: -1 8 | feat_channel: 512 9 | multiple_output: False 10 | data: 11 | dataset: airsim 12 | train_split: train 13 | val_split: val 14 | test_split: test 15 | img_rows: 512 16 | img_cols: 512 17 | path: dataset/airsim-srms-noise-data 18 | noisy_type: None 19 | target_view: 'target' 20 | commun_label: 'when2com' 21 | training: 22 | train_iters: 200000 23 | batch_size: 2 24 | val_interval: 1000 25 | n_workers: 8 26 | print_interval: 50 27 | optimizer: 28 | name: 'adam' 29 | lr: 1.0e-5 30 | loss: 31 | name: 'cross_entropy' 32 | size_average: True 33 | lr_schedule: 34 | resume: None 35 | -------------------------------------------------------------------------------- /configs/single-request-multiple-support/srms_when2com.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: LearnWhen2Com 3 | agent_num: 5 4 | shared_policy: True 5 | shared_img_encoder: 'unified' 6 | attention: 'general' 7 | sparse: False 8 | query: True 9 | query_size: 8 10 | key_size: 1024 11 | enc_backbone: resnet_encoder 12 | dec_backbone: simple_decoder 13 | feat_squeezer: -1 14 | feat_channel: 512 15 | multiple_output: False 16 | data: 17 | dataset: airsim 18 | train_split: train 19 | val_split: val 20 | test_split: test 21 | img_rows: 512 22 | img_cols: 512 23 | path: dataset/airsim-srms-noise-data 24 | noisy_type: None 25 | target_view: 'target' 26 | commun_label: 'when2com' 27 | training: 28 | train_iters: 200000 29 | batch_size: 2 30 | val_interval: 1000 31 | n_workers: 8 32 | print_interval: 50 33 | optimizer: 34 | name: 'adam' 35 | lr: 1.0e-5 36 | loss: 37 | name: 'cross_entropy' 38 | size_average: True 39 | lr_schedule: 40 | resume: None 41 | -------------------------------------------------------------------------------- /configs/single-request-multiple-support/srms_who2com.yml: -------------------------------------------------------------------------------- 1 | model: 2 | arch: LearnWho2Com 3 | agent_num: 5 4 | shared_policy: True 5 | shared_img_encoder: 'only_normal_agents' 6 | attention: 'general' 7 | sparse: False 8 | query: True 9 | query_size: 8 10 | key_size: 1024 11 | enc_backbone: resnet_encoder 12 | dec_backbone: simple_decoder 13 | feat_squeezer: -1 14 | feat_channel: 512 15 | multiple_output: False 16 | data: 17 | dataset: airsim 18 | train_split: train 19 | val_split: val 20 | test_split: test 21 | img_rows: 512 22 | img_cols: 512 23 | path: dataset/airsim-srms-noise-data 24 | noisy_type: None 25 | target_view: 'target' 26 | commun_label: 'when2com' 27 | training: 28 | train_iters: 200000 29 | batch_size: 2 30 | val_interval: 1000 31 | n_workers: 8 32 | print_interval: 50 33 | optimizer: 34 | name: 'adam' 35 | lr: 1.0e-5 36 | loss: 37 | name: 'cross_entropy' 38 | size_average: True 39 | lr_schedule: 40 | resume: None 41 | -------------------------------------------------------------------------------- /ptsemseg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-RIPL/MultiAgentPerception/4ef300547a7f7af2676a034f7cf742b009f57d99/ptsemseg/__init__.py -------------------------------------------------------------------------------- /ptsemseg/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from ptsemseg.augmentations.augmentations import ( 3 | AdjustContrast, 4 | AdjustGamma, 5 | AdjustBrightness, 6 | AdjustSaturation, 7 | AdjustHue, 8 | RandomCrop, 9 | RandomHorizontallyFlip, 10 | RandomVerticallyFlip, 11 | Scale, 12 | RandomSized, 13 | RandomSizedCrop, 14 | RandomRotate, 15 | RandomTranslate, 16 | CenterCrop, 17 | Compose, 18 | ) 19 | 20 | logger = logging.getLogger("ptsemseg") 21 | 22 | key2aug = { 23 | "gamma": AdjustGamma, 24 | "hue": AdjustHue, 25 | "brightness": AdjustBrightness, 26 | "saturation": AdjustSaturation, 27 | "contrast": AdjustContrast, 28 | "rcrop": RandomCrop, 29 | "hflip": RandomHorizontallyFlip, 30 | "vflip": RandomVerticallyFlip, 31 | "scale": Scale, 32 | "rsize": RandomSized, 33 | "rsizecrop": RandomSizedCrop, 34 | "rotate": RandomRotate, 35 | "translate": RandomTranslate, 36 | "ccrop": CenterCrop, 37 | } 38 | 39 | 40 | def get_composed_augmentations(aug_dict): 41 | if aug_dict is None: 42 | logger.info("Using No Augmentations") 43 | return None 44 | 45 | augmentations = [] 46 | for aug_key, aug_param in aug_dict.items(): 47 | augmentations.append(key2aug[aug_key](aug_param)) 48 | logger.info("Using {} aug with params {}".format(aug_key, aug_param)) 49 | return Compose(augmentations) 50 | -------------------------------------------------------------------------------- /ptsemseg/augmentations/augmentations.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | import numpy as np 5 | import torchvision.transforms.functional as tf 6 | 7 | from PIL import Image, ImageOps 8 | 9 | 10 | class Compose(object): 11 | def __init__(self, augmentations): 12 | self.augmentations = augmentations 13 | self.PIL2Numpy = False 14 | 15 | def __call__(self, img, mask): 16 | if isinstance(img, np.ndarray): 17 | img = Image.fromarray(img, mode="RGB") 18 | mask = Image.fromarray(mask, mode="L") 19 | self.PIL2Numpy = True 20 | 21 | assert img.size == mask.size 22 | for a in self.augmentations: 23 | img, mask = a(img, mask) 24 | 25 | if self.PIL2Numpy: 26 | img, mask = np.array(img), np.array(mask, dtype=np.uint8) 27 | 28 | return img, mask 29 | 30 | 31 | class RandomCrop(object): 32 | def __init__(self, size, padding=0): 33 | if isinstance(size, numbers.Number): 34 | self.size = (int(size), int(size)) 35 | else: 36 | self.size = size 37 | self.padding = padding 38 | 39 | def __call__(self, img, mask): 40 | if self.padding > 0: 41 | img = ImageOps.expand(img, border=self.padding, fill=0) 42 | mask = ImageOps.expand(mask, border=self.padding, fill=0) 43 | 44 | assert img.size == mask.size 45 | w, h = img.size 46 | th, tw = self.size 47 | if w == tw and h == th: 48 | return img, mask 49 | if w < tw or h < th: 50 | return (img.resize((tw, th), Image.BILINEAR), mask.resize((tw, th), Image.NEAREST)) 51 | 52 | x1 = random.randint(0, w - tw) 53 | y1 = random.randint(0, h - th) 54 | return (img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th))) 55 | 56 | 57 | class AdjustGamma(object): 58 | def __init__(self, gamma): 59 | self.gamma = gamma 60 | 61 | def __call__(self, img, mask): 62 | assert img.size == mask.size 63 | return tf.adjust_gamma(img, random.uniform(1, 1 + self.gamma)), mask 64 | 65 | 66 | class AdjustSaturation(object): 67 | def __init__(self, saturation): 68 | self.saturation = saturation 69 | 70 | def __call__(self, img, mask): 71 | assert img.size == mask.size 72 | return ( 73 | tf.adjust_saturation(img, random.uniform(1 - self.saturation, 1 + self.saturation)), 74 | mask, 75 | ) 76 | 77 | 78 | class AdjustHue(object): 79 | def __init__(self, hue): 80 | self.hue = hue 81 | 82 | def __call__(self, img, mask): 83 | assert img.size == mask.size 84 | return tf.adjust_hue(img, random.uniform(-self.hue, self.hue)), mask 85 | 86 | 87 | class AdjustBrightness(object): 88 | def __init__(self, bf): 89 | self.bf = bf 90 | 91 | def __call__(self, img, mask): 92 | assert img.size == mask.size 93 | return tf.adjust_brightness(img, random.uniform(1 - self.bf, 1 + self.bf)), mask 94 | 95 | 96 | class AdjustContrast(object): 97 | def __init__(self, cf): 98 | self.cf = cf 99 | 100 | def __call__(self, img, mask): 101 | assert img.size == mask.size 102 | return tf.adjust_contrast(img, random.uniform(1 - self.cf, 1 + self.cf)), mask 103 | 104 | 105 | class CenterCrop(object): 106 | def __init__(self, size): 107 | if isinstance(size, numbers.Number): 108 | self.size = (int(size), int(size)) 109 | else: 110 | self.size = size 111 | 112 | def __call__(self, img, mask): 113 | assert img.size == mask.size 114 | w, h = img.size 115 | th, tw = self.size 116 | x1 = int(round((w - tw) / 2.0)) 117 | y1 = int(round((h - th) / 2.0)) 118 | return (img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th))) 119 | 120 | 121 | class RandomHorizontallyFlip(object): 122 | def __init__(self, p): 123 | self.p = p 124 | 125 | def __call__(self, img, mask): 126 | if random.random() < self.p: 127 | return (img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT)) 128 | return img, mask 129 | 130 | 131 | class RandomVerticallyFlip(object): 132 | def __init__(self, p): 133 | self.p = p 134 | 135 | def __call__(self, img, mask): 136 | if random.random() < self.p: 137 | return (img.transpose(Image.FLIP_TOP_BOTTOM), mask.transpose(Image.FLIP_TOP_BOTTOM)) 138 | return img, mask 139 | 140 | 141 | class FreeScale(object): 142 | def __init__(self, size): 143 | self.size = tuple(reversed(size)) # size: (h, w) 144 | 145 | def __call__(self, img, mask): 146 | assert img.size == mask.size 147 | return (img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST)) 148 | 149 | 150 | class RandomTranslate(object): 151 | def __init__(self, offset): 152 | # tuple (delta_x, delta_y) 153 | self.offset = offset 154 | 155 | def __call__(self, img, mask): 156 | assert img.size == mask.size 157 | x_offset = int(2 * (random.random() - 0.5) * self.offset[0]) 158 | y_offset = int(2 * (random.random() - 0.5) * self.offset[1]) 159 | 160 | x_crop_offset = x_offset 161 | y_crop_offset = y_offset 162 | if x_offset < 0: 163 | x_crop_offset = 0 164 | if y_offset < 0: 165 | y_crop_offset = 0 166 | 167 | cropped_img = tf.crop( 168 | img, 169 | y_crop_offset, 170 | x_crop_offset, 171 | img.size[1] - abs(y_offset), 172 | img.size[0] - abs(x_offset), 173 | ) 174 | 175 | if x_offset >= 0 and y_offset >= 0: 176 | padding_tuple = (0, 0, x_offset, y_offset) 177 | 178 | elif x_offset >= 0 and y_offset < 0: 179 | padding_tuple = (0, abs(y_offset), x_offset, 0) 180 | 181 | elif x_offset < 0 and y_offset >= 0: 182 | padding_tuple = (abs(x_offset), 0, 0, y_offset) 183 | 184 | elif x_offset < 0 and y_offset < 0: 185 | padding_tuple = (abs(x_offset), abs(y_offset), 0, 0) 186 | 187 | return ( 188 | tf.pad(cropped_img, padding_tuple, padding_mode="reflect"), 189 | tf.affine( 190 | mask, 191 | translate=(-x_offset, -y_offset), 192 | scale=1.0, 193 | angle=0.0, 194 | shear=0.0, 195 | fillcolor=250, 196 | ), 197 | ) 198 | 199 | 200 | class RandomRotate(object): 201 | def __init__(self, degree): 202 | self.degree = degree 203 | 204 | def __call__(self, img, mask): 205 | rotate_degree = random.random() * 2 * self.degree - self.degree 206 | return ( 207 | tf.affine( 208 | img, 209 | translate=(0, 0), 210 | scale=1.0, 211 | angle=rotate_degree, 212 | resample=Image.BILINEAR, 213 | fillcolor=(0, 0, 0), 214 | shear=0.0, 215 | ), 216 | tf.affine( 217 | mask, 218 | translate=(0, 0), 219 | scale=1.0, 220 | angle=rotate_degree, 221 | resample=Image.NEAREST, 222 | fillcolor=250, 223 | shear=0.0, 224 | ), 225 | ) 226 | 227 | 228 | class Scale(object): 229 | def __init__(self, size): 230 | self.size = size 231 | 232 | def __call__(self, img, mask): 233 | assert img.size == mask.size 234 | w, h = img.size 235 | if (w >= h and w == self.size) or (h >= w and h == self.size): 236 | return img, mask 237 | if w > h: 238 | ow = self.size 239 | oh = int(self.size * h / w) 240 | return (img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST)) 241 | else: 242 | oh = self.size 243 | ow = int(self.size * w / h) 244 | return (img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST)) 245 | 246 | 247 | class RandomSizedCrop(object): 248 | def __init__(self, size): 249 | self.size = size 250 | 251 | def __call__(self, img, mask): 252 | assert img.size == mask.size 253 | for attempt in range(10): 254 | area = img.size[0] * img.size[1] 255 | target_area = random.uniform(0.45, 1.0) * area 256 | aspect_ratio = random.uniform(0.5, 2) 257 | 258 | w = int(round(math.sqrt(target_area * aspect_ratio))) 259 | h = int(round(math.sqrt(target_area / aspect_ratio))) 260 | 261 | if random.random() < 0.5: 262 | w, h = h, w 263 | 264 | if w <= img.size[0] and h <= img.size[1]: 265 | x1 = random.randint(0, img.size[0] - w) 266 | y1 = random.randint(0, img.size[1] - h) 267 | 268 | img = img.crop((x1, y1, x1 + w, y1 + h)) 269 | mask = mask.crop((x1, y1, x1 + w, y1 + h)) 270 | assert img.size == (w, h) 271 | 272 | return ( 273 | img.resize((self.size, self.size), Image.BILINEAR), 274 | mask.resize((self.size, self.size), Image.NEAREST), 275 | ) 276 | 277 | # Fallback 278 | scale = Scale(self.size) 279 | crop = CenterCrop(self.size) 280 | return crop(*scale(img, mask)) 281 | 282 | 283 | class RandomSized(object): 284 | def __init__(self, size): 285 | self.size = size 286 | self.scale = Scale(self.size) 287 | self.crop = RandomCrop(self.size) 288 | 289 | def __call__(self, img, mask): 290 | assert img.size == mask.size 291 | 292 | w = int(random.uniform(0.5, 2) * img.size[0]) 293 | h = int(random.uniform(0.5, 2) * img.size[1]) 294 | 295 | img, mask = (img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST)) 296 | 297 | return self.crop(*self.scale(img, mask)) 298 | -------------------------------------------------------------------------------- /ptsemseg/loader/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from ptsemseg.loader.airsim_loader import airsimLoader 4 | 5 | 6 | def get_loader(name): 7 | """get_loader 8 | 9 | :param name: 10 | """ 11 | return { 12 | "airsim": airsimLoader, 13 | }[name] 14 | 15 | -------------------------------------------------------------------------------- /ptsemseg/loader/airsim_loader.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | 4 | import os 5 | import torch 6 | import numpy as np 7 | import glob 8 | import cv2 9 | import copy 10 | from random import shuffle 11 | import random 12 | from torch.utils import data 13 | from ast import literal_eval as make_tuple 14 | import matplotlib.pyplot as plt 15 | 16 | 17 | 18 | 19 | def label_region_n_compute_distance(i,path_tuple): 20 | begin = path_tuple[0] 21 | end = path_tuple[1] 22 | 23 | # computer distance 24 | distance = ((begin[0]-end[0])**2 +(begin[1]-end[1])**2)**(0.5) 25 | 26 | 27 | # label region 28 | if begin[0] <= -400 or end[0]< -400: 29 | region = 'suburban' 30 | else: 31 | if begin[1] >= 300 or end[1] >=300: 32 | region = 'shopping' 33 | else: 34 | region = 'skyscraper' 35 | 36 | 37 | # update tuple 38 | path_tuple = (i,)+path_tuple + (distance, region,) 39 | 40 | return path_tuple 41 | 42 | 43 | 44 | 45 | class airsimLoader(data.Dataset): 46 | 47 | # segmentation decoding colors 48 | name2color = {"person": [[135, 169, 180]], 49 | "sidewalk": [[242, 107, 146]], 50 | "road": [[156,198,23],[43,79,150]], 51 | "sky": [[209,247,202]], 52 | "pole": [[249,79,73],[72,137,21],[45,157,177],[67,266,253],[206,190,59]], 53 | "building": [[161,171,27],[61,212,54],[151,161,26]], 54 | "car": [[153,108,6]], 55 | "bus": [[190,225,64]], 56 | "truck": [[112,105,191]], 57 | "vegetation": [[29,26,199],[234,21,250],[145,71,201],[247,200,111]] 58 | } 59 | 60 | 61 | name2id = {"person": 1, 62 | "sidewalk": 2, 63 | "road": 3, 64 | "sky": 4, 65 | "pole": 5, 66 | "building": 6, 67 | "car": 7, 68 | "bus": 8, 69 | "truck": 9, 70 | "vegetation": 10 } 71 | 72 | 73 | id2name = {i:name for name,i in name2id.items()} 74 | 75 | splits = ['train', 'val', 'test'] 76 | image_modes = ['scene', 'segmentation_decoded'] 77 | 78 | weathers = ['async_rotate_fog_000_clear'] 79 | 80 | # list of nodes on the maps (this needs for loading from the folder) 81 | all_edges = [ 82 | ((0, 0), (16, -74)), 83 | ((16, -74), (-86, -78)), 84 | ((-86, -78), (-94, -58)), 85 | ((-94, -58), (-94, 24)), 86 | ((-94, 24), (-143, 24)), 87 | ((-143, 24), (-219, 24)), 88 | ((-219, 24), (-219, -68)), 89 | ((-219, -68), (-214, -127)), 90 | ((-214, -127), (-336, -132)), 91 | ((-336, -132), (-335, -180)), 92 | ((-335, -180), (-216, -205)), 93 | ((-216, -205), (-226, -241)), 94 | ((-226, -241), (-240, -252)), 95 | ((-240, -252), (-440, -260)), 96 | ((-440, -260), (-483, -253)), 97 | ((-483, -253), (-494, -223)), 98 | ((-494, -223), (-493, -127)), 99 | ((-493, -127), (-441, -129)), 100 | ((-441, -129), (-443, -222)), 101 | ((-443, -222), (-339, -221)), 102 | ((-339, -221), (-335, -180)), 103 | ((-219, 24), (-248, 24)), 104 | ((-248, 24), (-302, 24)), 105 | ((-302, 24), (-337, 24)), 106 | ((-337, 24), (-593, 25)), 107 | ((-593, 25), (-597, -128)), 108 | ((-597, -128), (-597, -220)), 109 | ((-597, -220), (-748, -222)), 110 | ((-748, -222), (-744, -128)), 111 | ((-744, -128), (-746, 24)), 112 | ((-744, -128), (-597, -128)), 113 | ((-593, 25), (-746, 24)), 114 | ((-746, 24), (-832, 27)), 115 | ((-832, 27), (-804, 176)), 116 | ((-804, 176), (-747, 178)), 117 | ((-747, 178), (-745, 103)), 118 | ((-745, 103), (-696, 104)), 119 | ((-696, 104), (-596, 102)), 120 | ((-596, 102), (-599, 177)), 121 | ((-599, 177), (-747, 178)), 122 | ((-599, 177), (-597, 253)), 123 | ((-596, 102), (-593, 25)), 124 | ((-337, 24), (-338, 172)), 125 | ((-337, 172), (-332, 251)), 126 | ((-337, 172), (-221, 172)), 127 | ((-221, 172), (-221, 264)), 128 | ((-221, 172), (-219, 90)), 129 | ((-219, 90), (-219, 24)), 130 | ((-221, 172), (-148, 172)), 131 | ((-148, 172), (-130, 172)), 132 | ((-130, 172), (-57, 172)), 133 | ((-57, 172), (-57, 194)), 134 | ((20, 192), (20, 92)), 135 | ((20, 92), (21, 76)), 136 | ((21, 76), (66, 22)), 137 | ((66, 22), (123, 28)), 138 | ((123, 28), (123, 106)), 139 | ((123, 106), (123, 135)), 140 | ((123, 135), (176, 135)), 141 | ((176, 135), (176, 179)), 142 | ((176, 179), (210, 180)), 143 | ((210, 180), (210, 107)), 144 | ((210, 107), (216, 26)), 145 | ((216, 26), (118, 21)), 146 | ((118, 21), (118, 2)), 147 | ((118, 2), (100, -62)), 148 | ((100, -62), (89, -70)), 149 | ((89, -70), (62, -76)), 150 | ((62, -76), (28, -76)), 151 | ((28, -76), (16, -74)), 152 | ((16, -74), (14, -17)), 153 | ((-494, -223), (-597, -220)), 154 | ((-597, -128), (-493, -127)), 155 | ((-493, -127), (-493, 25)), 156 | ((-336, -132), (-337, 24)), 157 | ((14, -17), (66, 22)), 158 | ((-597, 253), (-443, 253)), 159 | ((-443, 253), (-332, 251)), 160 | ((-332, 251), (-221, 264)), 161 | ((-221, 264), (-211, 493)), 162 | ((-211, 493), (-129, 493)), 163 | ((-129, 493), (23, 493)), 164 | ((23, 493), (20, 274)), 165 | ((176, 274), (176, 348)), 166 | ((176, 348), (180, 493)), 167 | ((180, 493), (175, 660)), 168 | ((175, 660), (23, 646)), 169 | ((23, 646), (-128, 646)), 170 | ((-128, 646), (-134, 795)), 171 | ((-134, 795), (-130, 871)), 172 | ((-130, 871), (20, 872)), 173 | ((175, 872), (175, 795)), 174 | ((252, 799), (175, 795)), 175 | ((175, 795), (23, 798)), 176 | ((23, 798), (-134, 795)), 177 | ((-134, 795), (-128, 676)), 178 | ((-128, 676), (-129, 493)), 179 | ((23, 493), (23, 646)), 180 | ((23, 646), (23, 798)), 181 | ((23, 798), (20, 872)), 182 | ((-338, 172), (-332, 251)), 183 | ((-57, 255), (20, 255)), 184 | ((-57, 194), (20, 192)), 185 | ((20, 255), (20, 274)), 186 | ((20, 274), (176, 267)), 187 | ((23, 493), (180, 493)), 188 | ((176, 267), (176, 348))] 189 | split_subdirs = {} 190 | ignore_index = 0 191 | mean_rgb = {"airsim": [103.939, 116.779, 123.68],} 192 | 193 | def __init__( 194 | self, 195 | root, 196 | split="train", 197 | subsplit=None, 198 | is_transform=False, 199 | img_size=(512, 512), 200 | augmentations=None, 201 | img_norm=True, 202 | commun_label='None', 203 | version="airsim", 204 | target_view="target" 205 | 206 | ): 207 | 208 | # dataloader parameters 209 | self.dataset_div = self.divide_region_n_train_val_test() 210 | self.split_subdirs = self.generate_image_path(self.dataset_div) 211 | self.commun_label = commun_label 212 | self.root = root 213 | self.split = split 214 | self.is_transform = is_transform 215 | self.augmentations = augmentations 216 | self.img_norm = img_norm 217 | self.n_classes = 11 218 | self.img_size = (img_size if isinstance(img_size, tuple) else (img_size, img_size)) 219 | self.mean = np.array(self.mean_rgb[version]) 220 | 221 | # Set the target view; first element of list is target view 222 | self.cam_pos = self.get_cam_pos(target_view) 223 | 224 | # load the communication label 225 | if self.commun_label != 'None': 226 | comm_label = self.read_selection_label(self.commun_label) 227 | 228 | # Pre-define the empty list for the images 229 | self.imgs = {s:{c:{image_mode:[] for image_mode in self.image_modes} for c in self.cam_pos} for s in self.splits} 230 | self.com_label = {s:[] for s in self.splits} 231 | 232 | k = 0 233 | for split in self.splits: # [train, val] 234 | for subdir in self.split_subdirs[split]: # [trajectory ] 235 | 236 | file_list = sorted(glob.glob(os.path.join(root, 'scene', 'async_rotate_fog_000_clear',subdir,self.cam_pos[0],'*.png'),recursive=True)) 237 | 238 | for file_path in file_list: 239 | ext = file_path.replace(root+"/scene/",'') 240 | file_name = ext.split("/")[-1] 241 | path_dir = ext.split("/")[1] 242 | 243 | # Check if a image file exists in all views and all modalities 244 | list_of_all_cams_n_modal = [os.path.exists(os.path.join(root,modal,'async_rotate_fog_000_clear',path_dir, cam,file_name)) for modal in self.image_modes for cam in self.cam_pos] 245 | 246 | if all(list_of_all_cams_n_modal): 247 | k += 1 248 | # Add the file path to the self.imgs 249 | for comb_modal in self.image_modes: 250 | for comb_cam in self.cam_pos: 251 | file_path = os.path.join(root,comb_modal,'async_rotate_fog_000_clear', path_dir,comb_cam,file_name) 252 | self.imgs[split][comb_cam][comb_modal].append(file_path) 253 | 254 | if self.commun_label != 'None': # Load the communication label 255 | self.com_label[split].append(comm_label[path_dir+'/'+file_name]) 256 | 257 | if not self.imgs[self.split][self.cam_pos[0]][self.image_modes[0]]: 258 | raise Exception( 259 | "No files for split=[%s] found in %s" % (self.split, self.root) 260 | ) 261 | print("Found %d %s images" % (len(self.imgs[self.split][self.cam_pos[0]][self.image_modes[0]]), self.split)) 262 | 263 | # <---- Functions for conversion of paths ----> 264 | def tuple_to_folder_name(self, path_tuple): 265 | start = path_tuple[1] 266 | end = path_tuple[2] 267 | path=str(start[0])+'_'+str(-start[1])+'__'+str(end[0])+'_'+str(-end[1])+'*' 268 | return path 269 | def generate_image_path(self, dataset_div): 270 | 271 | # Merge across regions 272 | train_path_list = [] 273 | val_path_list = [] 274 | test_path_list = [] 275 | for region in ['skyscraper','suburban','shopping']: 276 | for train_one_path in dataset_div['train'][region][1]: 277 | train_path_list.append(self.tuple_to_folder_name(train_one_path)) 278 | 279 | for val_one_path in dataset_div['val'][region][1]: 280 | val_path_list.append(self.tuple_to_folder_name(val_one_path)) 281 | 282 | for test_one_path in dataset_div['test'][region][1]: 283 | test_path_list.append(self.tuple_to_folder_name(test_one_path)) 284 | 285 | 286 | split_subdirs = {} 287 | split_subdirs['train'] = train_path_list 288 | split_subdirs['val'] = val_path_list 289 | split_subdirs['test'] = test_path_list 290 | 291 | return split_subdirs 292 | def divide_region_n_train_val_test(self): 293 | 294 | region_dict = {'skyscraper':[0,[]],'suburban':[0,[]],'shopping':[0,[]]} 295 | test_ratio = 0.25 296 | val_ratio = 0.25 297 | 298 | dataset_div= {'train':{'skyscraper':[0,[]],'suburban':[0,[]],'shopping':[0,[]]}, 299 | 'val' :{'skyscraper':[0,[]],'suburban':[0,[]],'shopping':[0,[]]}, 300 | 'test' :{'skyscraper':[0,[]],'suburban':[0,[]],'shopping':[0,[]]}} 301 | 302 | process_edges = [] 303 | # label and compute distance 304 | for i, path in enumerate(self.all_edges): 305 | process_edges.append(label_region_n_compute_distance(i,path)) 306 | 307 | region_dict[process_edges[i][4]][1].append(process_edges[i]) 308 | region_dict[process_edges[i][4]][0] = region_dict[process_edges[i][4]][0] + process_edges[i][3] 309 | 310 | 311 | for region_type, distance_and_path_list in region_dict.items(): 312 | total_distance = distance_and_path_list[0] 313 | test_distance = total_distance*test_ratio 314 | val_distance = total_distance*val_ratio 315 | 316 | path_list = distance_and_path_list[1] 317 | tem_list = copy.deepcopy(path_list) 318 | 319 | random.seed(2019) 320 | shuffle(tem_list) 321 | 322 | sum_distance = 0 323 | 324 | # Test Set 325 | while sum_distance < test_distance*0.8: 326 | path = tem_list.pop() 327 | sum_distance += path[3] 328 | dataset_div['test'][region_type][0] = dataset_div['test'][region_type][0] + path[3] 329 | dataset_div['test'][region_type][1].append(path) 330 | 331 | # Val Set 332 | while sum_distance < (test_distance + val_distance)*0.8: 333 | path = tem_list.pop() 334 | sum_distance += path[3] 335 | dataset_div['val'][region_type][0] = dataset_div['val'][region_type][0] + path[3] 336 | dataset_div['val'][region_type][1].append(path) 337 | 338 | # Train Set 339 | dataset_div['train'][region_type][0] = total_distance - sum_distance 340 | dataset_div['train'][region_type][1] = tem_list 341 | 342 | color=['red','green','blue'] 343 | ## Visualiaztion with respect to region 344 | fig, ax = plt.subplots(figsize=(30, 15)) 345 | div_type = 'train' 346 | 347 | vis_txt_height = 800 348 | for div_type in ['train','val','test']: 349 | for region in ['skyscraper','suburban','shopping']: 350 | vis_path_list = dataset_div[div_type][region][1] 351 | for path in vis_path_list: 352 | x = [path[1][0],path[2][0]] 353 | y = [path[1][1],path[2][1]] 354 | 355 | if region == 'skyscraper': 356 | ax.plot(x, y, color='red', zorder=1, lw=3) 357 | elif region == 'suburban': 358 | ax.plot(x, y, color='blue', zorder=1, lw=3) 359 | elif region == 'shopping': 360 | ax.plot(x, y, color='green', zorder=1, lw=3) 361 | 362 | ax.scatter(x, y,color='black', s=120, zorder=2) 363 | 364 | # Visualize distance text 365 | distance = dataset_div[div_type][region][0] 366 | if region == 'skyscraper': 367 | ax.annotate(div_type+' - '+ region+': '+str(distance), (-800, vis_txt_height),fontsize=20,color='red') 368 | elif region == 'suburban': 369 | ax.annotate(div_type+' - '+ region+': '+str(distance), (-800, vis_txt_height),fontsize=20,color='blue') 370 | elif region == 'shopping': 371 | ax.annotate(div_type+' - '+ region+': '+str(distance), (-800, vis_txt_height),fontsize=20,color='green') 372 | vis_txt_height-=30 373 | 374 | plt.savefig('region.png', dpi=200) 375 | plt.close() 376 | 377 | ## Visualization with respect to train/val/test 378 | fig, ax = plt.subplots(figsize=(30, 15)) 379 | div_type = 'train' 380 | vis_txt_height = 800 381 | for div_type in ['train','val','test']: 382 | for region in ['skyscraper','suburban','shopping']: 383 | vis_path_list = dataset_div[div_type][region][1] 384 | for path in vis_path_list: 385 | x = [path[1][0],path[2][0]] 386 | y = [path[1][1],path[2][1]] 387 | 388 | if div_type == 'train': 389 | ax.plot(x, y, color='red', zorder=1, lw=3) 390 | elif div_type == 'val': 391 | ax.plot(x, y, color='blue', zorder=1, lw=3) 392 | elif div_type == 'test': 393 | ax.plot(x, y, color='green', zorder=1, lw=3) 394 | 395 | ax.scatter(x, y,color='black', s=120, zorder=2) 396 | 397 | # Visualize distance text 398 | distance = dataset_div[div_type][region][0] 399 | if div_type == 'train': 400 | ax.annotate(div_type+' - '+ region+': '+str(distance), (-800, vis_txt_height),fontsize=20,color='red') 401 | elif div_type == 'val': 402 | ax.annotate(div_type+' - '+ region+': '+str(distance), (-800, vis_txt_height),fontsize=20,color='blue') 403 | elif div_type == 'test': 404 | ax.annotate(div_type+' - '+ region+': '+str(distance), (-800, vis_txt_height),fontsize=20,color='green') 405 | vis_txt_height-=30 406 | 407 | #ax.annotate(txt, (x, y)) 408 | plt.savefig('train_val_test.png', dpi=200) 409 | plt.close() 410 | 411 | return dataset_div 412 | def read_selection_label(self, label_type): 413 | 414 | if label_type == 'when2com': 415 | with open(os.path.join(self.root,'gt_when_to_communicate.txt')) as f: 416 | content = f.readlines() 417 | 418 | com_label = {} 419 | for x in content: 420 | key = x.split(' ')[2].strip().split('/')[-3] + '/' + x.split(' ')[2].strip().split('/')[-1]+'.png' 421 | com_label[key] = int(x.split(' ')[1]) 422 | 423 | elif label_type == 'mimo': 424 | with open(os.path.join(self.root,'gt_mimo_communicate.txt')) as f: 425 | content = f.readlines() 426 | com_label = {} 427 | for x in content: 428 | file_key = x.split(' ')[-1].strip().split('/')[-3] + '/' + x.split(' ')[-1].strip().split('/')[-1]+'.png' 429 | noise_label = make_tuple(x.split(' (')[0]) 430 | link_label = make_tuple(x.split(') ')[1] + ')') 431 | com_label[file_key] = torch.tensor([noise_label, link_label]) 432 | 433 | else: 434 | raise ValueError('Unknown label file name '+ str(label_type)) 435 | 436 | 437 | print('Loaded: selection label.') 438 | return com_label 439 | 440 | def convert_link_label(self, link_label): 441 | div_list = [] 442 | for i in link_label: 443 | div_list.append(int(i/2)) 444 | 445 | new_link_label = [] 446 | for i, elem_i in enumerate(div_list): 447 | for j, elem_j in enumerate(div_list): 448 | if j != i and elem_i == elem_j: 449 | new_link_label.append(j) 450 | new_link_label = tuple(new_link_label) 451 | return new_link_label 452 | def get_cam_pos(self, target_view): 453 | if target_view == "overhead": 454 | cam_pos = [ 'overhead', 'front', 'back', 'left', 'right'] 455 | elif target_view == "front": 456 | cam_pos = [ 'front', 'back', 'left', 'right','overhead'] 457 | elif target_view == "back": 458 | cam_pos = [ 'back', 'front', 'left', 'right','overhead'] 459 | elif target_view == "left": 460 | cam_pos = [ 'left', 'back', 'front','right','overhead'] 461 | elif target_view == "target": 462 | cam_pos = [ 'target', 'normal1', 'normal2','normal3','normal4'] 463 | elif target_view == "6agent": 464 | cam_pos = ['agent1', 'agent2', 'agent3', 'agent4', 'agent5', 'agent6'] 465 | elif target_view == "5agent": 466 | cam_pos = ['agent1', 'agent2', 'agent3', 'agent4', 'agent5'] 467 | elif target_view == "DroneNP": 468 | cam_pos = ["DroneNN_main", "DroneNP_main", "DronePN_main", "DronePP_main", "DroneZZ_main"] 469 | elif target_view == "DroneNN_backNN": 470 | cam_pos = ["DroneNN_backNN", "DroneNP_backNP", "DronePN_backPN", "DroneNN_frontNN", "DroneNP_frontNP"] 471 | elif target_view == "5agentv7": 472 | cam_pos = ["agent1", "agent3", "agent5", "agent2", "agent4"] 473 | else: 474 | cam_pos = [ 'front', 'back', 'left', 'right', 'overhead'] 475 | return cam_pos 476 | # <---- Functions for conversion of paths ----> 477 | 478 | 479 | 480 | 481 | def __len__(self): 482 | """__len__""" 483 | return len(self.imgs[self.split][self.cam_pos[0]][self.image_modes[0]]) 484 | 485 | def __getitem__(self, index): 486 | """__getitem__ 487 | 488 | :param index: 489 | """ 490 | img_list = [] 491 | lbl_list = [] 492 | 493 | for k, camera in enumerate(self.cam_pos): 494 | 495 | img_path, mask_path = self.imgs[self.split][camera]['scene'][index], self.imgs[self.split][camera]['segmentation_decoded'][index] 496 | img, mask = np.array(cv2.imread(img_path),dtype=np.uint8)[:,:,:3], np.array(cv2.imread(mask_path),dtype=np.uint8)[:,:,0] 497 | 498 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 499 | lbl = mask 500 | if self.augmentations is not None: 501 | img, lbl, aux = self.augmentations(img, lbl) 502 | 503 | if self.is_transform: 504 | img, lbl = self.transform(img, lbl) 505 | 506 | img_list.append(img) 507 | lbl_list.append(lbl) 508 | 509 | if self.commun_label != 'None': 510 | return img_list, lbl_list, self.com_label[self.split][index] #, self.debug_file_path[self.split][index] 511 | else: 512 | return img_list, lbl_list 513 | 514 | 515 | def transform(self, img, lbl): 516 | 517 | """transform 518 | :param img: 519 | :param lbl: 520 | """ 521 | img = img[:, :, ::-1] # RGB -> BGR 522 | img = img.astype(np.float64) 523 | img -= self.mean 524 | if self.img_norm: 525 | img = img.astype(float) / 255.0 526 | # NHWC -> NCHW 527 | img = img.transpose(2, 0, 1) 528 | 529 | classes = np.unique(lbl) 530 | lbl = lbl.astype(float) 531 | lbl = lbl.astype(int) 532 | 533 | if not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes): 534 | print("after det", classes, np.unique(lbl)) 535 | raise ValueError("Segmentation map contained invalid class values") 536 | 537 | img = torch.from_numpy(img).float() 538 | lbl = torch.from_numpy(lbl).long() 539 | 540 | return img, lbl 541 | 542 | def decode_segmap(self, temp): 543 | r = temp.copy() 544 | g = temp.copy() 545 | b = temp.copy() 546 | for i,name in self.id2name.items(): 547 | r[(temp==i)] = self.name2color[name][0][0] 548 | g[(temp==i)] = self.name2color[name][0][1] 549 | b[(temp==i)] = self.name2color[name][0][2] 550 | 551 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3)) 552 | rgb[:, :, 0] = r / 255.0 553 | rgb[:, :, 1] = g / 255.0 554 | rgb[:, :, 2] = b / 255.0 555 | return rgb 556 | 557 | 558 | def save_tensor_imag(imgs): 559 | import numpy as np 560 | import cv2 561 | bs = imgs[0].shape[0] 562 | mean_rgb = np.array([103.939, 116.779, 123.68]) 563 | 564 | for view in range(len(imgs)): 565 | for i in range(bs): 566 | image = imgs[view][i] 567 | 568 | image = image.cpu().numpy() 569 | image = np.transpose(image, (1, 2, 0)) 570 | image = image * 255 + mean_rgb 571 | cv2.imwrite('debug_tmp/img_b' + str(i) +'_v'+str(view)+ '.png', image) 572 | 573 | 574 | -------------------------------------------------------------------------------- /ptsemseg/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import functools 3 | 4 | from ptsemseg.loss.loss import ( 5 | cross_entropy2d, 6 | bootstrapped_cross_entropy2d, 7 | multi_scale_cross_entropy2d 8 | ) 9 | ## Access to loss functions 10 | 11 | logger = logging.getLogger("ptsemseg") 12 | 13 | key2loss = { 14 | "cross_entropy": cross_entropy2d, 15 | "bootstrapped_cross_entropy": bootstrapped_cross_entropy2d, 16 | "multi_scale_cross_entropy": multi_scale_cross_entropy2d 17 | } 18 | 19 | 20 | def get_loss_function(cfg): 21 | if cfg["training"]["loss"] is None: 22 | logger.info("Using default cross entropy loss") 23 | return cross_entropy2d 24 | 25 | else: 26 | loss_dict = cfg["training"]["loss"] 27 | loss_name = loss_dict["name"] 28 | loss_params = {k: v for k, v in loss_dict.items() if k != "name"} 29 | 30 | if loss_name not in key2loss: 31 | raise NotImplementedError("Loss {} not implemented".format(loss_name)) 32 | 33 | logger.info("Using {} with {} params".format(loss_name, loss_params)) 34 | return functools.partial(key2loss[loss_name], **loss_params) 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /ptsemseg/loss/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import pdb 4 | 5 | def cross_entropy2d(input, target, weight=None, size_average=True): 6 | n, c, h, w = input.size() 7 | nt, ht, wt = target.size() 8 | 9 | # Handle inconsistent size between input and target 10 | if h != ht and w != wt: # upsample labels 11 | input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True) 12 | 13 | input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 14 | target = target.view(-1) 15 | loss = F.cross_entropy( 16 | input, target, weight=weight, size_average=size_average, ignore_index=250 17 | ) 18 | return loss 19 | 20 | 21 | def multi_scale_cross_entropy2d(input, target, weight=None, size_average=True, scale_weight=None): 22 | if not isinstance(input, tuple): 23 | return cross_entropy2d(input=input, target=target, weight=weight, size_average=size_average) 24 | 25 | # Auxiliary training for PSPNet [1.0, 0.4] and ICNet [1.0, 0.4, 0.16] 26 | if scale_weight is None: # scale_weight: torch tensor type 27 | n_inp = len(input) 28 | scale = 0.4 29 | scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp).float()).to('cuda' if target.is_cuda else 'cpu') 30 | 31 | loss = 0.0 32 | for i, inp in enumerate(input): 33 | loss = loss + scale_weight[i] * cross_entropy2d( 34 | input=inp, target=target, weight=weight, size_average=size_average 35 | ) 36 | 37 | return loss 38 | 39 | 40 | def bootstrapped_cross_entropy2d(input, target, K, weight=None, size_average=True): 41 | 42 | batch_size = input.size()[0] 43 | 44 | def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True): 45 | 46 | n, c, h, w = input.size() 47 | input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) 48 | target = target.view(-1) 49 | loss = F.cross_entropy( 50 | input, target, weight=weight, reduce=False, size_average=False, ignore_index=250 51 | ) 52 | 53 | topk_loss, _ = loss.topk(K) 54 | reduced_topk_loss = topk_loss.sum() / K 55 | 56 | return reduced_topk_loss 57 | 58 | loss = 0.0 59 | # Bootstrap from each image not entire batch 60 | for i in range(batch_size): 61 | loss += _bootstrap_xentropy_single( 62 | input=torch.unsqueeze(input[i], 0), 63 | target=torch.unsqueeze(target[i], 0), 64 | K=K, 65 | weight=weight, 66 | size_average=size_average, 67 | ) 68 | return loss / float(batch_size) 69 | -------------------------------------------------------------------------------- /ptsemseg/metrics.py: -------------------------------------------------------------------------------- 1 | # Adapted from score written by wkentaro 2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py 3 | 4 | import numpy as np 5 | import torch 6 | 7 | class runningScore(object): 8 | def __init__(self, n_classes): 9 | self.n_classes = n_classes 10 | self.confusion_matrix = np.zeros((n_classes, n_classes)) 11 | self.confusion_matrix_pos = np.zeros((n_classes, n_classes)) 12 | self.confusion_matrix_neg = np.zeros((n_classes, n_classes)) 13 | self.total_agent = 0 14 | self.correct_when2com = 0 15 | self.correct_who2com = 0 16 | self.total_bandW = 0 17 | self.count = 0 18 | 19 | def update_bandW(self, bandW): 20 | self.total_bandW += bandW 21 | self.count += 1.0 22 | 23 | def update_selection(self, if_commun_label, commun_label, action_argmax): 24 | if if_commun_label == 'when2com': 25 | action_argmax = torch.squeeze(action_argmax) 26 | commun_label = commun_label + 1 # -1,0,1,2,3 ->0, 1, 2 ,3 ,4 27 | 28 | 29 | self.total_agent += commun_label.size(0) 30 | when_to_commu_label = (commun_label == 0) 31 | 32 | if action_argmax.dim() == 2: 33 | predict_link = (action_argmax > 0.2).nonzero() 34 | link_num = predict_link.shape[0] 35 | when2com_pred = torch.zeros(commun_label.size(0), dtype=torch.int8) 36 | 37 | for row_idx in range(link_num): 38 | sample_idx = predict_link[row_idx,:][0] 39 | link_idx = predict_link[row_idx,:][1] 40 | if link_idx == commun_label[sample_idx]: 41 | self.correct_who2com = self.correct_who2com +1 42 | if link_idx != 0: 43 | when2com_pred[sample_idx] = True 44 | when2com_pred = when2com_pred.cuda() 45 | self.correct_when2com += (when2com_pred == when_to_commu_label).sum().item() 46 | elif action_argmax.dim() == 1: 47 | # Learn when to communicate accuracy 48 | when_to_commu_pred = (action_argmax == 0) 49 | self.correct_when2com += (when_to_commu_pred == when_to_commu_label).sum().item() 50 | 51 | # Learn who to communicate accuracy 52 | self.correct_who2com += (action_argmax == commun_label).sum().item() 53 | else: 54 | assert commun_label.shape == action_argmax.shape, "Shape of selection labels are different." 55 | elif if_commun_label == 'mimo': 56 | 57 | # commun_label = commun_label.cpu() 58 | self.total_agent += commun_label[:,0,:].shape[0]*commun_label[:,0,:].shape[1] 59 | # when2com 60 | id_tensor = torch.arange(action_argmax.shape[1]).repeat(action_argmax.shape[0], 1) 61 | when_to_commu_pred = (action_argmax.cpu() != id_tensor) 62 | when_to_commu_label = commun_label[:,0,:].type(torch.ByteTensor) 63 | self.correct_when2com += (when_to_commu_pred == when_to_commu_label).sum().item() 64 | 65 | # who2com (gpu) who *(need com) 66 | gt_action = commun_label[:,1,:] * commun_label[:,0,:] + id_tensor.cuda()*(1 - commun_label[:,0,:]) 67 | 68 | self.correct_who2com += (action_argmax == gt_action).sum().item() 69 | 70 | def update_div(self, if_commun_label, label_trues, label_preds, commun_label): 71 | # import pdb;pdb.set_trace() 72 | if if_commun_label == 'when2com': 73 | commun_label = commun_label.cpu().numpy() 74 | when2comlab = (commun_label == -1) # -1 ---> noraml # other --> noist need com 75 | elif if_commun_label == 'mimo': 76 | commun_label = commun_label.cpu().numpy()[:, 0, :] 77 | when2comlab = (commun_label == 0) # 0 --> normal # 1 ---> noisy need com 78 | when2comlab = when2comlab.transpose(1, 0) #[batch of agent1, batch of agent2 ] 79 | when2comlab = when2comlab.flatten() 80 | 81 | # import pdb; pdb.set_trace() 82 | pos_idx = (when2comlab == True).nonzero() 83 | neg_idx = (when2comlab == False).nonzero() 84 | 85 | label_trues_pos = label_trues[pos_idx] 86 | label_preds_pos = label_preds[pos_idx] 87 | if label_trues_pos.shape[0] != 0 : 88 | for lt, lp in zip(label_trues_pos, label_preds_pos): 89 | self.confusion_matrix_pos += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 90 | 91 | label_trues_neg = label_trues[neg_idx] 92 | label_preds_neg = label_preds[neg_idx] 93 | if label_trues_neg.shape[0] != 0: 94 | for lt, lp in zip(label_trues_neg, label_preds_neg): 95 | self.confusion_matrix_neg += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 96 | 97 | 98 | 99 | def _fast_hist(self, label_true, label_pred, n_class): 100 | mask = (label_true >= 0) & (label_true < n_class) 101 | hist = np.bincount( 102 | n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2 103 | ).reshape(n_class, n_class) 104 | return hist 105 | 106 | def update(self, label_trues, label_preds): 107 | for lt, lp in zip(label_trues, label_preds): 108 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) 109 | 110 | def get_avg_bandW(self): 111 | return self.total_bandW/self.count 112 | 113 | def get_only_noise_scores(self): 114 | """Returns accuracy score evaluation result. 115 | - overall accuracy 116 | - mean accuracy 117 | - mean IU 118 | - fwavacc 119 | """ 120 | hist = self.confusion_matrix_neg 121 | acc = np.diag(hist).sum() / hist.sum() 122 | acc_cls = np.diag(hist) / hist.sum(axis=1) 123 | acc_cls = np.nanmean(acc_cls) 124 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 125 | mean_iu = np.nanmean(iu) 126 | freq = hist.sum(axis=1) / hist.sum() 127 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 128 | cls_iu = dict(zip(range(self.n_classes), iu)) 129 | 130 | return ( 131 | { 132 | "Overall Acc: \t": acc, 133 | "Mean Acc : \t": acc_cls, 134 | "FreqW Acc : \t": fwavacc, 135 | "Mean IoU : \t": mean_iu, 136 | }, 137 | cls_iu, 138 | ) 139 | 140 | 141 | def get_only_normal_scores(self): 142 | """Returns accuracy score evaluation result. 143 | - overall accuracy 144 | - mean accuracy 145 | - mean IU 146 | - fwavacc 147 | """ 148 | hist = self.confusion_matrix_pos 149 | acc = np.diag(hist).sum() / hist.sum() 150 | acc_cls = np.diag(hist) / hist.sum(axis=1) 151 | acc_cls = np.nanmean(acc_cls) 152 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 153 | mean_iu = np.nanmean(iu) 154 | freq = hist.sum(axis=1) / hist.sum() 155 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 156 | cls_iu = dict(zip(range(self.n_classes), iu)) 157 | 158 | return ( 159 | { 160 | "Overall Acc: \t": acc, 161 | "Mean Acc : \t": acc_cls, 162 | "FreqW Acc : \t": fwavacc, 163 | "Mean IoU : \t": mean_iu, 164 | }, 165 | cls_iu, 166 | ) 167 | 168 | def get_scores(self): 169 | """Returns accuracy score evaluation result. 170 | - overall accuracy 171 | - mean accuracy 172 | - mean IU 173 | - fwavacc 174 | """ 175 | hist = self.confusion_matrix 176 | acc = np.diag(hist).sum() / hist.sum() 177 | acc_cls = np.diag(hist) / hist.sum(axis=1) 178 | acc_cls = np.nanmean(acc_cls) 179 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 180 | mean_iu = np.nanmean(iu) 181 | freq = hist.sum(axis=1) / hist.sum() 182 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 183 | cls_iu = dict(zip(range(self.n_classes), iu)) 184 | 185 | return ( 186 | { 187 | "Overall Acc: \t": acc, 188 | "Mean Acc : \t": acc_cls, 189 | "FreqW Acc : \t": fwavacc, 190 | "Mean IoU : \t": mean_iu, 191 | }, 192 | cls_iu, 193 | ) 194 | 195 | 196 | def get_selection_accuracy(self): 197 | when_com_accuacy = self.correct_when2com / self.total_agent * 100 198 | who_com_accuracy = self.correct_who2com / self.total_agent * 100 199 | 200 | return when_com_accuacy, who_com_accuracy 201 | 202 | 203 | 204 | 205 | def reset(self): 206 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) 207 | self.total_agent = 0 208 | self.correct_when2com = 0 209 | self.correct_who2com = 0 210 | self.total_bandW = 0 211 | self.count = 0 212 | 213 | 214 | def print_score(self,n_classes, score, class_iou): 215 | metric_string = "" 216 | class_string = "" 217 | 218 | for i in range(n_classes): 219 | # print(i, class_iou[i]) 220 | metric_string = metric_string + " " + str(i) 221 | class_string = class_string + " " + str(round(class_iou[i] * 100, 2)) 222 | 223 | for k, v in score.items(): 224 | metric_string = metric_string + " " + str(k) 225 | class_string = class_string + " " + str(round(v * 100, 2)) 226 | # print(k, v) 227 | print(metric_string) 228 | print(class_string) 229 | 230 | 231 | class averageMeter(object): 232 | """Computes and stores the average and current value""" 233 | 234 | def __init__(self): 235 | self.reset() 236 | 237 | def reset(self): 238 | self.val = 0 239 | self.avg = 0 240 | self.sum = 0 241 | self.count = 0 242 | 243 | def update(self, val, n=1): 244 | self.val = val 245 | self.sum += val * n 246 | self.count += n 247 | self.avg = self.sum / self.count 248 | -------------------------------------------------------------------------------- /ptsemseg/models/__init__.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torchvision.models as models 3 | 4 | 5 | from ptsemseg.models.agent import Single_agent, All_agents, LearnWho2Com, LearnWhen2Com, MIMOcom,MIMO_All_agents, MIMOcomWho 6 | 7 | 8 | def get_model(model_dict, n_classes, version=None): 9 | name = model_dict["model"]["arch"] 10 | 11 | model = _get_model_instance(name) 12 | in_channels = 3 13 | if name == "Single_agent": 14 | model = model(n_classes=n_classes, in_channels=in_channels, 15 | enc_backbone=model_dict["model"]['enc_backbone'], 16 | dec_backbone=model_dict["model"]['dec_backbone'], 17 | feat_squeezer=model_dict["model"]['feat_squeezer'], 18 | feat_channel=model_dict["model"]['feat_channel']) 19 | elif name == "All_agents": 20 | model = model(n_classes=n_classes, in_channels=in_channels, 21 | aux_agent_num=model_dict["model"]['agent_num'], 22 | shuffle_flag=model_dict["model"]['shuffle_features'], 23 | enc_backbone=model_dict["model"]['enc_backbone'], 24 | dec_backbone=model_dict["model"]['dec_backbone'], 25 | feat_squeezer=model_dict["model"]['feat_squeezer'], 26 | feat_channel=model_dict["model"]['feat_channel']) 27 | elif name == "MIMO_All_agents": 28 | model = model(n_classes=n_classes, in_channels=in_channels, 29 | aux_agent_num=model_dict["model"]['agent_num'], 30 | shuffle_flag=model_dict["model"]['shuffle_features'], 31 | enc_backbone=model_dict["model"]['enc_backbone'], 32 | dec_backbone=model_dict["model"]['dec_backbone'], 33 | feat_squeezer=model_dict["model"]['feat_squeezer'], 34 | feat_channel=model_dict["model"]['feat_channel']) 35 | 36 | elif name == "LearnWho2Com": 37 | model = model(n_classes=n_classes, in_channels=in_channels, 38 | attention=model_dict["model"]['attention'],has_query=model_dict["model"]['query'], 39 | sparse=model_dict["model"]['sparse'], 40 | aux_agent_num=model_dict["model"]['agent_num'], 41 | shared_img_encoder=model_dict["model"]["shared_img_encoder"], 42 | image_size=model_dict["data"]["img_rows"], 43 | query_size=model_dict["model"]["query_size"],key_size=model_dict["model"]["key_size"], 44 | enc_backbone=model_dict["model"]['enc_backbone'], 45 | dec_backbone=model_dict["model"]['dec_backbone'] 46 | ) 47 | elif name == "LearnWhen2Com": 48 | model = model(n_classes=n_classes, in_channels=in_channels, 49 | attention=model_dict["model"]['attention'],has_query=model_dict["model"]['query'], 50 | sparse=model_dict["model"]['sparse'], 51 | aux_agent_num=model_dict["model"]['agent_num'], 52 | shared_img_encoder=model_dict["model"]["shared_img_encoder"], 53 | image_size=model_dict["data"]["img_rows"], 54 | query_size=model_dict["model"]["query_size"],key_size=model_dict["model"]["key_size"], 55 | enc_backbone=model_dict["model"]['enc_backbone'], 56 | dec_backbone=model_dict["model"]['dec_backbone'] 57 | ) 58 | elif name == "MIMOcom": 59 | model = model(n_classes=n_classes, in_channels=in_channels, 60 | attention=model_dict["model"]['attention'],has_query=model_dict["model"]['query'], 61 | sparse=model_dict["model"]['sparse'], 62 | agent_num=model_dict["model"]['agent_num'], 63 | shared_img_encoder=model_dict["model"]["shared_img_encoder"], 64 | image_size=model_dict["data"]["img_rows"], 65 | query_size=model_dict["model"]["query_size"],key_size=model_dict["model"]["key_size"], 66 | enc_backbone=model_dict["model"]['enc_backbone'], 67 | dec_backbone=model_dict["model"]['dec_backbone'] 68 | ) 69 | elif name == "MIMOcomWho": 70 | model = model(n_classes=n_classes, in_channels=in_channels, 71 | attention=model_dict["model"]['attention'],has_query=model_dict["model"]['query'], 72 | sparse=model_dict["model"]['sparse'], 73 | agent_num=model_dict["model"]['agent_num'], 74 | shared_img_encoder=model_dict["model"]["shared_img_encoder"], 75 | image_size=model_dict["data"]["img_rows"], 76 | query_size=model_dict["model"]["query_size"],key_size=model_dict["model"]["key_size"], 77 | enc_backbone=model_dict["model"]['enc_backbone'], 78 | dec_backbone=model_dict["model"]['dec_backbone'] 79 | ) 80 | 81 | else: 82 | model = model(n_classes=n_classes, in_channels=in_channels, 83 | enc_backbone=model_dict["model"]['enc_backbone'], 84 | dec_backbone=model_dict["model"]['dec_backbone']) 85 | 86 | return model 87 | 88 | 89 | def _get_model_instance(name): 90 | try: 91 | return { 92 | "Single_agent": Single_agent, 93 | "All_agents": All_agents, 94 | "MIMO_All_agents": MIMO_All_agents, 95 | 'LearnWho2Com':LearnWho2Com, 96 | 'LearnWhen2Com': LearnWhen2Com, 97 | 'MIMOcom': MIMOcom, 98 | 'MIMOcomWho': MIMOcomWho, 99 | }[name] 100 | except: 101 | raise ("Model {} not available".format(name)) 102 | -------------------------------------------------------------------------------- /ptsemseg/models/agent.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torchvision.models as models 4 | import torch.nn.functional as F 5 | import pretrainedmodels 6 | import copy 7 | 8 | from ptsemseg.models.utils import conv2DBatchNormRelu, deconv2DBatchNormRelu, Sparsemax 9 | import random 10 | from torch.distributions.categorical import Categorical 11 | 12 | from ptsemseg.models.backbone import n_segnet_encoder, resnet_encoder, n_segnet_decoder, simple_decoder, FCN_decoder 13 | import numpy as np 14 | 15 | 16 | def get_encoder(name): 17 | try: 18 | return { 19 | "n_segnet_encoder": n_segnet_encoder, 20 | "resnet_encoder": resnet_encoder, 21 | }[name] 22 | except: 23 | raise ("Encoder {} not available".format(name)) 24 | 25 | 26 | def get_decoder(name): 27 | try: 28 | return { 29 | "n_segnet_decoder": n_segnet_decoder, 30 | "simple_decoder": simple_decoder, 31 | "FCN_decoder": FCN_decoder 32 | 33 | }[name] 34 | except: 35 | raise ("Decoder {} not available".format(name)) 36 | 37 | 38 | ### ============= Modules ============= ### 39 | class img_encoder(nn.Module): 40 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, feat_squeezer=-1, 41 | enc_backbone='n_segnet_encoder'): 42 | super(img_encoder, self).__init__() 43 | feat_chn = 256 44 | 45 | self.feature_backbone = get_encoder(enc_backbone)(n_classes=n_classes, in_channels=in_channels) 46 | self.feat_squeezer = feat_squeezer 47 | 48 | # squeeze the feature map size 49 | if feat_squeezer == 2: # resolution/2 50 | self.squeezer = conv2DBatchNormRelu(512, feat_channel, k_size=3, stride=2, padding=1) 51 | elif feat_squeezer == 4: # resolution/4 52 | self.squeezer = conv2DBatchNormRelu(512, feat_channel, k_size=3, stride=4, padding=1) 53 | else: 54 | self.squeezer = conv2DBatchNormRelu(512, feat_channel, k_size=3, stride=1, padding=1) 55 | 56 | def forward(self, inputs): 57 | outputs = self.feature_backbone(inputs) 58 | outputs = self.squeezer(outputs) 59 | 60 | return outputs 61 | 62 | 63 | class img_decoder(nn.Module): 64 | def __init__(self, n_classes=21, in_channels=512, agent_num=5, feat_squeezer=-1, dec_backbone='n_segnet_decoder'): 65 | super(img_decoder, self).__init__() 66 | 67 | self.feat_squeezer = feat_squeezer 68 | if feat_squeezer == 2: # resolution/2 69 | self.desqueezer = deconv2DBatchNormRelu(in_channels, in_channels, k_size=3, stride=2, padding=1, 70 | output_padding=1) 71 | self.output_decoder = get_decoder(dec_backbone)(n_classes=n_classes, in_channels=in_channels) 72 | 73 | elif feat_squeezer == 4: # resolution/4 74 | self.desqueezer1 = deconv2DBatchNormRelu(in_channels, 512, k_size=3, stride=2, padding=1, output_padding=1) 75 | self.desqueezer2 = deconv2DBatchNormRelu(512, 512, k_size=3, stride=2, padding=1, output_padding=1) 76 | self.output_decoder = get_decoder(dec_backbone)(n_classes=n_classes, in_channels=512) 77 | else: 78 | self.output_decoder = get_decoder(dec_backbone)(n_classes=n_classes, in_channels=in_channels) 79 | 80 | def forward(self, inputs): 81 | if self.feat_squeezer == 2: # resolution/2 82 | inputs = self.desqueezer(inputs) 83 | 84 | elif self.feat_squeezer == 4: # resolution/4 85 | inputs = self.desqueezer1(inputs) 86 | inputs = self.desqueezer2(inputs) 87 | 88 | outputs = self.output_decoder(inputs) 89 | return outputs 90 | 91 | 92 | class msg_generator(nn.Module): 93 | def __init__(self, in_channels=512, message_size=32): 94 | super(msg_generator, self).__init__() 95 | self.in_channels = in_channels 96 | 97 | # Encoder 98 | # down 1 99 | self.conv1 = conv2DBatchNormRelu(self.in_channels, 256, k_size=3, stride=1, padding=1) 100 | self.conv2 = conv2DBatchNormRelu(256, 128, k_size=3, stride=1, padding=1) 101 | self.conv3 = conv2DBatchNormRelu(128, 64, k_size=3, stride=1, padding=1) 102 | self.conv4 = conv2DBatchNormRelu(64, 64, k_size=3, stride=1, padding=1) 103 | self.conv5 = conv2DBatchNormRelu(64, message_size, k_size=3, stride=1, padding=1) 104 | 105 | def forward(self, inputs): 106 | outputs = self.conv1(inputs) 107 | outputs = self.conv2(outputs) 108 | outputs = self.conv3(outputs) 109 | outputs = self.conv4(outputs) 110 | outputs = self.conv5(outputs) 111 | return outputs 112 | 113 | 114 | class policy_net4(nn.Module): 115 | def __init__(self, n_classes=21, in_channels=512, input_feat_sz=32, enc_backbone='n_segnet_encoder'): 116 | super(policy_net4, self).__init__() 117 | self.in_channels = in_channels 118 | 119 | feat_map_sz = input_feat_sz // 4 120 | self.n_feat = int(256 * feat_map_sz * feat_map_sz) 121 | 122 | self.img_encoder = img_encoder(n_classes=n_classes, in_channels=self.in_channels, enc_backbone=enc_backbone) 123 | 124 | # Encoder 125 | # down 1 126 | self.conv1 = conv2DBatchNormRelu(512, 512, k_size=3, stride=1, padding=1) 127 | self.conv2 = conv2DBatchNormRelu(512, 256, k_size=3, stride=1, padding=1) 128 | self.conv3 = conv2DBatchNormRelu(256, 256, k_size=3, stride=2, padding=1) 129 | 130 | # down 2 131 | self.conv4 = conv2DBatchNormRelu(256, 256, k_size=3, stride=1, padding=1) 132 | self.conv5 = conv2DBatchNormRelu(256, 256, k_size=3, stride=2, padding=1) 133 | 134 | def forward(self, features_map): 135 | outputs1 = self.img_encoder(features_map) 136 | 137 | outputs = self.conv1(outputs1) 138 | outputs = self.conv2(outputs) 139 | outputs = self.conv3(outputs) 140 | outputs = self.conv4(outputs) 141 | outputs = self.conv5(outputs) 142 | return outputs 143 | 144 | 145 | class km_generator(nn.Module): 146 | def __init__(self, out_size=128, input_feat_sz=32): 147 | super(km_generator, self).__init__() 148 | feat_map_sz = input_feat_sz // 4 149 | self.n_feat = int(256 * feat_map_sz * feat_map_sz) 150 | self.fc = nn.Sequential( 151 | nn.Linear(self.n_feat, 256), # 152 | nn.ReLU(inplace=True), 153 | nn.Linear(256, 128), # 154 | nn.ReLU(inplace=True), 155 | nn.Linear(128, out_size)) # 156 | 157 | def forward(self, features_map): 158 | outputs = self.fc(features_map.view(-1, self.n_feat)) 159 | return outputs 160 | 161 | 162 | class linear(nn.Module): 163 | def __init__(self, out_size=128, input_feat_sz=32): 164 | super(linear, self).__init__() 165 | feat_map_sz = input_feat_sz // 4 166 | self.n_feat = int(256 * feat_map_sz * feat_map_sz) 167 | 168 | self.fc = nn.Sequential( 169 | nn.Linear(self.n_feat, 256), 170 | nn.ReLU(inplace=True), 171 | nn.Linear(256, 128), 172 | nn.ReLU(inplace=True), 173 | nn.Linear(128, out_size) 174 | ) 175 | 176 | def forward(self, features_map): 177 | outputs = self.fc(features_map.view(-1, self.n_feat)) 178 | return outputs 179 | 180 | 181 | class conv(nn.Module): 182 | def __init__(self, out_size=128): 183 | super(conv, self).__init__() 184 | feat_map_sz = input_feat_sz // 4 185 | self.conv = conv2DBatchNormRelu(256, out_size, k_size=1, stride=1, padding=1) 186 | 187 | def forward(self, features_map): 188 | outputs = self.conv(features_map) 189 | return outputs 190 | 191 | 192 | 193 | # <------ Attention ------> # 194 | class ScaledDotProductAttention(nn.Module): 195 | ''' Scaled Dot-Product Attention ''' 196 | 197 | def __init__(self, temperature, attn_dropout=0.1): 198 | super().__init__() 199 | self.temperature = temperature 200 | self.sparsemax = Sparsemax(dim=1) 201 | self.softmax = nn.Softmax(dim=1) 202 | 203 | def forward(self, q, k, v, sparse=True): 204 | attn_orig = torch.bmm(k, q.transpose(2, 1)) 205 | attn_orig = attn_orig / self.temperature 206 | if sparse: 207 | attn_orig = self.sparsemax(attn_orig) 208 | else: 209 | attn_orig = self.softmax(attn_orig) 210 | attn = torch.unsqueeze(torch.unsqueeze(attn_orig, 3), 4) 211 | output = attn * v # (batch,4,channel,size,size) 212 | output = output.sum(1) # (batch,1,channel,size,size) 213 | return output, attn_orig.transpose(2, 1) 214 | 215 | class AdditiveAttentin(nn.Module): 216 | def __init__(self): 217 | super().__init__() 218 | # self.dropout = nn.Dropout(attn_dropout) 219 | self.softmax = nn.Softmax(dim=1) 220 | self.sparsemax = Sparsemax(dim=1) 221 | self.linear_feat = nn.Linear(128, 128) 222 | self.linear_context = nn.Linear(128, 128) 223 | self.linear_out = nn.Linear(128, 1) 224 | 225 | def forward(self, q, k, v, sparse=True): 226 | # q (batch,1,128) 227 | # k (batch,4,128) 228 | # v (batch,4,channel,size,size) 229 | temp1 = self.linear_feat(k) # (batch,4,128) 230 | temp2 = self.linear_context(q) # (batch,1,128) 231 | attn_orig = self.linear_out(temp1 + temp2) # (batch,4,1) 232 | if sparse: 233 | attn_orig = self.sparsemax(attn_orig) # (batch,4,1) 234 | else: 235 | attn_orig = self.softmax(attn_orig) # (batch,4,1) 236 | attn = torch.unsqueeze(torch.unsqueeze(attn_orig, 3), 4) # (batch,4,1,1,1) 237 | output = attn * v 238 | output = output.sum(1) # (batch,1,channel,size,size) 239 | return output, attn_orig.transpose(2, 1) 240 | 241 | # MIMO (non warp) 242 | class MIMOGeneralDotProductAttention(nn.Module): 243 | ''' Scaled Dot-Product Attention ''' 244 | 245 | def __init__(self, query_size, key_size, attn_dropout=0.1): 246 | super().__init__() 247 | self.sparsemax = Sparsemax(dim=1) 248 | self.softmax = nn.Softmax(dim=1) 249 | self.linear = nn.Linear(query_size, key_size) 250 | print('Msg size: ',query_size,' Key size: ', key_size) 251 | 252 | def forward(self, qu, k, v, sparse=True): 253 | # qu (batch,5,32) 254 | # k (batch,5,1024) 255 | # v (batch,5,channel,size,size) 256 | query = self.linear(qu) # (batch,5,key_size) 257 | 258 | # normalization 259 | # query_norm = query.norm(p=2,dim=2).unsqueeze(2).expand_as(query) 260 | # query = query.div(query_norm + 1e-9) 261 | 262 | # k_norm = k.norm(p=2,dim=2).unsqueeze(2).expand_as(k) 263 | # k = k.div(k_norm + 1e-9) 264 | 265 | 266 | 267 | # generate the 268 | attn_orig = torch.bmm(k, query.transpose(2, 1)) # (batch,5,5) column: differnt keys and the same query 269 | 270 | # scaling [not sure] 271 | # scaling = torch.sqrt(torch.tensor(k.shape[2],dtype=torch.float32)).cuda() 272 | # attn_orig = attn_orig/ scaling # (batch,5,5) column: differnt keys and the same query 273 | 274 | attn_orig_softmax = self.softmax(attn_orig) # (batch,5,5) 275 | 276 | attn_shape = attn_orig_softmax.shape 277 | bats, key_num, query_num = attn_shape[0], attn_shape[1], attn_shape[2] 278 | attn_orig_softmax_exp = attn_orig_softmax.view(bats, key_num, query_num, 1, 1, 1) 279 | 280 | v_exp = torch.unsqueeze(v, 2) 281 | v_exp = v_exp.expand(-1, -1, query_num, -1, -1, -1) 282 | 283 | output = attn_orig_softmax_exp * v_exp # (batch,4,channel,size,size) 284 | output_sum = output.sum(1) # (batch,1,channel,size,size) 285 | 286 | return output_sum, attn_orig_softmax 287 | 288 | # MIMO always com 289 | class MIMOWhoGeneralDotProductAttention(nn.Module): 290 | ''' Scaled Dot-Product Attention ''' 291 | 292 | def __init__(self, query_size, key_size, attn_dropout=0.1): 293 | super().__init__() 294 | self.sparsemax = Sparsemax(dim=1) 295 | self.softmax = nn.Softmax(dim=1) 296 | self.linear = nn.Linear(query_size, key_size) 297 | print('Msg size: ',query_size,' Key size: ', key_size) 298 | 299 | def forward(self, qu, k, v, sparse=True): 300 | # qu (batch,5,32) 301 | # k (batch,5,1024) 302 | # v (batch,5,channel,size,size) 303 | query = self.linear(qu) # (batch,5,key_size) 304 | 305 | 306 | attn_orig = torch.bmm(k, query.transpose(2, 1)) # (batch,5,5) column: differnt keys and the same query 307 | 308 | 309 | # remove the diagonal and softmax 310 | del_diag_att_orig = [] 311 | for bi in range(attn_orig.shape[0]): 312 | up = torch.triu(attn_orig[bi],diagonal=1,out=None)[:-1,] 313 | dow = torch.tril(attn_orig[bi],diagonal=-1,out=None)[1:,] 314 | del_diag_att_orig_per_sample = torch.unsqueeze((up+dow),dim=0) 315 | del_diag_att_orig.append(del_diag_att_orig_per_sample) 316 | del_diag_att_orig = torch.cat(tuple(del_diag_att_orig), dim=0) 317 | 318 | attn_orig_softmax = self.softmax(del_diag_att_orig) # (batch,5,5) 319 | 320 | append_att_orig = [] 321 | for bi in range(attn_orig_softmax.shape[0]): 322 | up = torch.triu(attn_orig_softmax[bi],diagonal=1,out=None) 323 | up_ext = torch.cat((up, torch.zeros((1, up.shape[1])).cuda())) 324 | dow = torch.tril(attn_orig_softmax[bi],diagonal=0,out=None) 325 | dow_ext = torch.cat((torch.zeros((1, dow.shape[1])).cuda(), dow)) 326 | 327 | append_att_orig_per_sample = torch.unsqueeze((up_ext + dow_ext),dim=0) 328 | append_att_orig.append(append_att_orig_per_sample) 329 | append_att_orig = torch.cat(tuple(append_att_orig), dim=0) 330 | 331 | 332 | 333 | attn_shape = append_att_orig.shape 334 | bats, key_num, query_num = attn_shape[0], attn_shape[1], attn_shape[2] 335 | attn_orig_softmax_exp = append_att_orig.view(bats, key_num, query_num, 1, 1, 1) 336 | 337 | v_exp = torch.unsqueeze(v, 2) 338 | v_exp = v_exp.expand(-1, -1, query_num, -1, -1, -1) 339 | 340 | output = attn_orig_softmax_exp * v_exp # (batch,4,channel,size,size) 341 | output_sum = output.sum(1) # (batch,1,channel,size,size) 342 | 343 | return output_sum, append_att_orig 344 | 345 | class GeneralDotProductAttention(nn.Module): 346 | ''' Scaled Dot-Product Attention ''' 347 | 348 | def __init__(self, query_size, key_size, attn_dropout=0.1): 349 | super().__init__() 350 | self.sparsemax = Sparsemax(dim=1) 351 | self.softmax = nn.Softmax(dim=1) 352 | self.linear = nn.Linear(query_size, key_size) 353 | print('Msg size: ',query_size,' Key size: ', key_size) 354 | 355 | def forward(self, q, k, v, sparse=True): 356 | # q (batch,1,128) 357 | # k (batch,4,128) 358 | # v (batch,4,channel*size*size) 359 | query = self.linear(q) # (batch,1,key_size) 360 | attn_orig = torch.bmm(k, query.transpose(2, 1)) # (batch,4,1) 361 | if sparse: 362 | attn_orig = self.sparsemax(attn_orig) # (batch,4,1) 363 | else: 364 | attn_orig = self.softmax(attn_orig) # (batch,4,1) 365 | attn = torch.unsqueeze(torch.unsqueeze(attn_orig, 3), 4) # (batch,4,1,1,1) 366 | output = attn * v # (batch,4,channel,size,size) 367 | output = output.sum(1) # (batch,1,channel,size,size) 368 | return output, attn_orig.transpose(2, 1) 369 | 370 | # ============= Single normal and Single degarded ============= # 371 | 372 | 373 | # ======================= Model ========================= 374 | 375 | class Single_agent(nn.Module): 376 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, enc_backbone='n_segnet_encoder', 377 | dec_backbone='n_segnet_decoder', feat_squeezer=-1): 378 | ''' 379 | feat_squeezer: -1 (No squeeze), 380 | ''' 381 | super(Single_agent, self).__init__() 382 | self.in_channels = in_channels 383 | 384 | # Encoder 385 | self.encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 386 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 387 | 388 | # Decoder 389 | self.decoder = img_decoder(n_classes=n_classes, in_channels=feat_channel * 1, feat_squeezer=feat_squeezer, 390 | dec_backbone=dec_backbone) 391 | 392 | def forward(self, inputs): 393 | feature_map = self.encoder(inputs) 394 | pred = self.decoder(feature_map) 395 | return pred 396 | 397 | 398 | # Randomly selection baseline and Concatenation of all observations 399 | class All_agents(nn.Module): 400 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, aux_agent_num=4, shuffle_flag=False, 401 | enc_backbone='n_segnet_encoder', dec_backbone='n_segnet_decoder', feat_squeezer=-1): 402 | super(All_agents, self).__init__() 403 | 404 | self.agent_num = aux_agent_num 405 | self.in_channels = in_channels 406 | self.shuffle_flag = shuffle_flag 407 | 408 | # Encoder 409 | self.encoder1 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 410 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 411 | self.encoder2 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 412 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 413 | self.encoder3 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 414 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 415 | self.encoder4 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 416 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 417 | self.encoder5 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 418 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 419 | 420 | # Decoder for interested agent 421 | if self.shuffle_flag == 'selection': # random selection 422 | self.decoder = img_decoder(n_classes=n_classes, in_channels=feat_channel * 2 , 423 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone) 424 | else: # catall 425 | self.decoder = img_decoder(n_classes=n_classes, in_channels=feat_channel * self.agent_num, 426 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone) 427 | 428 | def divide_inputs(self, inputs): 429 | ''' 430 | Divide the input into a list of several images 431 | ''' 432 | input_list = [] 433 | divide_num = 5 434 | for i in range(divide_num): 435 | input_list.append(inputs[:, 3 * i:3 * i + 3, :, :]) 436 | 437 | return input_list 438 | 439 | def forward(self, inputs): 440 | # agent_num = 5 441 | 442 | input_list = self.divide_inputs(inputs) 443 | feat_map1 = self.encoder1(input_list[0]) 444 | feat_map2 = self.encoder2(input_list[1]) 445 | feat_map3 = self.encoder3(input_list[2]) 446 | feat_map4 = self.encoder4(input_list[3]) 447 | feat_map5 = self.encoder5(input_list[4]) 448 | 449 | if self.shuffle_flag == 'selection': # use randomly picked feature and only specific numbers 450 | aux_view_feats = [feat_map1,feat_map2, feat_map3, feat_map4, feat_map5] 451 | aux_id = random.randint(0, 4) 452 | aux_view_feats = torch.unsqueeze(aux_view_feats[aux_id], 0) 453 | feat_map_list = (feat_map1,) + tuple(aux_view_feats) 454 | argmax_action = torch.ones(feat_map1.shape[0], dtype=torch.long)*aux_id 455 | 456 | elif self.shuffle_flag == 'fixed2': 457 | feat_map_list = (feat_map1, feat_map2) 458 | 459 | else: 460 | feat_map_list = (feat_map1, feat_map2, feat_map3, feat_map4, feat_map5) 461 | 462 | # combine the feat maps 463 | concat_featmaps = torch.cat(feat_map_list, 1) 464 | pred = self.decoder(concat_featmaps) 465 | 466 | if self.shuffle_flag == 'selection': # use randomly picked feature and only specific numbers 467 | return pred, argmax_action.cuda() 468 | else: 469 | return pred 470 | 471 | 472 | class LearnWho2Com(nn.Module): 473 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, feat_squeezer=-1, attention='additive', 474 | has_query=True, sparse=False, aux_agent_num=4, shuffle_flag=False, image_size=512, 475 | shared_img_encoder=False \ 476 | , key_size=128, query_size=128, enc_backbone='n_segnet_encoder', dec_backbone='n_segnet_decoder'): 477 | super(LearnWho2Com, self).__init__() 478 | # agent_num = 2 479 | self.aux_agent_num = aux_agent_num 480 | self.in_channels = in_channels 481 | self.shuffle_flag = shuffle_flag 482 | self.feature_map_channel = 512 483 | self.key_size = key_size 484 | self.query_size = query_size 485 | self.shared_img_encoder = shared_img_encoder 486 | self.has_query = has_query 487 | self.sparse = sparse 488 | # Encoder 489 | # Non-shared 490 | 491 | if self.shared_img_encoder == 'unified': 492 | self.u_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 493 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 494 | elif self.shared_img_encoder == 'only_normal_agents': 495 | self.degarded_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 496 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 497 | self.normal_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 498 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 499 | 500 | else: 501 | self.encoder1 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 502 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 503 | self.encoder2 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 504 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 505 | self.encoder3 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 506 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 507 | self.encoder4 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 508 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 509 | self.encoder5 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 510 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 511 | 512 | # # Message generator 513 | self.query_key_net = policy_net4(n_classes=n_classes, in_channels=in_channels, enc_backbone=enc_backbone) 514 | if self.has_query: 515 | self.query_net = linear(out_size=self.query_size, input_feat_sz=image_size / 32) 516 | 517 | self.key_net = linear(out_size=self.key_size, input_feat_sz=image_size / 32) 518 | if attention == 'additive': 519 | self.attention_net = AdditiveAttentin() 520 | elif attention == 'general': 521 | self.attention_net = GeneralDotProductAttention(self.query_size, self.key_size) 522 | else: 523 | self.attention_net = ScaledDotProductAttention(128 ** 0.5) 524 | 525 | # Segmentation decoder 526 | 527 | self.decoder = img_decoder(n_classes=n_classes, in_channels=self.feature_map_channel * 2, 528 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone) 529 | # List the parameters of each modules 530 | self.attention_paras = list(self.attention_net.parameters()) 531 | if self.shared_img_encoder == 'unified': 532 | self.img_net_paras = list(self.u_encoder.parameters()) + list(self.decoder.parameters()) 533 | elif self.shared_img_encoder == 'only_normal_agents': 534 | self.img_net_paras = list(self.degarded_encoder.parameters()) + list( 535 | self.normal_encoder.parameters()) + list(self.decoder.parameters()) 536 | else: 537 | self.img_net_paras = list(self.encoder1.parameters()) + \ 538 | list(self.encoder2.parameters()) + \ 539 | list(self.encoder3.parameters()) + \ 540 | list(self.encoder4.parameters()) + \ 541 | list(self.encoder5.parameters()) + \ 542 | list(self.decoder.parameters()) 543 | 544 | self.policy_net_paras = list(self.query_key_net.parameters()) + list( 545 | self.key_net.parameters()) + self.attention_paras 546 | if self.has_query: 547 | self.policy_net_paras = self.policy_net_paras + list(self.query_net.parameters()) 548 | 549 | self.all_paras = self.img_net_paras + self.policy_net_paras 550 | 551 | def divide_inputs(self, inputs): 552 | ''' 553 | Divide the input into a list of several images 554 | ''' 555 | input_list = [] 556 | divide_num = 5 557 | for i in range(divide_num): 558 | input_list.append(inputs[:, 3 * i:3 * i + 3, :, :]) 559 | 560 | return input_list 561 | 562 | def shape_feat_map(self, feat, size): 563 | return torch.unsqueeze(feat.view(-1, size[1] * size[2] * size[3]), 1) 564 | 565 | def forward(self, inputs, training=True, inference='argmax'): 566 | 567 | batch_size, _, _, _ = inputs.size() 568 | input_list = self.divide_inputs(inputs) 569 | if self.shared_img_encoder == 'unified': 570 | # vectorize 571 | unified_feat_map = torch.cat((input_list[0], input_list[1], input_list[2], input_list[3], input_list[4]), 0) 572 | feat_map = self.u_encoder(unified_feat_map) 573 | feat_map1 = feat_map[0:batch_size * 1] 574 | feat_map2 = feat_map[batch_size * 1:batch_size * 2] 575 | feat_map3 = feat_map[batch_size * 2:batch_size * 3] 576 | feat_map4 = feat_map[batch_size * 3:batch_size * 4] 577 | feat_map5 = feat_map[batch_size * 4:batch_size * 5] 578 | 579 | elif self.shared_img_encoder == 'only_normal_agents': 580 | feat_map1 = self.degarded_encoder(input_list[0]) 581 | # vectorize 582 | unified_normal_feat_map = torch.cat((input_list[1], input_list[2], input_list[3], input_list[4]), 0) 583 | feat_map = self.normal_encoder(unified_normal_feat_map) 584 | feat_map2 = feat_map[0:batch_size * 1] 585 | feat_map3 = feat_map[batch_size * 1:batch_size * 2] 586 | feat_map4 = feat_map[batch_size * 2:batch_size * 3] 587 | feat_map5 = feat_map[batch_size * 3:batch_size * 4] 588 | else: 589 | feat_map1 = self.encoder1(input_list[0]) 590 | feat_map2 = self.encoder2(input_list[1]) 591 | feat_map3 = self.encoder3(input_list[2]) 592 | feat_map4 = self.encoder4(input_list[3]) 593 | feat_map5 = self.encoder5(input_list[4]) 594 | 595 | unified_feat_map = torch.cat((input_list[0], input_list[1], input_list[2], input_list[3], input_list[4]), 0) 596 | query_key_map = self.query_key_net(unified_feat_map) 597 | query_key_map1 = query_key_map[0:batch_size * 1] 598 | query_key_map2 = query_key_map[batch_size * 1:batch_size * 2] 599 | query_key_map3 = query_key_map[batch_size * 2:batch_size * 3] 600 | query_key_map4 = query_key_map[batch_size * 3:batch_size * 4] 601 | query_key_map5 = query_key_map[batch_size * 4:batch_size * 5] 602 | 603 | key2 = torch.unsqueeze(self.key_net(query_key_map2), 1) 604 | key3 = torch.unsqueeze(self.key_net(query_key_map3), 1) 605 | key4 = torch.unsqueeze(self.key_net(query_key_map4), 1) 606 | key5 = torch.unsqueeze(self.key_net(query_key_map5), 1) 607 | 608 | if self.has_query: 609 | query = torch.unsqueeze(self.query_net(query_key_map1), 1) # (batch,1,128) 610 | else: 611 | query = torch.ones(batch_size, 1, self.query_size).to('cuda') 612 | 613 | feat_map2 = torch.unsqueeze(feat_map2, 1) # (batch,1,channel,size,size) 614 | feat_map3 = torch.unsqueeze(feat_map3, 1) # (batch,1,channel,size,size) 615 | feat_map4 = torch.unsqueeze(feat_map4, 1) # (batch,1,channel,size,size) 616 | feat_map5 = torch.unsqueeze(feat_map5, 1) # (batch,1,channel,size,size) 617 | 618 | keys = torch.cat((key2, key3, key4, key5), 1) # (batch,4,128) 619 | vals = torch.cat((feat_map2, feat_map3, feat_map4, feat_map5), 1) # (batch,4,channel,size,size) 620 | aux_feat, prob_action = self.attention_net(query, keys, vals, 621 | sparse=self.sparse) # (batch,1,#channel*size*size),(batch,1,4) 622 | # print('Action: ', prob_action) 623 | concat_featmaps = torch.cat((feat_map1, aux_feat), 1) 624 | pred = self.decoder(concat_featmaps) 625 | if training: 626 | action = torch.argmax(prob_action, dim=2) 627 | return pred, prob_action, action 628 | else: 629 | if inference == 'softmax': 630 | action = torch.argmax(prob_action, dim=2) 631 | return pred, prob_action, action 632 | elif inference == 'argmax_test': 633 | action = torch.argmax(prob_action, dim=2) 634 | aux_feat_list = [] 635 | for k in range(batch_size): 636 | if action[k] == 0: 637 | aux_feat_list.append(torch.unsqueeze(feat_map2[k], 0)) 638 | elif action[k] == 1: 639 | aux_feat_list.append(torch.unsqueeze(feat_map3[k], 0)) 640 | elif action[k] == 2: 641 | aux_feat_list.append(torch.unsqueeze(feat_map4[k], 0)) 642 | elif action[k] == 3: 643 | aux_feat_list.append(torch.unsqueeze(feat_map5[k], 0)) 644 | else: 645 | raise ValueError('Incorrect action') 646 | aux_feat_argmax = tuple(aux_feat_list) 647 | aux_feat_argmax = torch.cat(aux_feat_argmax, 0) 648 | aux_feat_argmax = torch.squeeze(aux_feat_argmax, 1) 649 | concat_featmaps_argmax = torch.cat((feat_map1.detach(), aux_feat_argmax.detach()), 1) 650 | pred_argmax = self.decoder(concat_featmaps_argmax) 651 | return pred_argmax, prob_action, action 652 | elif inference == 'argmax_train': 653 | action = torch.argmax(prob_action, dim=2) 654 | aux_feat_list = [] 655 | for k in range(batch_size): 656 | if action[k] == 0: 657 | aux_feat_list.append(torch.unsqueeze(feat_map2[k], 0)) 658 | elif action[k] == 1: 659 | aux_feat_list.append(torch.unsqueeze(feat_map3[k], 0)) 660 | elif action[k] == 2: 661 | aux_feat_list.append(torch.unsqueeze(feat_map4[k], 0)) 662 | elif action[k] == 3: 663 | aux_feat_list.append(torch.unsqueeze(feat_map5[k], 0)) 664 | else: 665 | raise ValueError('Incorrect action') 666 | aux_feat_argmax = tuple(aux_feat_list) 667 | aux_feat_argmax = torch.cat(aux_feat_argmax, 0) 668 | aux_feat_argmax = torch.squeeze(aux_feat_argmax, 1) 669 | concat_featmaps_argmax = torch.cat((feat_map1.detach(), aux_feat_argmax.detach()), 1) 670 | pred_argmax = self.argmax_decoder(concat_featmaps_argmax) 671 | return pred_argmax, prob_action, action 672 | else: 673 | raise ValueError('Incorrect inference mode') 674 | 675 | 676 | class LearnWhen2Com(nn.Module): 677 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, feat_squeezer=-1, attention='additive', 678 | has_query=True, sparse=False, aux_agent_num=4, shuffle_flag=False, image_size=512, 679 | shared_img_encoder=False, key_size=128, query_size=128, enc_backbone='n_segnet_encoder', 680 | dec_backbone='n_segnet_decoder'): 681 | super(LearnWhen2Com, self).__init__() 682 | # agent_num = 2 683 | self.aux_agent_num = aux_agent_num 684 | self.in_channels = in_channels 685 | self.shuffle_flag = shuffle_flag 686 | self.feature_map_channel = 512 687 | self.key_size = key_size 688 | self.query_size = query_size 689 | self.shared_img_encoder = shared_img_encoder 690 | self.has_query = has_query 691 | self.sparse = sparse 692 | # Encoder 693 | # Non-shared 694 | if self.shared_img_encoder == 'unified': 695 | self.u_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 696 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 697 | elif self.shared_img_encoder == 'only_normal_agents': 698 | self.degarded_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 699 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 700 | self.normal_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 701 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 702 | 703 | else: 704 | self.encoder1 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 705 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 706 | self.encoder2 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 707 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 708 | self.encoder3 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 709 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 710 | self.encoder4 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 711 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 712 | self.encoder5 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 713 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 714 | 715 | # # Message generator 716 | self.query_key_net = policy_net4(n_classes=n_classes, in_channels=in_channels, enc_backbone=enc_backbone) 717 | if self.has_query: 718 | self.query_net = linear(out_size=self.query_size,input_feat_sz=image_size / 32) 719 | self.key_net = linear(out_size=self.key_size,input_feat_sz=image_size / 32) 720 | if attention == 'additive': 721 | self.attention_net = AdditiveAttentin() 722 | elif attention == 'general': 723 | self.attention_net = GeneralDotProductAttention(self.query_size, self.key_size) 724 | else: 725 | self.attention_net = ScaledDotProductAttention(128 ** 0.5) 726 | 727 | # Segmentation decoder 728 | self.argmax_decoder = img_decoder(n_classes=n_classes, in_channels=self.feature_map_channel, 729 | agent_num=self.aux_agent_num + 1, dec_backbone=dec_backbone) 730 | 731 | self.decoder = img_decoder(n_classes=n_classes, in_channels=self.feature_map_channel, 732 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone) 733 | # List the parameters of each modules 734 | self.attention_paras = list(self.attention_net.parameters()) 735 | if self.shared_img_encoder == 'unified': 736 | self.img_net_paras = list(self.u_encoder.parameters()) + list(self.decoder.parameters()) 737 | elif self.shared_img_encoder == 'only_normal_agents': 738 | self.img_net_paras = list(self.degarded_encoder.parameters()) + list( 739 | self.normal_encoder.parameters()) + list(self.decoder.parameters()) 740 | else: 741 | self.img_net_paras = list(self.encoder1.parameters()) + \ 742 | list(self.encoder2.parameters()) + \ 743 | list(self.encoder3.parameters()) + \ 744 | list(self.encoder4.parameters()) + \ 745 | list(self.encoder5.parameters()) + \ 746 | list(self.decoder.parameters()) 747 | 748 | self.img_net_paras = self.img_net_paras + list(self.argmax_decoder.parameters()) 749 | 750 | self.policy_net_paras = list(self.query_key_net.parameters()) + list( 751 | self.key_net.parameters()) + self.attention_paras 752 | if self.has_query: 753 | self.policy_net_paras = self.policy_net_paras + list(self.query_net.parameters()) 754 | 755 | 756 | self.all_paras = self.img_net_paras + self.policy_net_paras 757 | 758 | def divide_inputs(self, inputs): 759 | ''' 760 | Divide the input into a list of several images 761 | ''' 762 | input_list = [] 763 | divide_num = 5 764 | for i in range(divide_num): 765 | input_list.append(inputs[:, 3 * i:3 * i + 3, :, :]) 766 | 767 | return input_list 768 | 769 | def shape_feat_map(self, feat, size): 770 | return torch.unsqueeze(feat.view(-1, size[1] * size[2] * size[3]), 1) 771 | 772 | def argmax_select(self, feat_map1, feat_map2, feat_map3, feat_map4, feat_map5, action, batch_size): 773 | 774 | num_connect = 0 775 | feat_list = [] 776 | for k in range(batch_size): 777 | if action[k] == 0: 778 | feat_list.append(torch.unsqueeze(feat_map1[k], 0)) 779 | elif action[k] == 1: 780 | feat_list.append(torch.unsqueeze(feat_map2[k], 0)) 781 | num_connect = num_connect + 1 782 | elif action[k] == 2: 783 | feat_list.append(torch.unsqueeze(feat_map3[k], 0)) 784 | num_connect = num_connect + 1 785 | elif action[k] == 3: 786 | feat_list.append(torch.unsqueeze(feat_map4[k], 0)) 787 | num_connect = num_connect + 1 788 | elif action[k] == 4: 789 | feat_list.append(torch.unsqueeze(feat_map5[k], 0)) 790 | num_connect = num_connect + 1 791 | else: 792 | raise ValueError('Incorrect action') 793 | num_connect = num_connect / batch_size 794 | 795 | feat_argmax = tuple(feat_list) 796 | feat_argmax = torch.cat(feat_argmax, 0) 797 | feat_argmax = torch.squeeze(feat_argmax, 1) 798 | return feat_argmax, num_connect 799 | 800 | def activated_select(self, vals, W_mat, thres=0.2): 801 | action = torch.mul(W_mat, (W_mat > thres).float()) 802 | attn = action.view(action.shape[0], action.shape[2], 1, 1, 1) # (batch,5,1,1,1) 803 | output = attn * vals 804 | feat_fuse = output.sum(1) # (batch,1,channel,size,size) 805 | 806 | batch_size = action.shape[0] 807 | num_connect = torch.nonzero(action[:, :, 1:]).shape[0] / batch_size 808 | 809 | return feat_fuse, action, num_connect 810 | 811 | def forward(self, inputs, training=True, inference='argmax'): 812 | batch_size, _, _, _ = inputs.size() 813 | input_list = self.divide_inputs(inputs) 814 | if self.shared_img_encoder == 'unified': 815 | unified_feat_map = torch.cat((input_list[0], input_list[1], input_list[2], input_list[3], input_list[4]), 0) 816 | feat_map = self.u_encoder(unified_feat_map) 817 | feat_map1 = feat_map[0:batch_size * 1] 818 | feat_map2 = feat_map[batch_size * 1:batch_size * 2] 819 | feat_map3 = feat_map[batch_size * 2:batch_size * 3] 820 | feat_map4 = feat_map[batch_size * 3:batch_size * 4] 821 | feat_map5 = feat_map[batch_size * 4:batch_size * 5] 822 | 823 | elif self.shared_img_encoder == 'only_normal_agents': 824 | feat_map1 = self.degarded_encoder(input_list[0]) 825 | unified_normal_feat_map = torch.cat((input_list[1], input_list[2], input_list[3], input_list[4]), 0) 826 | feat_map = self.normal_encoder(unified_normal_feat_map) 827 | feat_map2 = feat_map[0:batch_size * 1] 828 | feat_map3 = feat_map[batch_size * 1:batch_size * 2] 829 | feat_map4 = feat_map[batch_size * 2:batch_size * 3] 830 | feat_map5 = feat_map[batch_size * 3:batch_size * 4] 831 | else: 832 | feat_map1 = self.encoder1(input_list[0]) 833 | feat_map2 = self.encoder2(input_list[1]) 834 | feat_map3 = self.encoder3(input_list[2]) 835 | feat_map4 = self.encoder4(input_list[3]) 836 | feat_map5 = self.encoder5(input_list[4]) 837 | 838 | unified_feat_map = torch.cat((input_list[0], input_list[1], input_list[2], input_list[3], input_list[4]), 0) 839 | query_key_map = self.query_key_net(unified_feat_map) 840 | query_key_map1 = query_key_map[0:batch_size * 1] 841 | 842 | keys = self.key_net(query_key_map) 843 | key1 = torch.unsqueeze(keys[0:batch_size * 1], 1) 844 | key2 = torch.unsqueeze(keys[batch_size * 1:batch_size * 2], 1) 845 | key3 = torch.unsqueeze(keys[batch_size * 2:batch_size * 3], 1) 846 | key4 = torch.unsqueeze(keys[batch_size * 3:batch_size * 4], 1) 847 | key5 = torch.unsqueeze(keys[batch_size * 4:batch_size * 5], 1) 848 | 849 | 850 | if self.has_query: 851 | querys = self.query_net(query_key_map) 852 | query = torch.unsqueeze(querys[0:batch_size * 1], 1) 853 | else: 854 | query = torch.ones(batch_size, 1, self.query_size).to('cuda') 855 | 856 | feat_map1 = torch.unsqueeze(feat_map1, 1) # (batch,1,channel,size,size) 857 | feat_map2 = torch.unsqueeze(feat_map2, 1) # (batch,1,channel,size,size) 858 | feat_map3 = torch.unsqueeze(feat_map3, 1) # (batch,1,channel,size,size) 859 | feat_map4 = torch.unsqueeze(feat_map4, 1) # (batch,1,channel,size,size) 860 | feat_map5 = torch.unsqueeze(feat_map5, 1) # (batch,1,channel,size,size) 861 | 862 | keys = torch.cat((key1, key2, key3, key4, key5), 1) # (batch,4,128) 863 | vals = torch.cat((feat_map1, feat_map2, feat_map3, feat_map4, feat_map5), 1) # (batch,4,channel,size,size) 864 | aux_feat, prob_action = self.attention_net(query, keys, vals, 865 | sparse=self.sparse) # (batch,1,#channel*size*size),(batch,1,4) 866 | pred = self.decoder(aux_feat) 867 | 868 | if training: 869 | action = torch.argmax(prob_action, dim=2) 870 | return pred, prob_action, action 871 | else: 872 | if inference == 'softmax': 873 | action = torch.argmax(prob_action, dim=2) 874 | num_connect = 4 875 | return pred, prob_action, action, num_connect 876 | elif inference == 'argmax_test': 877 | action = torch.argmax(prob_action, dim=2) 878 | feat_argmax, num_connect = self.argmax_select(feat_map1, feat_map2, feat_map3, feat_map4, feat_map5, 879 | action, batch_size) 880 | featmaps_argmax = feat_argmax.detach() 881 | pred_argmax = self.decoder(featmaps_argmax) 882 | return pred_argmax, prob_action, action, num_connect 883 | elif inference == 'activated': 884 | feat_act, action, num_connect = self.activated_select(vals, prob_action) 885 | featmaps_act = feat_act.detach() 886 | pred_act = self.decoder(featmaps_act) 887 | return pred_act, prob_action, action, num_connect 888 | else: 889 | raise ValueError('Incorrect inference mode') 890 | 891 | 892 | class MIMO_All_agents(nn.Module): 893 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, aux_agent_num=4, shuffle_flag=False, 894 | enc_backbone='n_segnet_encoder', dec_backbone='n_segnet_decoder', feat_squeezer=-1): 895 | super(MIMO_All_agents, self).__init__() 896 | 897 | self.agent_num = aux_agent_num # include the target agent 898 | self.in_channels = in_channels 899 | self.shuffle_flag = shuffle_flag 900 | 901 | self.encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 902 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 903 | 904 | 905 | 906 | # Decoder for interested agent 907 | if self.shuffle_flag == 'selection': # random selection 908 | self.decoder = img_decoder(n_classes=n_classes, in_channels=feat_channel * 2 , 909 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone) 910 | elif self.shuffle_flag == 'ComNet': # random selection 911 | self.decoder = img_decoder(n_classes=n_classes, in_channels=feat_channel * 2 , 912 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone) 913 | else: # Catall 914 | self.decoder = img_decoder(n_classes=n_classes, in_channels=feat_channel * self.agent_num, 915 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone) 916 | 917 | def divide_inputs(self, inputs): 918 | ''' 919 | Divide the input into a list of several images 920 | ''' 921 | input_list = [] 922 | divide_num = self.agent_num 923 | for i in range(divide_num): 924 | input_list.append(inputs[:, 3 * i:3 * i + 3, :, :]) 925 | return input_list 926 | 927 | def forward(self, inputs): 928 | input_list = self.divide_inputs(inputs) 929 | 930 | feat_map = [] 931 | for i in range(self.agent_num): 932 | fmp = self.encoder(input_list[i]) 933 | feat_map.append(fmp) 934 | 935 | if self.shuffle_flag == 'selection': # use randomly picked feature and only specific numbers 936 | # randomly select 937 | concat_fmp_list = [] 938 | rand_action_list = [] 939 | for i in range(self.agent_num): 940 | randidx = random.randint(0, self.agent_num-1) 941 | rand_action_list.append( torch.ones((feat_map[0].shape[0],1), dtype=torch.long)*randidx) 942 | 943 | fmp_per_agent = torch.cat((feat_map[i], feat_map[randidx]), 1) # fmp_list = [bs, channel*2, h, w] 944 | concat_fmp_list.append(fmp_per_agent) 945 | concat_featmaps = torch.cat(tuple(concat_fmp_list), 0) # concat_featmaps = [bs*agent_num, channel*2, h, w] 946 | pred = self.decoder(concat_featmaps) 947 | argmax_action = torch.cat(tuple(rand_action_list), 1) 948 | elif self.shuffle_flag == 'ComNet': # use randomly picked feature and only specific numbers 949 | 950 | # randomly select 951 | concat_fmp_list = [] 952 | for i in range(self.agent_num): 953 | sum_other_feat = torch.zeros((feat_map[0].shape[0], feat_map[0].shape[1],feat_map[0].shape[2],feat_map[0].shape[3])).cuda() 954 | for j in range(self.agent_num): 955 | if j!= i: # other agents 956 | sum_other_feat += feat_map[j] 957 | avg_other_feat = sum_other_feat/(self.agent_num-1) 958 | 959 | fmp_per_agent = torch.cat((feat_map[i], avg_other_feat), 1) # fmp_list = [bs, channel*2, h, w] 960 | concat_fmp_list.append(fmp_per_agent) 961 | concat_featmaps = torch.cat(tuple(concat_fmp_list), 0) # concat_featmaps = [bs*agent_num, channel*2, h, w] 962 | pred = self.decoder(concat_featmaps) 963 | 964 | else: 965 | concat_fmp_list = [] 966 | for i in range(self.agent_num): 967 | fmp_list = [] 968 | for j in range(self.agent_num): 969 | fmp_list.append(feat_map[ (i+j)%self.agent_num ]) 970 | fmp_per_agent = torch.cat(tuple(fmp_list), 1) # fmp_list = [bs, channel*agent_num, h, w] 971 | concat_fmp_list.append(fmp_per_agent) 972 | concat_featmaps = torch.cat(tuple(concat_fmp_list), 0) # concat_featmaps = [bs*agent_num, channel*agent_num, h, w] 973 | 974 | pred = self.decoder(concat_featmaps) 975 | 976 | 977 | if self.shuffle_flag == 'selection': # use randomly picked feature and only specific numbers 978 | return pred, argmax_action.cuda() 979 | else: 980 | return pred 981 | 982 | # our model (no warping) 983 | class MIMOcom(nn.Module): 984 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, feat_squeezer=-1, attention='additive', 985 | has_query=True, sparse=False, agent_num=5, shuffle_flag=False, image_size=512, 986 | shared_img_encoder=False, key_size=128, query_size=128, enc_backbone='n_segnet_encoder', 987 | dec_backbone='n_segnet_decoder'): 988 | super(MIMOcom, self).__init__() 989 | 990 | self.agent_num = agent_num 991 | self.in_channels = in_channels 992 | self.shuffle_flag = shuffle_flag 993 | self.feature_map_channel = 512 994 | self.key_size = key_size 995 | self.query_size = query_size 996 | self.shared_img_encoder = shared_img_encoder 997 | self.has_query = has_query 998 | self.sparse = sparse 999 | 1000 | 1001 | print('When2com') # our model: detach the learning of values and keys 1002 | self.u_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 1003 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 1004 | 1005 | self.key_net = km_generator(out_size=self.key_size, input_feat_sz=image_size / 32) 1006 | self.attention_net = MIMOGeneralDotProductAttention(self.query_size, self.key_size) 1007 | 1008 | # # Message generator 1009 | self.query_key_net = policy_net4(n_classes=n_classes, in_channels=in_channels, enc_backbone=enc_backbone) 1010 | if self.has_query: 1011 | self.query_net = km_generator(out_size=self.query_size, input_feat_sz=image_size / 32) 1012 | 1013 | # Segmentation decoder 1014 | self.decoder = img_decoder(n_classes=n_classes, in_channels=self.feature_map_channel, 1015 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone) 1016 | 1017 | 1018 | # List the parameters of each modules 1019 | self.attention_paras = list(self.attention_net.parameters()) 1020 | if self.shared_img_encoder == 'unified': 1021 | self.img_net_paras = list(self.u_encoder.parameters()) + list(self.decoder.parameters()) 1022 | 1023 | 1024 | 1025 | self.policy_net_paras = list(self.query_key_net.parameters()) + list( 1026 | self.key_net.parameters()) + self.attention_paras 1027 | if self.has_query: 1028 | self.policy_net_paras = self.policy_net_paras + list(self.query_net.parameters()) 1029 | 1030 | self.all_paras = self.img_net_paras + self.policy_net_paras 1031 | 1032 | 1033 | def shape_feat_map(self, feat, size): 1034 | return torch.unsqueeze(feat.view(-1, size[1] * size[2] * size[3]), 1) 1035 | 1036 | def argmax_select(self, val_mat, prob_action): 1037 | # v(batch, query_num, channel, size, size) 1038 | cls_num = prob_action.shape[1] 1039 | 1040 | coef_argmax = F.one_hot(prob_action.max(dim=1)[1], num_classes=cls_num).type(torch.cuda.FloatTensor) 1041 | coef_argmax = coef_argmax.transpose(1, 2) 1042 | attn_shape = coef_argmax.shape 1043 | bats, key_num, query_num = attn_shape[0], attn_shape[1], attn_shape[2] 1044 | coef_argmax_exp = coef_argmax.view(bats, key_num, query_num, 1, 1, 1) 1045 | 1046 | v_exp = torch.unsqueeze(val_mat, 2) 1047 | v_exp = v_exp.expand(-1, -1, query_num, -1, -1, -1) 1048 | 1049 | output = coef_argmax_exp * v_exp # (batch,4,channel,size,size) 1050 | feat_argmax = output.sum(1) # (batch,1,channel,size,size) 1051 | 1052 | # compute connect 1053 | count_coef = copy.deepcopy(coef_argmax) 1054 | ind = np.diag_indices(self.agent_num) 1055 | count_coef[:, ind[0], ind[1]] = 0 1056 | num_connect = torch.nonzero(count_coef).shape[0] / (self.agent_num * count_coef.shape[0]) 1057 | 1058 | return feat_argmax, coef_argmax, num_connect 1059 | 1060 | def activated_select(self, val_mat, prob_action, thres=0.2): 1061 | 1062 | coef_act = torch.mul(prob_action, (prob_action > thres).float()) 1063 | attn_shape = coef_act.shape 1064 | bats, key_num, query_num = attn_shape[0], attn_shape[1], attn_shape[2] 1065 | coef_act_exp = coef_act.view(bats, key_num, query_num, 1, 1, 1) 1066 | 1067 | v_exp = torch.unsqueeze(val_mat, 2) 1068 | v_exp = v_exp.expand(-1, -1, query_num, -1, -1, -1) 1069 | 1070 | output = coef_act_exp * v_exp # (batch,4,channel,size,size) 1071 | feat_act = output.sum(1) # (batch,1,channel,size,size) 1072 | 1073 | # compute connect 1074 | count_coef = coef_act.clone() 1075 | ind = np.diag_indices(self.agent_num) 1076 | count_coef[:, ind[0], ind[1]] = 0 1077 | num_connect = torch.nonzero(count_coef).shape[0] / (self.agent_num * count_coef.shape[0]) 1078 | return feat_act, coef_act, num_connect 1079 | 1080 | def agents2batch(self, feats): 1081 | agent_num = feats.shape[1] 1082 | feat_list = [] 1083 | for i in range(agent_num): 1084 | feat_list.append(feats[:, i, :, :, :]) 1085 | feat_mat = torch.cat(tuple(feat_list), 0) 1086 | return feat_mat 1087 | 1088 | def divide_inputs(self, inputs): 1089 | ''' 1090 | Divide the input into a list of several images 1091 | ''' 1092 | input_list = [] 1093 | for i in range(self.agent_num): 1094 | input_list.append(inputs[:, 3 * i:3 * i + 3, :, :]) 1095 | 1096 | return input_list 1097 | 1098 | def forward(self, inputs, training=True, MO_flag=False , inference='argmax'): 1099 | batch_size, _, _, _ = inputs.size() 1100 | 1101 | input_list = self.divide_inputs(inputs) 1102 | 1103 | if self.shared_img_encoder == 'unified': 1104 | # vectorize input list 1105 | img_list = [] 1106 | for i in range(self.agent_num): 1107 | img_list.append(input_list[i]) 1108 | unified_img_list = torch.cat(tuple(img_list), 0) 1109 | 1110 | # pass encoder 1111 | feat_maps = self.u_encoder(unified_img_list) 1112 | 1113 | # get feat maps for each image 1114 | feat_map = {} 1115 | feat_list = [] 1116 | for i in range(self.agent_num): 1117 | feat_map[i] = torch.unsqueeze(feat_maps[batch_size * i:batch_size * (i + 1)], 1) 1118 | feat_list.append(feat_map[i]) 1119 | val_mat = torch.cat(tuple(feat_list), 1) 1120 | else: 1121 | raise ValueError('Incorrect encoder') 1122 | 1123 | # pass feature maps through key and query generator 1124 | query_key_maps = self.query_key_net(unified_img_list) 1125 | 1126 | keys = self.key_net(query_key_maps) 1127 | 1128 | if self.has_query: 1129 | querys = self.query_net(query_key_maps) 1130 | 1131 | # get key and query 1132 | key = {} 1133 | query = {} 1134 | key_list = [] 1135 | query_list = [] 1136 | 1137 | for i in range(self.agent_num): 1138 | key[i] = torch.unsqueeze(keys[batch_size * i:batch_size * (i + 1)], 1) 1139 | key_list.append(key[i]) 1140 | if self.has_query: 1141 | query[i] = torch.unsqueeze(querys[batch_size * i:batch_size * (i + 1)], 1) 1142 | else: 1143 | query[i] = torch.ones(batch_size, 1, self.query_size).to('cuda') 1144 | query_list.append(query[i]) 1145 | 1146 | 1147 | key_mat = torch.cat(tuple(key_list), 1) 1148 | query_mat = torch.cat(tuple(query_list), 1) 1149 | 1150 | if MO_flag: 1151 | query_mat = query_mat 1152 | else: 1153 | query_mat = torch.unsqueeze(query_mat[:,0,:],1) 1154 | 1155 | feat_fuse, prob_action = self.attention_net(query_mat, key_mat, val_mat, sparse=self.sparse) 1156 | 1157 | 1158 | # weighted feature maps is passed to decoder 1159 | feat_fuse_mat = self.agents2batch(feat_fuse) 1160 | 1161 | pred = self.decoder(feat_fuse_mat) 1162 | 1163 | # not related to how we combine the feature (prefer to use the agnets' own frames: to reduce the bandwidth) 1164 | small_bis = torch.eye(prob_action.shape[1])*0.001 1165 | small_bis = small_bis.reshape((1, prob_action.shape[1], prob_action.shape[2])) 1166 | small_bis = small_bis.repeat(prob_action.shape[0], 1, 1).cuda() 1167 | prob_action = prob_action + small_bis 1168 | 1169 | 1170 | if training: 1171 | action = torch.argmax(prob_action, dim=1) 1172 | num_connect = self.agent_num - 1 1173 | 1174 | return pred, prob_action, action, num_connect 1175 | else: 1176 | if inference == 'softmax': 1177 | action = torch.argmax(prob_action, dim=1) 1178 | num_connect = self.agent_num - 1 1179 | 1180 | return pred, prob_action, action, num_connect 1181 | 1182 | elif inference == 'argmax_test': 1183 | 1184 | feat_argmax, connect_mat, num_connect = self.argmax_select(val_mat, prob_action) 1185 | 1186 | feat_argmax_mat = self.agents2batch(feat_argmax) # (batchsize*agent_num, channel, size, size) 1187 | feat_argmax_mat = feat_argmax_mat.detach() 1188 | pred_argmax = self.decoder(feat_argmax_mat) 1189 | action = torch.argmax(connect_mat, dim=1) 1190 | return pred_argmax, prob_action, action, num_connect 1191 | 1192 | elif inference == 'activated': 1193 | feat_act, connect_mat, num_connect = self.activated_select(val_mat, prob_action) 1194 | 1195 | feat_act_mat = self.agents2batch(feat_act) # (batchsize*agent_num, channel, size, size) 1196 | feat_act_mat = feat_act_mat.detach() 1197 | 1198 | pred_act = self.decoder(feat_act_mat) 1199 | 1200 | action = torch.argmax(connect_mat, dim=1) 1201 | 1202 | return pred_act, prob_action, action, num_connect 1203 | else: 1204 | raise ValueError('Incorrect inference mode') 1205 | 1206 | # AuxAtt, 1207 | class MIMOcomWho(nn.Module): 1208 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, feat_squeezer=-1, attention='additive', 1209 | has_query=True, sparse=False, agent_num=5, shuffle_flag=False, image_size=512, 1210 | shared_img_encoder=False, key_size=128, query_size=128, enc_backbone='n_segnet_encoder', 1211 | dec_backbone='n_segnet_decoder'): 1212 | super(MIMOcomWho, self).__init__() 1213 | 1214 | self.agent_num = agent_num 1215 | self.in_channels = in_channels 1216 | self.shuffle_flag = shuffle_flag 1217 | self.feature_map_channel = 512 1218 | self.key_size = key_size 1219 | self.query_size = query_size 1220 | self.shared_img_encoder = shared_img_encoder 1221 | self.has_query = has_query 1222 | self.sparse = sparse 1223 | # Encoder 1224 | 1225 | # Non-shared 1226 | if self.shared_img_encoder == 'unified': 1227 | self.u_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel, 1228 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone) 1229 | else: 1230 | raise ValueError('Incorrect shared_img_encoder flag') 1231 | 1232 | # # Message generator 1233 | self.query_key_net = policy_net4(n_classes=n_classes, in_channels=in_channels, enc_backbone=enc_backbone) 1234 | if self.has_query: 1235 | self.query_net = linear(out_size=self.query_size, input_feat_sz=image_size / 32) 1236 | 1237 | self.key_net = linear(out_size=self.key_size, input_feat_sz=image_size / 32) 1238 | 1239 | self.attention_net = MIMOWhoGeneralDotProductAttention(self.query_size, self.key_size) 1240 | 1241 | # Segmentation decoder 1242 | self.decoder = img_decoder(n_classes=n_classes, in_channels=self.feature_map_channel*2, 1243 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone) 1244 | 1245 | 1246 | # List the parameters of each modules 1247 | self.attention_paras = list(self.attention_net.parameters()) 1248 | if self.shared_img_encoder == 'unified': 1249 | self.img_net_paras = list(self.u_encoder.parameters()) + list(self.decoder.parameters()) 1250 | 1251 | 1252 | 1253 | self.policy_net_paras = list(self.query_key_net.parameters()) + list( 1254 | self.key_net.parameters()) + self.attention_paras 1255 | if self.has_query: 1256 | self.policy_net_paras = self.policy_net_paras + list(self.query_net.parameters()) 1257 | 1258 | self.all_paras = self.img_net_paras + self.policy_net_paras 1259 | 1260 | 1261 | def shape_feat_map(self, feat, size): 1262 | return torch.unsqueeze(feat.view(-1, size[1] * size[2] * size[3]), 1) 1263 | 1264 | def argmax_select(self, val_mat, prob_action): 1265 | # v(batch, query_num, channel, size, size) 1266 | cls_num = prob_action.shape[1] 1267 | 1268 | coef_argmax = F.one_hot(prob_action.max(dim=1)[1], num_classes=cls_num).type(torch.cuda.FloatTensor) 1269 | coef_argmax = coef_argmax.transpose(1, 2) 1270 | attn_shape = coef_argmax.shape 1271 | bats, key_num, query_num = attn_shape[0], attn_shape[1], attn_shape[2] 1272 | coef_argmax_exp = coef_argmax.view(bats, key_num, query_num, 1, 1, 1) 1273 | 1274 | v_exp = torch.unsqueeze(val_mat, 2) 1275 | v_exp = v_exp.expand(-1, -1, query_num, -1, -1, -1) 1276 | 1277 | output = coef_argmax_exp * v_exp # (batch,4,channel,size,size) 1278 | feat_argmax = output.sum(1) # (batch,1,channel,size,size) 1279 | 1280 | # compute connect 1281 | count_coef = coef_argmax.clone() 1282 | ind = np.diag_indices(self.agent_num) 1283 | count_coef[:, ind[0], ind[1]] = 0 1284 | num_connect = torch.nonzero(count_coef).shape[0] / (self.agent_num * count_coef.shape[0]) 1285 | 1286 | return feat_argmax, coef_argmax, num_connect 1287 | 1288 | def activated_select(self, val_mat, prob_action, thres=0.2): 1289 | 1290 | coef_act = torch.mul(prob_action, (prob_action > thres).float()) 1291 | attn_shape = coef_act.shape 1292 | bats, key_num, query_num = attn_shape[0], attn_shape[1], attn_shape[2] 1293 | coef_act_exp = coef_act.view(bats, key_num, query_num, 1, 1, 1) 1294 | 1295 | v_exp = torch.unsqueeze(val_mat, 2) 1296 | v_exp = v_exp.expand(-1, -1, query_num, -1, -1, -1) 1297 | 1298 | output = coef_act_exp * v_exp # (batch,4,channel,size,size) 1299 | feat_act = output.sum(1) # (batch,1,channel,size,size) 1300 | 1301 | # compute connect 1302 | count_coef = coef_act.clone() 1303 | ind = np.diag_indices(self.agent_num) 1304 | count_coef[:, ind[0], ind[1]] = 0 1305 | num_connect = torch.nonzero(count_coef).shape[0] / (self.agent_num * count_coef.shape[0]) 1306 | 1307 | return feat_act, coef_act, num_connect 1308 | 1309 | def agents2batch(self, feats): 1310 | agent_num = feats.shape[1] 1311 | feat_list = [] 1312 | for i in range(agent_num): 1313 | feat_list.append(feats[:, i, :, :, :]) 1314 | feat_mat = torch.cat(tuple(feat_list), 0) 1315 | return feat_mat 1316 | 1317 | def divide_inputs(self, inputs): 1318 | ''' 1319 | Divide the input into a list of several images 1320 | ''' 1321 | input_list = [] 1322 | for i in range(self.agent_num): 1323 | input_list.append(inputs[:, 3 * i:3 * i + 3, :, :]) 1324 | 1325 | return input_list 1326 | 1327 | def forward(self, inputs, training=True, MO_flag=False , inference='argmax'): 1328 | batch_size, _, _, _ = inputs.size() 1329 | 1330 | input_list = self.divide_inputs(inputs) 1331 | 1332 | if self.shared_img_encoder == 'unified': 1333 | # vectorize input list 1334 | img_list = [] 1335 | for i in range(self.agent_num): 1336 | img_list.append(input_list[i]) 1337 | unified_img_list = torch.cat(tuple(img_list), 0) 1338 | 1339 | # pass encoder 1340 | feat_maps = self.u_encoder(unified_img_list) 1341 | 1342 | # get feat maps for each image 1343 | feat_map = {} 1344 | feat_list = [] 1345 | for i in range(self.agent_num): 1346 | feat_map[i] = torch.unsqueeze(feat_maps[batch_size * i:batch_size * (i + 1)], 1) 1347 | feat_list.append(feat_map[i]) 1348 | val_mat = torch.cat(tuple(feat_list), 1) 1349 | else: 1350 | raise ValueError('Incorrect encoder') 1351 | 1352 | # pass feature maps through key and query generator 1353 | query_key_maps = self.query_key_net(unified_img_list) 1354 | keys = self.key_net(query_key_maps) 1355 | if self.has_query: 1356 | querys = self.query_net(query_key_maps) 1357 | 1358 | # get key and query 1359 | key = {} 1360 | query = {} 1361 | key_list = [] 1362 | query_list = [] 1363 | 1364 | for i in range(self.agent_num): 1365 | key[i] = torch.unsqueeze(keys[batch_size * i:batch_size * (i + 1)], 1) 1366 | key_list.append(key[i]) 1367 | if self.has_query: 1368 | query[i] = torch.unsqueeze(querys[batch_size * i:batch_size * (i + 1)], 1) 1369 | else: 1370 | query[i] = torch.ones(batch_size, 1, self.query_size).to('cuda') 1371 | query_list.append(query[i]) 1372 | 1373 | key_mat = torch.cat(tuple(key_list), 1) 1374 | query_mat = torch.cat(tuple(query_list), 1) 1375 | 1376 | if MO_flag: 1377 | query_mat = query_mat 1378 | else: 1379 | query_mat = torch.unsqueeze(query_mat[:,0,:],1) 1380 | 1381 | feat_fuse, prob_action = self.attention_net(query_mat, key_mat, val_mat, sparse=self.sparse) 1382 | fuse_map = torch.cat( (feat_fuse, val_mat), dim=2) 1383 | 1384 | 1385 | feat_fuse_mat = self.agents2batch(fuse_map) 1386 | pred = self.decoder(feat_fuse_mat) 1387 | 1388 | if training: 1389 | action = torch.argmax(prob_action, dim=1) 1390 | num_connect = self.agent_num - 1 1391 | 1392 | return pred, prob_action, action, num_connect 1393 | else: 1394 | if inference == 'softmax': 1395 | action = torch.argmax(prob_action, dim=1) 1396 | num_connect = self.agent_num - 1 1397 | 1398 | return pred, prob_action, action, num_connect 1399 | 1400 | elif inference == 'argmax_test': 1401 | feat_argmax, connect_mat, num_connect = self.argmax_select(val_mat, prob_action) 1402 | fuse_map = torch.cat((feat_argmax, val_mat), dim=2) 1403 | 1404 | # argmax feature map is passed to decoder 1405 | feat_argmax_mat = self.agents2batch(fuse_map) # (batchsize*agent_num, channel, size, size) 1406 | feat_argmax_mat = feat_argmax_mat.detach() 1407 | pred_argmax = self.decoder(feat_argmax_mat) 1408 | action = torch.argmax(prob_action, dim=1) 1409 | 1410 | return pred_argmax, prob_action, action, num_connect 1411 | 1412 | elif inference == 'activated': 1413 | feat_act, connect_mat, num_connect = self.activated_select(val_mat, prob_action) 1414 | fuse_map = torch.cat((feat_act, val_mat), dim=2) 1415 | 1416 | feat_act_mat = self.agents2batch(fuse_map) # (batchsize*agent_num, channel, size, size) 1417 | feat_act_mat = feat_act_mat.detach() 1418 | pred_act = self.decoder(feat_act_mat) 1419 | action = torch.argmax(prob_action, dim=1) 1420 | 1421 | return pred_act, prob_action, action, num_connect 1422 | else: 1423 | raise ValueError('Incorrect inference mode') 1424 | -------------------------------------------------------------------------------- /ptsemseg/models/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torchvision.models as models 4 | 5 | import pretrainedmodels 6 | 7 | from ptsemseg.models.utils import conv2DBatchNormRelu, deconv2DBatchNormRelu, Sparsemax 8 | import random 9 | 10 | 11 | 12 | class n_segnet_encoder(nn.Module): 13 | def __init__(self, n_classes=21, in_channels=3): 14 | super(n_segnet_encoder, self).__init__() 15 | self.in_channels = in_channels 16 | 17 | # Encoder 18 | # down 1 19 | self.conv1 = conv2DBatchNormRelu(self.in_channels, 64, k_size=3, stride=1, padding=1) 20 | self.conv2 = conv2DBatchNormRelu(64 , 64, k_size=3, stride=2, padding=1) 21 | 22 | # down 2 23 | self.conv3 = conv2DBatchNormRelu(64 ,128, k_size=3, stride=1, padding=1) 24 | self.conv4 = conv2DBatchNormRelu(128 ,128, k_size=3, stride=2, padding=1) 25 | 26 | # down 3 27 | self.conv5 = conv2DBatchNormRelu(128 , 256, k_size=3, stride=1, padding=1) 28 | self.conv6 = conv2DBatchNormRelu(256 , 256, k_size=3, stride=1, padding=1) 29 | self.conv7 = conv2DBatchNormRelu(256 , 256, k_size=3, stride=2, padding=1) 30 | 31 | # down 4 32 | self.conv8 = conv2DBatchNormRelu(256 , 512, k_size=3, stride=1, padding=1) 33 | self.conv9 = conv2DBatchNormRelu(512 , 512, k_size=3, stride=1, padding=1) 34 | self.conv10= conv2DBatchNormRelu(512 , 512, k_size=3, stride=2, padding=1) 35 | 36 | # down 5 37 | self.conv11= conv2DBatchNormRelu(512 , 512, k_size=3, stride=1, padding=1) 38 | self.conv12= conv2DBatchNormRelu(512 , 512, k_size=3, stride=1, padding=1) 39 | self.conv13= conv2DBatchNormRelu(512 , 512, k_size=3, stride=2, padding=1) 40 | 41 | def forward(self, inputs): 42 | outputs = self.conv1(inputs) 43 | outputs = self.conv2(outputs) 44 | outputs = self.conv3(outputs) 45 | outputs = self.conv4(outputs) 46 | outputs = self.conv5(outputs) 47 | outputs = self.conv6(outputs) 48 | outputs = self.conv7(outputs) 49 | outputs = self.conv8(outputs) 50 | outputs = self.conv9(outputs) 51 | outputs = self.conv10(outputs) 52 | outputs = self.conv11(outputs) 53 | outputs = self.conv12(outputs) 54 | outputs = self.conv13(outputs) 55 | return outputs 56 | 57 | 58 | class resnet_encoder(nn.Module): 59 | def __init__(self, n_classes=21, in_channels=3): 60 | super(resnet_encoder, self).__init__() 61 | feat_chn = 256 62 | #self.feature_backbone = n_segnet_encoder(n_classes=n_classes, in_channels=in_channels) 63 | self.feature_backbone = pretrainedmodels.__dict__['resnet18'](num_classes=1000, pretrained=None) 64 | 65 | self.backbone_0 = self.feature_backbone.conv1 66 | self.backbone_1 = nn.Sequential(self.feature_backbone.bn1, self.feature_backbone.relu, self.feature_backbone.maxpool, self.feature_backbone.layer1) 67 | self.backbone_2 = self.feature_backbone.layer2 68 | self.backbone_3 = self.feature_backbone.layer3 69 | self.backbone_4 = self.feature_backbone.layer4 70 | 71 | 72 | def forward(self, inputs): 73 | # print('input:') 74 | # print(inputs.size()) 75 | # import pdb; pdb.set_trace() 76 | outputs = self.backbone_0(inputs) 77 | # print('base_0 size: ') 78 | # print(base_0.size()) 79 | 80 | outputs = self.backbone_1(outputs) 81 | # print('base_1 size: ') 82 | # print(base_1.size()) 83 | 84 | outputs = self.backbone_2(outputs) 85 | # print('base_2 size: ') 86 | # print(base_2.size()) 87 | 88 | outputs = self.backbone_3(outputs) 89 | # print('base_3 size: ') 90 | # print(base_3.size()) 91 | 92 | outputs = self.backbone_4(outputs) 93 | # print('base_4 size: ') 94 | # print(base_4.size()) 95 | 96 | return outputs 97 | 98 | ### ============= Decoder Backbone ============= ### 99 | class n_segnet_decoder(nn.Module): 100 | def __init__(self, n_classes=21, in_channels=512): 101 | #def __init__(self, n_classes=21, in_channels=512,agent_num=5): 102 | super(n_segnet_decoder, self).__init__() 103 | self.in_channels = in_channels 104 | # Decoder 105 | self.deconv1= deconv2DBatchNormRelu(self.in_channels, 512, k_size=3, stride=2, padding=1,output_padding=1) 106 | self.deconv2= conv2DBatchNormRelu(512 , 512, k_size=3, stride=1, padding=1) 107 | self.deconv3= conv2DBatchNormRelu(512 , 512, k_size=3, stride=1, padding=1) 108 | 109 | # up 4 110 | self.deconv4= deconv2DBatchNormRelu(512 , 512, k_size=3, stride=2, padding=1,output_padding=1) 111 | self.deconv5= conv2DBatchNormRelu(512 , 512, k_size=3, stride=1, padding=1) 112 | self.deconv6= conv2DBatchNormRelu(512 , 256, k_size=3, stride=1, padding=1) 113 | 114 | # up 3 115 | self.deconv7= deconv2DBatchNormRelu(256 , 256, k_size=3, stride=2, padding=1,output_padding=1) 116 | self.deconv8= conv2DBatchNormRelu(256 , 128, k_size=3, stride=1, padding=1) 117 | 118 | # up 2 119 | self.deconv9= deconv2DBatchNormRelu(128 , 128, k_size=3, stride=2, padding=1,output_padding=1) 120 | self.deconv10= conv2DBatchNormRelu(128 , 64, k_size=3, stride=1, padding=1) 121 | 122 | # up 1 123 | self.deconv11= deconv2DBatchNormRelu(64 , 64, k_size=3, stride=2, padding=1,output_padding=1) 124 | self.deconv12= conv2DBatchNormRelu(64 , n_classes, k_size=3, stride=1, padding=1) 125 | 126 | def forward(self, inputs): 127 | outputs = self.deconv1(inputs) 128 | outputs = self.deconv2(outputs) 129 | outputs = self.deconv3(outputs) 130 | 131 | outputs = self.deconv4(outputs) 132 | outputs = self.deconv5(outputs) 133 | outputs = self.deconv6(outputs) 134 | outputs = self.deconv7(outputs) 135 | outputs = self.deconv8(outputs) 136 | outputs = self.deconv9(outputs) 137 | outputs = self.deconv10(outputs) 138 | outputs = self.deconv11(outputs) 139 | outputs = self.deconv12(outputs) 140 | return outputs 141 | 142 | 143 | class simple_decoder(nn.Module): 144 | def __init__(self, n_classes=21, in_channels=512): 145 | super(simple_decoder, self).__init__() 146 | self.in_channels = in_channels 147 | 148 | feat_chn = 256 149 | 150 | self.pred = nn.Sequential( 151 | nn.Conv2d(self.in_channels, feat_chn, kernel_size=3, padding=1), 152 | nn.ReLU(inplace=True), 153 | nn.Conv2d(feat_chn, n_classes, kernel_size=3, padding=1) 154 | ) 155 | 156 | def forward(self, inputs): 157 | pred = self.pred(inputs) 158 | # print('pred size: ') 159 | # print(pred.size()) 160 | pred = nn.functional.interpolate(pred, size=torch.Size([inputs.size()[2]*32,inputs.size()[3]*32]), mode='bilinear', align_corners=False) 161 | # print('pred size: ') 162 | #rint(pred.size()) 163 | 164 | return pred 165 | 166 | 167 | class FCN_decoder(nn.Module): 168 | def __init__(self, n_classes=21, in_channels=512): 169 | super(FCN_decoder, self).__init__() 170 | feat_chn = 256 171 | 172 | self.pred = nn.Sequential( 173 | nn.Conv2d(in_channels, feat_chn, kernel_size=3, padding=1), 174 | nn.ReLU(inplace=True), 175 | nn.Conv2d(feat_chn, n_classes, kernel_size=3, padding=1) 176 | ) 177 | 178 | def forward(self, inputs): 179 | pred = self.pred(base_4) 180 | print('pred size: ') 181 | print(pred.size()) 182 | pred = nn.functional.interpolate(pred, size=inputs.size()[-2:], mode='bilinear', align_corners=False) 183 | print('pred size: ') 184 | print(pred.size()) 185 | 186 | return pred 187 | 188 | -------------------------------------------------------------------------------- /ptsemseg/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | class conv2DBatchNorm(nn.Module): 10 | def __init__( 11 | self, 12 | in_channels, 13 | n_filters, 14 | k_size, 15 | stride, 16 | padding, 17 | bias=True, 18 | dilation=1, 19 | is_batchnorm=True, 20 | ): 21 | super(conv2DBatchNorm, self).__init__() 22 | 23 | conv_mod = nn.Conv2d( 24 | int(in_channels), 25 | int(n_filters), 26 | kernel_size=k_size, 27 | padding=padding, 28 | stride=stride, 29 | bias=bias, 30 | dilation=dilation, 31 | ) 32 | 33 | if is_batchnorm: 34 | self.cb_unit = nn.Sequential(conv_mod, nn.BatchNorm2d(int(n_filters))) 35 | else: 36 | self.cb_unit = nn.Sequential(conv_mod) 37 | 38 | def forward(self, inputs): 39 | outputs = self.cb_unit(inputs) 40 | return outputs 41 | 42 | 43 | class conv2DGroupNorm(nn.Module): 44 | def __init__( 45 | self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, n_groups=16 46 | ): 47 | super(conv2DGroupNorm, self).__init__() 48 | 49 | conv_mod = nn.Conv2d( 50 | int(in_channels), 51 | int(n_filters), 52 | kernel_size=k_size, 53 | padding=padding, 54 | stride=stride, 55 | bias=bias, 56 | dilation=dilation, 57 | ) 58 | 59 | self.cg_unit = nn.Sequential(conv_mod, nn.GroupNorm(n_groups, int(n_filters))) 60 | 61 | def forward(self, inputs): 62 | outputs = self.cg_unit(inputs) 63 | return outputs 64 | 65 | 66 | class deconv2DBatchNorm(nn.Module): 67 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): 68 | super(deconv2DBatchNorm, self).__init__() 69 | 70 | self.dcb_unit = nn.Sequential( 71 | nn.ConvTranspose2d( 72 | int(in_channels), 73 | int(n_filters), 74 | kernel_size=k_size, 75 | padding=padding, 76 | stride=stride, 77 | bias=bias, 78 | ), 79 | nn.BatchNorm2d(int(n_filters)), 80 | ) 81 | 82 | def forward(self, inputs): 83 | outputs = self.dcb_unit(inputs) 84 | return outputs 85 | 86 | 87 | class conv2DBatchNormRelu(nn.Module): 88 | def __init__( 89 | self, 90 | in_channels, 91 | n_filters, 92 | k_size, 93 | stride, 94 | padding, 95 | bias=True, 96 | dilation=1, 97 | is_batchnorm=True, 98 | ): 99 | super(conv2DBatchNormRelu, self).__init__() 100 | 101 | conv_mod = nn.Conv2d( 102 | int(in_channels), 103 | int(n_filters), 104 | kernel_size=k_size, 105 | padding=padding, 106 | stride=stride, 107 | bias=bias, 108 | dilation=dilation, 109 | ) 110 | 111 | if is_batchnorm: 112 | self.cbr_unit = nn.Sequential( 113 | conv_mod, nn.BatchNorm2d(int(n_filters)), nn.ReLU(inplace=True) 114 | ) 115 | else: 116 | self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True)) 117 | 118 | def forward(self, inputs): 119 | outputs = self.cbr_unit(inputs) 120 | return outputs 121 | 122 | 123 | class conv2DGroupNormRelu(nn.Module): 124 | def __init__( 125 | self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, n_groups=16 126 | ): 127 | super(conv2DGroupNormRelu, self).__init__() 128 | 129 | conv_mod = nn.Conv2d( 130 | int(in_channels), 131 | int(n_filters), 132 | kernel_size=k_size, 133 | padding=padding, 134 | stride=stride, 135 | bias=bias, 136 | dilation=dilation, 137 | ) 138 | 139 | self.cgr_unit = nn.Sequential( 140 | conv_mod, nn.GroupNorm(n_groups, int(n_filters)), nn.ReLU(inplace=True) 141 | ) 142 | 143 | def forward(self, inputs): 144 | outputs = self.cgr_unit(inputs) 145 | return outputs 146 | 147 | 148 | class deconv2DBatchNormRelu(nn.Module): 149 | def __init__(self, in_channels, n_filters, k_size, stride, padding, output_padding,bias=True): 150 | super(deconv2DBatchNormRelu, self).__init__() 151 | 152 | self.dcbr_unit = nn.Sequential( 153 | nn.ConvTranspose2d( 154 | int(in_channels), 155 | int(n_filters), 156 | kernel_size=k_size, 157 | padding=padding, 158 | output_padding=output_padding, 159 | stride=stride, 160 | bias=bias, 161 | ), 162 | nn.BatchNorm2d(int(n_filters)), 163 | nn.ReLU(inplace=True), 164 | ) 165 | 166 | def forward(self, inputs): 167 | outputs = self.dcbr_unit(inputs) 168 | return outputs 169 | 170 | 171 | class unetConv2(nn.Module): 172 | def __init__(self, in_size, out_size, is_batchnorm): 173 | super(unetConv2, self).__init__() 174 | 175 | if is_batchnorm: 176 | self.conv1 = nn.Sequential( 177 | nn.Conv2d(in_size, out_size, 3, 1, 0), nn.BatchNorm2d(out_size), nn.ReLU() 178 | ) 179 | self.conv2 = nn.Sequential( 180 | nn.Conv2d(out_size, out_size, 3, 1, 0), nn.BatchNorm2d(out_size), nn.ReLU() 181 | ) 182 | else: 183 | self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 0), nn.ReLU()) 184 | self.conv2 = nn.Sequential(nn.Conv2d(out_size, out_size, 3, 1, 0), nn.ReLU()) 185 | 186 | def forward(self, inputs): 187 | outputs = self.conv1(inputs) 188 | outputs = self.conv2(outputs) 189 | return outputs 190 | 191 | 192 | class unetUp(nn.Module): 193 | def __init__(self, in_size, out_size, is_deconv): 194 | super(unetUp, self).__init__() 195 | self.conv = unetConv2(in_size, out_size, False) 196 | if is_deconv: 197 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) 198 | else: 199 | self.up = nn.UpsamplingBilinear2d(scale_factor=2) 200 | 201 | def forward(self, inputs1, inputs2): 202 | outputs2 = self.up(inputs2) 203 | offset = outputs2.size()[2] - inputs1.size()[2] 204 | padding = 2 * [offset // 2, offset // 2] 205 | outputs1 = F.pad(inputs1, padding) 206 | return self.conv(torch.cat([outputs1, outputs2], 1)) 207 | 208 | 209 | class segnetDown2(nn.Module): 210 | def __init__(self, in_size, out_size): 211 | super(segnetDown2, self).__init__() 212 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, k_size=3, stride=1, padding=1) 213 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, k_size=3, stride=1, padding=1) 214 | self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) 215 | 216 | def forward(self, inputs): 217 | outputs = self.conv1(inputs) 218 | outputs = self.conv2(outputs) 219 | unpooled_shape = outputs.size() 220 | outputs, indices = self.maxpool_with_argmax(outputs) 221 | return outputs, indices, unpooled_shape 222 | 223 | 224 | class segnetDown3(nn.Module): 225 | def __init__(self, in_size, out_size): 226 | super(segnetDown3, self).__init__() 227 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 228 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 229 | self.conv3 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) 230 | self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) 231 | 232 | def forward(self, inputs): 233 | outputs = self.conv1(inputs) 234 | outputs = self.conv2(outputs) 235 | outputs = self.conv3(outputs) 236 | unpooled_shape = outputs.size() 237 | outputs, indices = self.maxpool_with_argmax(outputs) 238 | return outputs, indices, unpooled_shape 239 | 240 | 241 | class segnetUp2(nn.Module): 242 | def __init__(self, in_size, out_size): 243 | super(segnetUp2, self).__init__() 244 | self.unpool = nn.MaxUnpool2d(2, 2) 245 | self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 246 | self.conv2 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 247 | 248 | def forward(self, inputs, indices, output_shape): 249 | outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) 250 | outputs = self.conv1(outputs) 251 | outputs = self.conv2(outputs) 252 | return outputs 253 | 254 | 255 | class segnetUp3(nn.Module): 256 | def __init__(self, in_size, out_size): 257 | super(segnetUp3, self).__init__() 258 | self.unpool = nn.MaxUnpool2d(2, 2) 259 | self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 260 | self.conv2 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) 261 | self.conv3 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) 262 | 263 | def forward(self, inputs, indices, output_shape): 264 | outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) 265 | outputs = self.conv1(outputs) 266 | outputs = self.conv2(outputs) 267 | outputs = self.conv3(outputs) 268 | return outputs 269 | 270 | 271 | 272 | 273 | 274 | 275 | class residualBlock(nn.Module): 276 | expansion = 1 277 | 278 | def __init__(self, in_channels, n_filters, stride=1, downsample=None): 279 | super(residualBlock, self).__init__() 280 | 281 | self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, 1, bias=False) 282 | self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, bias=False) 283 | self.downsample = downsample 284 | self.stride = stride 285 | self.relu = nn.ReLU(inplace=True) 286 | 287 | def forward(self, x): 288 | residual = x 289 | 290 | out = self.convbnrelu1(x) 291 | out = self.convbn2(out) 292 | 293 | if self.downsample is not None: 294 | residual = self.downsample(x) 295 | 296 | out += residual 297 | out = self.relu(out) 298 | return out 299 | 300 | 301 | class residualBottleneck(nn.Module): 302 | expansion = 4 303 | 304 | def __init__(self, in_channels, n_filters, stride=1, downsample=None): 305 | super(residualBottleneck, self).__init__() 306 | self.convbn1 = nn.Conv2DBatchNorm(in_channels, n_filters, k_size=1, bias=False) 307 | self.convbn2 = nn.Conv2DBatchNorm( 308 | n_filters, n_filters, k_size=3, padding=1, stride=stride, bias=False 309 | ) 310 | self.convbn3 = nn.Conv2DBatchNorm(n_filters, n_filters * 4, k_size=1, bias=False) 311 | self.relu = nn.ReLU(inplace=True) 312 | self.downsample = downsample 313 | self.stride = stride 314 | 315 | def forward(self, x): 316 | residual = x 317 | 318 | out = self.convbn1(x) 319 | out = self.convbn2(out) 320 | out = self.convbn3(out) 321 | 322 | if self.downsample is not None: 323 | residual = self.downsample(x) 324 | 325 | out += residual 326 | out = self.relu(out) 327 | 328 | return out 329 | 330 | 331 | class linknetUp(nn.Module): 332 | def __init__(self, in_channels, n_filters): 333 | super(linknetUp, self).__init__() 334 | 335 | # B, 2C, H, W -> B, C/2, H, W 336 | self.convbnrelu1 = conv2DBatchNormRelu( 337 | in_channels, n_filters / 2, k_size=1, stride=1, padding=1 338 | ) 339 | 340 | # B, C/2, H, W -> B, C/2, H, W 341 | self.deconvbnrelu2 = nn.deconv2DBatchNormRelu( 342 | n_filters / 2, n_filters / 2, k_size=3, stride=2, padding=0 343 | ) 344 | 345 | # B, C/2, H, W -> B, C, H, W 346 | self.convbnrelu3 = conv2DBatchNormRelu( 347 | n_filters / 2, n_filters, k_size=1, stride=1, padding=1 348 | ) 349 | 350 | def forward(self, x): 351 | x = self.convbnrelu1(x) 352 | x = self.deconvbnrelu2(x) 353 | x = self.convbnrelu3(x) 354 | return x 355 | 356 | 357 | class FRRU(nn.Module): 358 | """ 359 | Full Resolution Residual Unit for FRRN 360 | """ 361 | 362 | def __init__(self, prev_channels, out_channels, scale, group_norm=False, n_groups=None): 363 | super(FRRU, self).__init__() 364 | self.scale = scale 365 | self.prev_channels = prev_channels 366 | self.out_channels = out_channels 367 | self.group_norm = group_norm 368 | self.n_groups = n_groups 369 | 370 | if self.group_norm: 371 | conv_unit = conv2DGroupNormRelu 372 | self.conv1 = conv_unit( 373 | prev_channels + 32, 374 | out_channels, 375 | k_size=3, 376 | stride=1, 377 | padding=1, 378 | bias=False, 379 | n_groups=self.n_groups, 380 | ) 381 | self.conv2 = conv_unit( 382 | out_channels, 383 | out_channels, 384 | k_size=3, 385 | stride=1, 386 | padding=1, 387 | bias=False, 388 | n_groups=self.n_groups, 389 | ) 390 | 391 | else: 392 | conv_unit = conv2DBatchNormRelu 393 | self.conv1 = conv_unit( 394 | prev_channels + 32, out_channels, k_size=3, stride=1, padding=1, bias=False 395 | ) 396 | self.conv2 = conv_unit( 397 | out_channels, out_channels, k_size=3, stride=1, padding=1, bias=False 398 | ) 399 | 400 | self.conv_res = nn.Conv2d(out_channels, 32, kernel_size=1, stride=1, padding=0) 401 | 402 | def forward(self, y, z): 403 | x = torch.cat([y, nn.MaxPool2d(self.scale, self.scale)(z)], dim=1) 404 | y_prime = self.conv1(x) 405 | y_prime = self.conv2(y_prime) 406 | 407 | x = self.conv_res(y_prime) 408 | upsample_size = torch.Size([_s * self.scale for _s in y_prime.shape[-2:]]) 409 | x = F.upsample(x, size=upsample_size, mode="nearest") 410 | z_prime = z + x 411 | 412 | return y_prime, z_prime 413 | 414 | 415 | class RU(nn.Module): 416 | """ 417 | Residual Unit for FRRN 418 | """ 419 | 420 | def __init__(self, channels, kernel_size=3, strides=1, group_norm=False, n_groups=None): 421 | super(RU, self).__init__() 422 | self.group_norm = group_norm 423 | self.n_groups = n_groups 424 | 425 | if self.group_norm: 426 | self.conv1 = conv2DGroupNormRelu( 427 | channels, 428 | channels, 429 | k_size=kernel_size, 430 | stride=strides, 431 | padding=1, 432 | bias=False, 433 | n_groups=self.n_groups, 434 | ) 435 | self.conv2 = conv2DGroupNorm( 436 | channels, 437 | channels, 438 | k_size=kernel_size, 439 | stride=strides, 440 | padding=1, 441 | bias=False, 442 | n_groups=self.n_groups, 443 | ) 444 | 445 | else: 446 | self.conv1 = conv2DBatchNormRelu( 447 | channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False 448 | ) 449 | self.conv2 = conv2DBatchNorm( 450 | channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False 451 | ) 452 | 453 | def forward(self, x): 454 | incoming = x 455 | x = self.conv1(x) 456 | x = self.conv2(x) 457 | return x + incoming 458 | 459 | 460 | class residualConvUnit(nn.Module): 461 | def __init__(self, channels, kernel_size=3): 462 | super(residualConvUnit, self).__init__() 463 | 464 | self.residual_conv_unit = nn.Sequential( 465 | nn.ReLU(inplace=True), 466 | nn.Conv2d(channels, channels, kernel_size=kernel_size), 467 | nn.ReLU(inplace=True), 468 | nn.Conv2d(channels, channels, kernel_size=kernel_size), 469 | ) 470 | 471 | def forward(self, x): 472 | input = x 473 | x = self.residual_conv_unit(x) 474 | return x + input 475 | 476 | 477 | class multiResolutionFusion(nn.Module): 478 | def __init__(self, channels, up_scale_high, up_scale_low, high_shape, low_shape): 479 | super(multiResolutionFusion, self).__init__() 480 | 481 | self.up_scale_high = up_scale_high 482 | self.up_scale_low = up_scale_low 483 | 484 | self.conv_high = nn.Conv2d(high_shape[1], channels, kernel_size=3) 485 | 486 | if low_shape is not None: 487 | self.conv_low = nn.Conv2d(low_shape[1], channels, kernel_size=3) 488 | 489 | def forward(self, x_high, x_low): 490 | high_upsampled = F.upsample( 491 | self.conv_high(x_high), scale_factor=self.up_scale_high, mode="bilinear" 492 | ) 493 | 494 | if x_low is None: 495 | return high_upsampled 496 | 497 | low_upsampled = F.upsample( 498 | self.conv_low(x_low), scale_factor=self.up_scale_low, mode="bilinear" 499 | ) 500 | 501 | return low_upsampled + high_upsampled 502 | 503 | 504 | class chainedResidualPooling(nn.Module): 505 | def __init__(self, channels, input_shape): 506 | super(chainedResidualPooling, self).__init__() 507 | 508 | self.chained_residual_pooling = nn.Sequential( 509 | nn.ReLU(inplace=True), 510 | nn.MaxPool2d(5, 1, 2), 511 | nn.Conv2d(input_shape[1], channels, kernel_size=3), 512 | ) 513 | 514 | def forward(self, x): 515 | input = x 516 | x = self.chained_residual_pooling(x) 517 | return x + input 518 | 519 | 520 | class pyramidPooling(nn.Module): 521 | def __init__( 522 | self, in_channels, pool_sizes, model_name="pspnet", fusion_mode="cat", is_batchnorm=True 523 | ): 524 | super(pyramidPooling, self).__init__() 525 | 526 | bias = not is_batchnorm 527 | 528 | self.paths = [] 529 | for i in range(len(pool_sizes)): 530 | self.paths.append( 531 | conv2DBatchNormRelu( 532 | in_channels, 533 | int(in_channels / len(pool_sizes)), 534 | 1, 535 | 1, 536 | 0, 537 | bias=bias, 538 | is_batchnorm=is_batchnorm, 539 | ) 540 | ) 541 | 542 | self.path_module_list = nn.ModuleList(self.paths) 543 | self.pool_sizes = pool_sizes 544 | self.model_name = model_name 545 | self.fusion_mode = fusion_mode 546 | 547 | def forward(self, x): 548 | h, w = x.shape[2:] 549 | 550 | if self.training or self.model_name != "icnet": # general settings or pspnet 551 | k_sizes = [] 552 | strides = [] 553 | for pool_size in self.pool_sizes: 554 | k_sizes.append((int(h / pool_size), int(w / pool_size))) 555 | strides.append((int(h / pool_size), int(w / pool_size))) 556 | else: # eval mode and icnet: pre-trained for 1025 x 2049 557 | k_sizes = [(8, 15), (13, 25), (17, 33), (33, 65)] 558 | strides = [(5, 10), (10, 20), (16, 32), (33, 65)] 559 | 560 | if self.fusion_mode == "cat": # pspnet: concat (including x) 561 | output_slices = [x] 562 | 563 | for i, (module, pool_size) in enumerate(zip(self.path_module_list, self.pool_sizes)): 564 | out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0) 565 | # out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size)) 566 | if self.model_name != "icnet": 567 | out = module(out) 568 | out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True) 569 | output_slices.append(out) 570 | 571 | return torch.cat(output_slices, dim=1) 572 | else: # icnet: element-wise sum (including x) 573 | pp_sum = x 574 | 575 | for i, (module, pool_size) in enumerate(zip(self.path_module_list, self.pool_sizes)): 576 | out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0) 577 | # out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size)) 578 | if self.model_name != "icnet": 579 | out = module(out) 580 | out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True) 581 | pp_sum = pp_sum + out 582 | 583 | return pp_sum 584 | 585 | 586 | class bottleNeckPSP(nn.Module): 587 | def __init__( 588 | self, in_channels, mid_channels, out_channels, stride, dilation=1, is_batchnorm=True 589 | ): 590 | super(bottleNeckPSP, self).__init__() 591 | 592 | bias = not is_batchnorm 593 | 594 | self.cbr1 = conv2DBatchNormRelu( 595 | in_channels, mid_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm 596 | ) 597 | if dilation > 1: 598 | self.cbr2 = conv2DBatchNormRelu( 599 | mid_channels, 600 | mid_channels, 601 | 3, 602 | stride=stride, 603 | padding=dilation, 604 | bias=bias, 605 | dilation=dilation, 606 | is_batchnorm=is_batchnorm, 607 | ) 608 | else: 609 | self.cbr2 = conv2DBatchNormRelu( 610 | mid_channels, 611 | mid_channels, 612 | 3, 613 | stride=stride, 614 | padding=1, 615 | bias=bias, 616 | dilation=1, 617 | is_batchnorm=is_batchnorm, 618 | ) 619 | self.cb3 = conv2DBatchNorm( 620 | mid_channels, out_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm 621 | ) 622 | self.cb4 = conv2DBatchNorm( 623 | in_channels, 624 | out_channels, 625 | 1, 626 | stride=stride, 627 | padding=0, 628 | bias=bias, 629 | is_batchnorm=is_batchnorm, 630 | ) 631 | 632 | def forward(self, x): 633 | conv = self.cb3(self.cbr2(self.cbr1(x))) 634 | residual = self.cb4(x) 635 | return F.relu(conv + residual, inplace=True) 636 | 637 | 638 | class bottleNeckIdentifyPSP(nn.Module): 639 | def __init__(self, in_channels, mid_channels, stride, dilation=1, is_batchnorm=True): 640 | super(bottleNeckIdentifyPSP, self).__init__() 641 | 642 | bias = not is_batchnorm 643 | 644 | self.cbr1 = conv2DBatchNormRelu( 645 | in_channels, mid_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm 646 | ) 647 | if dilation > 1: 648 | self.cbr2 = conv2DBatchNormRelu( 649 | mid_channels, 650 | mid_channels, 651 | 3, 652 | stride=1, 653 | padding=dilation, 654 | bias=bias, 655 | dilation=dilation, 656 | is_batchnorm=is_batchnorm, 657 | ) 658 | else: 659 | self.cbr2 = conv2DBatchNormRelu( 660 | mid_channels, 661 | mid_channels, 662 | 3, 663 | stride=1, 664 | padding=1, 665 | bias=bias, 666 | dilation=1, 667 | is_batchnorm=is_batchnorm, 668 | ) 669 | self.cb3 = conv2DBatchNorm( 670 | mid_channels, in_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm 671 | ) 672 | 673 | def forward(self, x): 674 | residual = x 675 | x = self.cb3(self.cbr2(self.cbr1(x))) 676 | return F.relu(x + residual, inplace=True) 677 | 678 | 679 | class residualBlockPSP(nn.Module): 680 | def __init__( 681 | self, 682 | n_blocks, 683 | in_channels, 684 | mid_channels, 685 | out_channels, 686 | stride, 687 | dilation=1, 688 | include_range="all", 689 | is_batchnorm=True, 690 | ): 691 | super(residualBlockPSP, self).__init__() 692 | 693 | if dilation > 1: 694 | stride = 1 695 | 696 | # residualBlockPSP = convBlockPSP + identityBlockPSPs 697 | layers = [] 698 | if include_range in ["all", "conv"]: 699 | layers.append( 700 | bottleNeckPSP( 701 | in_channels, 702 | mid_channels, 703 | out_channels, 704 | stride, 705 | dilation, 706 | is_batchnorm=is_batchnorm, 707 | ) 708 | ) 709 | if include_range in ["all", "identity"]: 710 | for i in range(n_blocks - 1): 711 | layers.append( 712 | bottleNeckIdentifyPSP( 713 | out_channels, mid_channels, stride, dilation, is_batchnorm=is_batchnorm 714 | ) 715 | ) 716 | 717 | self.layers = nn.Sequential(*layers) 718 | 719 | def forward(self, x): 720 | return self.layers(x) 721 | 722 | 723 | class cascadeFeatureFusion(nn.Module): 724 | def __init__( 725 | self, n_classes, low_in_channels, high_in_channels, out_channels, is_batchnorm=True 726 | ): 727 | super(cascadeFeatureFusion, self).__init__() 728 | 729 | bias = not is_batchnorm 730 | 731 | self.low_dilated_conv_bn = conv2DBatchNorm( 732 | low_in_channels, 733 | out_channels, 734 | 3, 735 | stride=1, 736 | padding=2, 737 | bias=bias, 738 | dilation=2, 739 | is_batchnorm=is_batchnorm, 740 | ) 741 | self.low_classifier_conv = nn.Conv2d( 742 | int(low_in_channels), 743 | int(n_classes), 744 | kernel_size=1, 745 | padding=0, 746 | stride=1, 747 | bias=True, 748 | dilation=1, 749 | ) # Train only 750 | self.high_proj_conv_bn = conv2DBatchNorm( 751 | high_in_channels, 752 | out_channels, 753 | 1, 754 | stride=1, 755 | padding=0, 756 | bias=bias, 757 | is_batchnorm=is_batchnorm, 758 | ) 759 | 760 | def forward(self, x_low, x_high): 761 | x_low_upsampled = F.interpolate( 762 | x_low, size=get_interp_size(x_low, z_factor=2), mode="bilinear", align_corners=True 763 | ) 764 | 765 | low_cls = self.low_classifier_conv(x_low_upsampled) 766 | 767 | low_fm = self.low_dilated_conv_bn(x_low_upsampled) 768 | high_fm = self.high_proj_conv_bn(x_high) 769 | high_fused_fm = F.relu(low_fm + high_fm, inplace=True) 770 | 771 | return high_fused_fm, low_cls 772 | 773 | 774 | def get_interp_size(input, s_factor=1, z_factor=1): # for caffe 775 | ori_h, ori_w = input.shape[2:] 776 | 777 | # shrink (s_factor >= 1) 778 | ori_h = (ori_h - 1) / s_factor + 1 779 | ori_w = (ori_w - 1) / s_factor + 1 780 | 781 | # zoom (z_factor >= 1) 782 | ori_h = ori_h + (ori_h - 1) * (z_factor - 1) 783 | ori_w = ori_w + (ori_w - 1) * (z_factor - 1) 784 | 785 | resize_shape = (int(ori_h), int(ori_w)) 786 | return resize_shape 787 | 788 | 789 | def interp(input, output_size, mode="bilinear"): 790 | n, c, ih, iw = input.shape 791 | oh, ow = output_size 792 | 793 | # normalize to [-1, 1] 794 | h = torch.arange(0, oh, dtype=torch.float, device=input.device) / (oh - 1) * 2 - 1 795 | w = torch.arange(0, ow, dtype=torch.float, device=input.device) / (ow - 1) * 2 - 1 796 | 797 | grid = torch.zeros(oh, ow, 2, dtype=torch.float, device=input.device) 798 | grid[:, :, 0] = w.unsqueeze(0).repeat(oh, 1) 799 | grid[:, :, 1] = h.unsqueeze(0).repeat(ow, 1).transpose(0, 1) 800 | grid = grid.unsqueeze(0).repeat(n, 1, 1, 1) # grid.shape: [n, oh, ow, 2] 801 | grid = Variable(grid) 802 | if input.is_cuda: 803 | grid = grid.cuda() 804 | 805 | return F.grid_sample(input, grid, mode=mode) 806 | 807 | 808 | def get_upsampling_weight(in_channels, out_channels, kernel_size): 809 | """Make a 2D bilinear kernel suitable for upsampling""" 810 | factor = (kernel_size + 1) // 2 811 | if kernel_size % 2 == 1: 812 | center = factor - 1 813 | else: 814 | center = factor - 0.5 815 | og = np.ogrid[:kernel_size, :kernel_size] 816 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) 817 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64) 818 | weight[range(in_channels), range(out_channels), :, :] = filt 819 | return torch.from_numpy(weight).float() 820 | 821 | class Sparsemax(nn.Module): 822 | """Sparsemax function.""" 823 | 824 | def __init__(self, dim=None): 825 | """Initialize sparsemax activation 826 | 827 | Args: 828 | dim (int, optional): The dimension over which to apply the sparsemax function. 829 | """ 830 | super(Sparsemax, self).__init__() 831 | 832 | self.dim = -1 if dim is None else dim 833 | 834 | def forward(self, input): 835 | """Forward function. 836 | Args: 837 | input (torch.Tensor): Input tensor. First dimension should be the batch size 838 | Returns: 839 | torch.Tensor: [batch_size x number_of_logits] Output tensor 840 | """ 841 | # Sparsemax currently only handles 2-dim tensors, 842 | # so we reshape and reshape back after sparsemax 843 | original_size = input.size() 844 | input = input.view(-1, input.size(self.dim)) 845 | 846 | dim = 1 847 | number_of_logits = input.size(dim) 848 | 849 | # Translate input by max for numerical stability 850 | input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input) 851 | 852 | # Sort input in descending order. 853 | # (NOTE: Can be replaced with linear time selection method described here: 854 | # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html) 855 | zs = torch.sort(input=input, dim=dim, descending=True)[0] 856 | range = torch.range(start=1, end=number_of_logits, device=input.device).view(1, -1) 857 | range = range.expand_as(zs) 858 | 859 | # Determine sparsity of projection 860 | bound = 1 + range * zs 861 | cumulative_sum_zs = torch.cumsum(zs, dim) 862 | is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type()) 863 | k = torch.max(is_gt * range, dim, keepdim=True)[0] 864 | 865 | # Compute threshold function 866 | zs_sparse = is_gt * zs 867 | 868 | # Compute taus 869 | taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k 870 | taus = taus.expand_as(input) 871 | 872 | # Sparsemax 873 | self.output = torch.max(torch.zeros_like(input), input - taus) 874 | 875 | output = self.output.view(original_size) 876 | 877 | return output 878 | 879 | def backward(self, grad_output): 880 | """Backward function.""" 881 | dim = 1 882 | 883 | nonzeros = torch.ne(self.output, 0) 884 | sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim) 885 | self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output)) 886 | 887 | return self.grad_input 888 | -------------------------------------------------------------------------------- /ptsemseg/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from torch.optim import SGD, Adam, ASGD, Adamax, Adadelta, Adagrad, RMSprop 4 | 5 | logger = logging.getLogger("ptsemseg") 6 | 7 | key2opt = { 8 | "sgd": SGD, 9 | "adam": Adam, 10 | "asgd": ASGD, 11 | "adamax": Adamax, 12 | "adadelta": Adadelta, 13 | "adagrad": Adagrad, 14 | "rmsprop": RMSprop, 15 | } 16 | 17 | 18 | def get_optimizer(cfg): 19 | if cfg["training"]["optimizer"] is None: 20 | logger.info("Using SGD optimizer") 21 | return SGD 22 | 23 | else: 24 | opt_name = cfg["training"]["optimizer"]["name"] 25 | if opt_name not in key2opt: 26 | raise NotImplementedError("Optimizer {} not implemented".format(opt_name)) 27 | 28 | logger.info("Using {} optimizer".format(opt_name)) 29 | return key2opt[opt_name] 30 | -------------------------------------------------------------------------------- /ptsemseg/probe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_vectorize_grad(model_parameter): 4 | grad_vec = None 5 | for i, para in enumerate(model_parameter): 6 | if para.requires_grad == True: 7 | grad = para.grad.view(para.grad.numel()) 8 | if i == 0: 9 | grad_vec = grad 10 | else: 11 | grad_vec = torch.cat((grad_vec,grad), 0) 12 | return grad_vec 13 | -------------------------------------------------------------------------------- /ptsemseg/process_img.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc as m 3 | import time 4 | import cv2 5 | from torch.autograd import Variable 6 | def generate_noise(img,noisy_type=None): 7 | ''' 8 | Input: img: RGB image, noisy_type: string of noisy type 9 | generate noisy image 10 | * image must be RGB 11 | ''' 12 | image_batch = img.shape[0] 13 | img_ch = img.shape[1] 14 | img_row = img.shape[2] 15 | img_col = img.shape[3] 16 | 17 | if noisy_type == 'occlusion': 18 | #print('Noisy_type: Occlusion') 19 | img[:,:,int(img_row/5):(img_row),:] = 0 20 | elif noisy_type == 'random_noisy': 21 | noise = Variable(img.data.new(img.size()).normal_(0, 0.8)) 22 | img = img + noise 23 | img_np = img.data.numpy() 24 | #m.imsave('noisy_image.png',img_np[0].transpose(1,2,0)) 25 | elif noisy_type == 'grayscale': 26 | #print('Noisy_type: Grayscale') 27 | img = np.dot(img[...,:3], [0.299, 0.587, 0.114]) 28 | elif noisy_type == 'low_resolution': 29 | #print('Noisy_type: Low resolution (but now is original image)') 30 | pass 31 | else: 32 | # print('Noisy_type: original image)') 33 | pass 34 | 35 | return img 36 | 37 | def recursive_glob(rootdir=".", suffix=""): 38 | """Performs recursive glob with given suffix and rootdir 39 | :param rootdir is the root directory 40 | :param suffix is the suffix to be searched 41 | """ 42 | return [ 43 | os.path.join(looproot, filename) 44 | for looproot, _, filenames in os.walk(rootdir) 45 | for filename in filenames 46 | if filename.endswith(suffix) 47 | ] 48 | 49 | 50 | def alpha_blend(input_image, segmentation_mask, alpha=0.5): 51 | """Alpha Blending utility to overlay RGB masks on RBG images 52 | :param input_image is a np.ndarray with 3 channels 53 | :param segmentation_mask is a np.ndarray with 3 channels 54 | :param alpha is a float value 55 | """ 56 | blended = np.zeros(input_image.size, dtype=np.float32) 57 | blended = input_image * alpha + segmentation_mask * (1 - alpha) 58 | return blended 59 | 60 | 61 | def convert_state_dict(state_dict): 62 | """Converts a state dict saved from a dataParallel module to normal 63 | module state_dict inplace 64 | :param state_dict is the loaded DataParallel model_state 65 | """ 66 | new_state_dict = OrderedDict() 67 | for k, v in state_dict.items(): 68 | name = k[7:] # remove `module.` 69 | new_state_dict[name] = v 70 | return new_state_dict 71 | 72 | 73 | def get_logger(logdir): 74 | logger = logging.getLogger("ptsemseg") 75 | ts = str(datetime.datetime.now()).split(".")[0].replace(" ", "_") 76 | ts = ts.replace(":", "_").replace("-", "_") 77 | file_path = os.path.join(logdir, "run_{}.log".format(ts)) 78 | hdlr = logging.FileHandler(file_path) 79 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") 80 | hdlr.setFormatter(formatter) 81 | logger.addHandler(hdlr) 82 | logger.setLevel(logging.INFO) 83 | return logger 84 | -------------------------------------------------------------------------------- /ptsemseg/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR, CosineAnnealingLR 4 | 5 | from ptsemseg.schedulers.schedulers import WarmUpLR, ConstantLR, PolynomialLR 6 | 7 | logger = logging.getLogger("ptsemseg") 8 | 9 | key2scheduler = { 10 | "constant_lr": ConstantLR, 11 | "poly_lr": PolynomialLR, 12 | "multi_step": MultiStepLR, 13 | "cosine_annealing": CosineAnnealingLR, 14 | "exp_lr": ExponentialLR, 15 | } 16 | 17 | 18 | def get_scheduler(optimizer, scheduler_dict): 19 | if scheduler_dict is None: 20 | logger.info("Using No LR Scheduling") 21 | return ConstantLR(optimizer) 22 | 23 | s_type = scheduler_dict["name"] 24 | scheduler_dict.pop("name") 25 | 26 | logging.info("Using {} scheduler with {} params".format(s_type, scheduler_dict)) 27 | 28 | warmup_dict = {} 29 | if "warmup_iters" in scheduler_dict: 30 | # This can be done in a more pythonic way... 31 | warmup_dict["warmup_iters"] = scheduler_dict.get("warmup_iters", 100) 32 | warmup_dict["mode"] = scheduler_dict.get("warmup_mode", "linear") 33 | warmup_dict["gamma"] = scheduler_dict.get("warmup_factor", 0.2) 34 | 35 | logger.info( 36 | "Using Warmup with {} iters {} gamma and {} mode".format( 37 | warmup_dict["warmup_iters"], warmup_dict["gamma"], warmup_dict["mode"] 38 | ) 39 | ) 40 | 41 | scheduler_dict.pop("warmup_iters", None) 42 | scheduler_dict.pop("warmup_mode", None) 43 | scheduler_dict.pop("warmup_factor", None) 44 | 45 | base_scheduler = key2scheduler[s_type](optimizer, **scheduler_dict) 46 | return WarmUpLR(optimizer, base_scheduler, **warmup_dict) 47 | 48 | return key2scheduler[s_type](optimizer, **scheduler_dict) 49 | -------------------------------------------------------------------------------- /ptsemseg/schedulers/schedulers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class ConstantLR(_LRScheduler): 5 | def __init__(self, optimizer, last_epoch=-1): 6 | super(ConstantLR, self).__init__(optimizer, last_epoch) 7 | 8 | def get_lr(self): 9 | return [base_lr for base_lr in self.base_lrs] 10 | 11 | 12 | class PolynomialLR(_LRScheduler): 13 | def __init__(self, optimizer, max_iter, decay_iter=1, gamma=0.9, last_epoch=-1): 14 | self.decay_iter = decay_iter 15 | self.max_iter = max_iter 16 | self.gamma = gamma 17 | super(PolynomialLR, self).__init__(optimizer, last_epoch) 18 | 19 | def get_lr(self): 20 | if self.last_epoch % self.decay_iter or self.last_epoch % self.max_iter: 21 | return [base_lr for base_lr in self.base_lrs] 22 | else: 23 | factor = (1 - self.last_epoch / float(self.max_iter)) ** self.gamma 24 | return [base_lr * factor for base_lr in self.base_lrs] 25 | 26 | 27 | class WarmUpLR(_LRScheduler): 28 | def __init__( 29 | self, optimizer, scheduler, mode="linear", warmup_iters=100, gamma=0.2, last_epoch=-1 30 | ): 31 | self.mode = mode 32 | self.scheduler = scheduler 33 | self.warmup_iters = warmup_iters 34 | self.gamma = gamma 35 | super(WarmUpLR, self).__init__(optimizer, last_epoch) 36 | 37 | def get_lr(self): 38 | cold_lrs = self.scheduler.get_lr() 39 | 40 | if self.last_epoch < self.warmup_iters: 41 | if self.mode == "linear": 42 | alpha = self.last_epoch / float(self.warmup_iters) 43 | factor = self.gamma * (1 - alpha) + alpha 44 | 45 | elif self.mode == "constant": 46 | factor = self.gamma 47 | else: 48 | raise KeyError("WarmUp type {} not implemented".format(self.mode)) 49 | 50 | return [factor * base_lr for base_lr in cold_lrs] 51 | 52 | return cold_lrs 53 | -------------------------------------------------------------------------------- /ptsemseg/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc Utility functions 3 | """ 4 | import os 5 | import logging 6 | import datetime 7 | import numpy as np 8 | 9 | from collections import OrderedDict 10 | from torch import nn 11 | import torch.nn.init as init 12 | 13 | def init_weights(m): 14 | if isinstance(m, nn.Conv1d): 15 | init.normal_(m.weight.data) 16 | if m.bias is not None: 17 | init.normal_(m.bias.data) 18 | elif isinstance(m, nn.Conv2d): 19 | init.xavier_normal_(m.weight.data) 20 | if m.bias is not None: 21 | init.normal_(m.bias.data) 22 | elif isinstance(m, nn.Conv3d): 23 | init.xavier_normal_(m.weight.data) 24 | if m.bias is not None: 25 | init.normal_(m.bias.data) 26 | elif isinstance(m, nn.ConvTranspose1d): 27 | init.normal_(m.weight.data) 28 | if m.bias is not None: 29 | init.normal_(m.bias.data) 30 | elif isinstance(m, nn.ConvTranspose2d): 31 | init.xavier_normal_(m.weight.data) 32 | if m.bias is not None: 33 | init.normal_(m.bias.data) 34 | elif isinstance(m, nn.ConvTranspose3d): 35 | init.xavier_normal_(m.weight.data) 36 | if m.bias is not None: 37 | init.normal_(m.bias.data) 38 | elif isinstance(m, nn.BatchNorm1d): 39 | init.normal_(m.weight.data, mean=1, std=0.02) 40 | init.constant_(m.bias.data, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.normal_(m.weight.data, mean=1, std=0.02) 43 | init.constant_(m.bias.data, 0) 44 | elif isinstance(m, nn.BatchNorm3d): 45 | init.normal_(m.weight.data, mean=1, std=0.02) 46 | init.constant_(m.bias.data, 0) 47 | elif isinstance(m, nn.Linear): 48 | init.xavier_normal_(m.weight.data) 49 | init.normal_(m.bias.data) 50 | elif isinstance(m, nn.LSTM): 51 | for param in m.parameters(): 52 | if len(param.shape) >= 2: 53 | init.orthogonal_(param.data) 54 | else: 55 | init.normal_(param.data) 56 | elif isinstance(m, nn.LSTMCell): 57 | for param in m.parameters(): 58 | if len(param.shape) >= 2: 59 | init.orthogonal_(param.data) 60 | else: 61 | init.normal_(param.data) 62 | elif isinstance(m, nn.GRU): 63 | for param in m.parameters(): 64 | if len(param.shape) >= 2: 65 | init.orthogonal_(param.data) 66 | else: 67 | init.normal_(param.data) 68 | elif isinstance(m, nn.GRUCell): 69 | for param in m.parameters(): 70 | if len(param.shape) >= 2: 71 | init.orthogonal_(param.data) 72 | else: 73 | init.normal_(param.data) 74 | 75 | 76 | def recursive_glob(rootdir=".", suffix=""): 77 | """Performs recursive glob with given suffix and rootdir 78 | :param rootdir is the root directory 79 | :param suffix is the suffix to be searched 80 | """ 81 | return [ 82 | os.path.join(looproot, filename) 83 | for looproot, _, filenames in os.walk(rootdir) 84 | for filename in filenames 85 | if filename.endswith(suffix) 86 | ] 87 | 88 | 89 | def alpha_blend(input_image, segmentation_mask, alpha=0.5): 90 | """Alpha Blending utility to overlay RGB masks on RBG images 91 | :param input_image is a np.ndarray with 3 channels 92 | :param segmentation_mask is a np.ndarray with 3 channels 93 | :param alpha is a float value 94 | """ 95 | blended = np.zeros(input_image.size, dtype=np.float32) 96 | blended = input_image * alpha + segmentation_mask * (1 - alpha) 97 | return blended 98 | 99 | 100 | def convert_state_dict(state_dict): 101 | """Converts a state dict saved from a dataParallel module to normal 102 | module state_dict inplace 103 | :param state_dict is the loaded DataParallel model_state 104 | """ 105 | new_state_dict = OrderedDict() 106 | for k, v in state_dict.items(): 107 | name = k[7:] # remove `module.` 108 | new_state_dict[name] = v 109 | return new_state_dict 110 | 111 | 112 | def get_logger(logdir): 113 | logger = logging.getLogger("ptsemseg") 114 | ts = str(datetime.datetime.now()).split(".")[0].replace(" ", "_") 115 | ts = ts.replace(":", "_").replace("-", "_") 116 | file_path = os.path.join(logdir, "run_{}.log".format(ts)) 117 | hdlr = logging.FileHandler(file_path) 118 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s") 119 | hdlr.setFormatter(formatter) 120 | logger.addHandler(hdlr) 121 | logger.setLevel(logging.INFO) 122 | return logger 123 | 124 | 125 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==2.0.0 2 | numpy==1.22.0 3 | scipy==0.19.0 4 | torch==0.4.1 5 | torchvision==0.2.0 6 | tqdm==4.11.2 7 | pydensecrf 8 | protobuf 9 | tensorboardX 10 | pyyaml 11 | pretrainedmodels 12 | opencv-python 13 | -------------------------------------------------------------------------------- /teaser/1359-teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-RIPL/MultiAgentPerception/4ef300547a7f7af2676a034f7cf742b009f57d99/teaser/1359-teaser.gif -------------------------------------------------------------------------------- /teaser/pytorch-logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GT-RIPL/MultiAgentPerception/4ef300547a7f7af2676a034f7cf742b009f57d99/teaser/pytorch-logo-dark.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import torch 3 | import argparse 4 | import timeit 5 | import numpy as np 6 | import cv2 7 | import os 8 | 9 | from torch.utils import data 10 | from ptsemseg.models import get_model 11 | from ptsemseg.loader import get_loader 12 | from ptsemseg.metrics import runningScore 13 | from ptsemseg.utils import convert_state_dict 14 | from ptsemseg.visual import draw_bounding 15 | from ptsemseg.trainer import * 16 | 17 | torch.backends.cudnn.benchmark = True 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser(description="config") 22 | parser.add_argument( 23 | "--config", 24 | nargs="?", 25 | type=str, 26 | default="configs/your_configs.yml", 27 | help="Configuration file to use", 28 | ) 29 | 30 | parser.add_argument( 31 | "--gpu", 32 | nargs="?", 33 | type=str, 34 | default="0", 35 | help="Used GPUs", 36 | ) 37 | 38 | parser.add_argument( 39 | "--run_time", 40 | nargs="?", 41 | type=int, 42 | default=1, 43 | help="run_time", 44 | ) 45 | 46 | 47 | parser.add_argument( 48 | "--model_path", 49 | nargs="?", 50 | type=str, 51 | default="Single_Agent.pkl", 52 | help="Path to the saved model", 53 | ) 54 | 55 | 56 | args = parser.parse_args() 57 | 58 | # Set the gpu 59 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 60 | run_times = args.run_time 61 | 62 | with open(args.config) as fp: 63 | cfg = yaml.load(fp) 64 | 65 | # ============= Testing ============= 66 | 67 | # Setup device 68 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 69 | 70 | # Setup Dataloader 71 | data_loader = get_loader(cfg["data"]["dataset"]) 72 | data_path = cfg["data"]["path"] 73 | 74 | # Load communication label (note that some datasets do not provide this) 75 | if 'commun_label' in cfg["data"]: 76 | if_commun_label = cfg["data"]['commun_label'] 77 | else: 78 | if_commun_label = 'None' 79 | 80 | # test data loadeer 81 | te_loader = data_loader( 82 | data_path, 83 | split=cfg["data"]['test_split'], 84 | is_transform=True, 85 | img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), 86 | target_view=cfg["data"]["target_view"], 87 | commun_label=if_commun_label) 88 | 89 | testloader = data.DataLoader(te_loader, batch_size=cfg["training"]["batch_size"], num_workers=8) 90 | 91 | 92 | # Setup Model 93 | model = get_model(cfg, te_loader.n_classes).to(device) 94 | 95 | # set up the model 96 | if cfg['model']['arch'] == 'LearnWhen2Com': # Our when2com 97 | trainer = Trainer_LearnWhen2Com(cfg, None, None, model, None, None, None, None, None, device) 98 | elif cfg['model']['arch'] == 'LearnWho2Com': # Our who2com 99 | trainer = Trainer_LearnWho2Com(cfg, None, None, model, None, None, None, None, None, device) 100 | elif cfg['model']['arch'] == 'MIMOcom': # 101 | trainer = Trainer_MIMOcom(cfg, None, None, model, None, None, None, None, None, device) 102 | elif cfg['model']['arch'] == 'MIMOcomMultiWarp': 103 | trainer = Trainer_MIMOcomMultiWarp(cfg, None, None, None, None, None, None, None, None, device) 104 | elif cfg['model']['arch'] == 'MIMOcomWho': 105 | trainer = Trainer_MIMOcomWho(cfg, None, None, model, None, None, None, None, None, device) 106 | elif cfg['model']['arch'] == 'Single_agent': 107 | trainer = Trainer_Single_agent(cfg, None, None, model, None, None, None, None, None, device) 108 | elif cfg['model']['arch'] == 'All_agents': 109 | trainer = Trainer_All_agents(cfg, None, None, model, None, None, None, None, None, device) 110 | elif cfg['model']['arch'] == 'MIMO_All_agents': 111 | trainer = Trainer_MIMO_All_agents(cfg, None, None, model, None, None, None, None, None, device) 112 | else: 113 | raise ValueError('Unknown arch name for testing') 114 | 115 | 116 | print(args.model_path) 117 | # load best weight 118 | trainer.load_weight(args.model_path) 119 | 120 | # if you would like to obtain qual results or other stats, just change the output 121 | _ = trainer.evaluate(testloader) 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import time 4 | import shutil 5 | import torch 6 | import random 7 | import argparse 8 | import numpy as np 9 | import copy 10 | import timeit 11 | import statistics 12 | import datetime 13 | from torch.utils import data 14 | from tqdm import tqdm 15 | import cv2 16 | 17 | from ptsemseg.process_img import generate_noise 18 | from ptsemseg.models import get_model 19 | from ptsemseg.loss import get_loss_function 20 | from ptsemseg.loader import get_loader 21 | from ptsemseg.utils import get_logger, init_weights 22 | from ptsemseg.metrics import runningScore 23 | from ptsemseg.augmentations import get_composed_augmentations 24 | from ptsemseg.schedulers import get_scheduler 25 | from ptsemseg.optimizers import get_optimizer 26 | from ptsemseg.utils import convert_state_dict 27 | 28 | from ptsemseg.trainer import * 29 | 30 | from tensorboardX import SummaryWriter 31 | 32 | 33 | # main function 34 | if __name__ == "__main__": 35 | 36 | parser = argparse.ArgumentParser(description="config") 37 | parser.add_argument( 38 | "--config", 39 | nargs="?", 40 | type=str, 41 | default="configs/your_configs.yml", 42 | help="Configuration file to use", 43 | ) 44 | 45 | parser.add_argument( 46 | "--gpu", 47 | nargs="?", 48 | type=str, 49 | default="0", 50 | help="Used GPUs", 51 | ) 52 | 53 | parser.add_argument( 54 | "--run_time", 55 | nargs="?", 56 | type=int, 57 | default=1, 58 | help="run_time", 59 | ) 60 | 61 | args = parser.parse_args() 62 | 63 | # Set the gpu 64 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 65 | run_times = args.run_time 66 | 67 | with open(args.config) as fp: 68 | cfg = yaml.load(fp) 69 | 70 | data_splits = ['val_split', 'test_split'] 71 | 72 | # initialize for results stats 73 | score_list = {} 74 | class_iou_list = {} 75 | acc_list = {} 76 | if cfg['model']['arch'] == 'LearnWho2Com': 77 | for infer in ['softmax', 'argmax_test']: 78 | score_list[infer] = {} 79 | class_iou_list[infer] = {} 80 | acc_list[infer] = {} 81 | for data_sp in data_splits: 82 | score_list[infer][data_sp] = [] 83 | class_iou_list[infer][data_sp] = [] 84 | acc_list[infer][data_sp] = [] 85 | elif cfg['model']['arch'] == 'LearnWhen2Com' or \ 86 | cfg['model']['arch'] == 'MIMOcom' or \ 87 | cfg['model']['arch'] == 'MIMOcomMultiWarp' or \ 88 | cfg['model']['arch'] == 'MIMOcomWho' : 89 | for infer in ['softmax', 'argmax_test', 'activated']: 90 | score_list[infer] = {} 91 | class_iou_list[infer] = {} 92 | acc_list[infer] = {} 93 | for data_sp in data_splits: 94 | score_list[infer][data_sp] = [] 95 | class_iou_list[infer][data_sp] = [] 96 | acc_list[infer][data_sp] = [] 97 | elif cfg['model']['arch'] == 'Single_agent' or cfg['model']['arch'] == 'All_agents' or cfg['model']['arch'] == 'MIMO_All_agents': 98 | for infer in ['default']: 99 | score_list[infer] = {} 100 | class_iou_list[infer] = {} 101 | acc_list[infer] = {} 102 | for data_sp in data_splits: 103 | score_list[infer][data_sp] = [] 104 | class_iou_list[infer][data_sp] = [] 105 | acc_list[infer][data_sp] = [] 106 | 107 | for _ in range(run_times): 108 | run_id = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 109 | logdir = os.path.join("runs", os.path.basename(args.config)[:-4], str(run_id)) 110 | writer = SummaryWriter(logdir=logdir) 111 | 112 | print("RUNDIR: {}".format(logdir)) 113 | shutil.copy(args.config, logdir) 114 | 115 | 116 | # ============= Training ============= 117 | # logger 118 | logger = get_logger(logdir) 119 | logger.info("Begin") 120 | 121 | # Setup seeds 122 | torch.manual_seed(cfg.get("seed", 1337)) 123 | torch.cuda.manual_seed(cfg.get("seed", 1337)) 124 | np.random.seed(cfg.get("seed", 1337)) 125 | random.seed(cfg.get("seed", 1337)) 126 | 127 | # Setup device 128 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 129 | 130 | # Setup Dataloader 131 | data_loader = get_loader(cfg["data"]["dataset"]) 132 | data_path = cfg["data"]["path"] 133 | 134 | # Load communication label (note that some datasets do not provide this) 135 | if 'commun_label' in cfg["data"]: 136 | if_commun_label = cfg["data"]['commun_label'] 137 | else: 138 | if_commun_label = 'None' 139 | 140 | 141 | # dataloaders 142 | t_loader = data_loader( 143 | data_path, 144 | is_transform=True, 145 | split=cfg["data"]["train_split"], 146 | img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), 147 | augmentations=get_composed_augmentations(cfg["training"].get("augmentations", None)), 148 | target_view=cfg["data"]["target_view"], 149 | commun_label=if_commun_label 150 | ) 151 | 152 | v_loader = data_loader( 153 | data_path, 154 | is_transform=True, 155 | split=cfg["data"]["val_split"], 156 | img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), 157 | target_view=cfg["data"]["target_view"], 158 | commun_label=if_commun_label 159 | ) 160 | 161 | trainloader = data.DataLoader( 162 | t_loader, 163 | batch_size=cfg["training"]["batch_size"], 164 | num_workers=cfg["training"]["n_workers"], 165 | shuffle=True, 166 | drop_last=True 167 | ) 168 | 169 | valloader = data.DataLoader( 170 | v_loader, 171 | batch_size=cfg["training"]["batch_size"], 172 | num_workers=cfg["training"]["n_workers"] 173 | ) 174 | 175 | # Setup Model 176 | model = get_model(cfg, t_loader.n_classes).to(device) 177 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) 178 | # import pdb; pdb.set_trace() 179 | 180 | # Setup optimizer 181 | optimizer_cls = get_optimizer(cfg) 182 | optimizer_params = {k: v for k, v in cfg["training"]["optimizer"].items() if k != "name"} 183 | optimizer = optimizer_cls(model.parameters(), **optimizer_params) 184 | logger.info("Using optimizer {}".format(optimizer)) 185 | 186 | # Setup scheduler 187 | scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"]) 188 | 189 | # Setup loss 190 | loss_fn = get_loss_function(cfg) 191 | logger.info("Using loss {}".format(loss_fn)) 192 | 193 | 194 | # ================== TRAINING ================== 195 | if cfg['model']['arch'] == 'LearnWhen2Com': # Our when2com 196 | trainer = Trainer_LearnWhen2Com(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device) 197 | elif cfg['model']['arch'] == 'LearnWho2Com': # Our who2com 198 | trainer = Trainer_LearnWho2Com(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device) 199 | elif cfg['model']['arch'] == 'MIMOcom': # 200 | trainer = Trainer_MIMOcom(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device) 201 | elif cfg['model']['arch'] == 'MIMOcomMultiWarp': 202 | trainer = Trainer_MIMOcomMultiWarp(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device) 203 | elif cfg['model']['arch'] == 'MIMOcomWho': 204 | trainer = Trainer_MIMOcomWho(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device) 205 | elif cfg['model']['arch'] == 'Single_agent': 206 | trainer = Trainer_Single_agent(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device) 207 | elif cfg['model']['arch'] == 'All_agents': 208 | trainer = Trainer_All_agents(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device) 209 | elif cfg['model']['arch'] == 'MIMO_All_agents': 210 | trainer = Trainer_MIMO_All_agents(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device) 211 | else: 212 | raise ValueError('Unknown arch name for training') 213 | 214 | model_path = trainer.train() 215 | 216 | 217 | # ================ Val + Test ================ 218 | 219 | te_loader = data_loader( 220 | data_path, 221 | split=cfg["data"]['test_split'], 222 | is_transform=True, 223 | img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]), 224 | target_view=cfg["data"]["target_view"], 225 | commun_label=if_commun_label) 226 | 227 | n_classes = te_loader.n_classes 228 | testloader = data.DataLoader(te_loader, batch_size=cfg["training"]["batch_size"], num_workers=8) 229 | 230 | # load best weight 231 | trainer.load_weight(model_path) 232 | _ = trainer.evaluate(testloader) 233 | 234 | 235 | 236 | --------------------------------------------------------------------------------