├── .gitignore
├── LICENSE
├── README.md
├── configs
├── multi-request-multi-support
│ ├── mrms_allnorm.yml
│ ├── mrms_occdeg.yml
│ ├── mrms_randcom.yml
│ ├── mrms_when2com.yml
│ └── mrms_who2com.yml
└── single-request-multiple-support
│ ├── srms_allnorm.yml
│ ├── srms_occdeg.yml
│ ├── srms_randcom.yml
│ ├── srms_when2com.yml
│ └── srms_who2com.yml
├── ptsemseg
├── __init__.py
├── augmentations
│ ├── __init__.py
│ └── augmentations.py
├── loader
│ ├── __init__.py
│ └── airsim_loader.py
├── loss
│ ├── __init__.py
│ └── loss.py
├── metrics.py
├── models
│ ├── __init__.py
│ ├── agent.py
│ ├── backbone.py
│ └── utils.py
├── optimizers
│ └── __init__.py
├── probe.py
├── process_img.py
├── schedulers
│ ├── __init__.py
│ └── schedulers.py
├── trainer.py
└── utils.py
├── requirements.txt
├── teaser
├── 1359-teaser.gif
└── pytorch-logo-dark.png
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Torch Models
7 | *.pkl
8 | *.pth
9 | current_train.py
10 | video_test*.py
11 | *.swp
12 | data
13 | ckpt
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | local_test.py
21 | .DS_STORE
22 | .idea/
23 | .vscode/
24 | env/
25 | build/
26 | develop-eggs/
27 | dist/
28 | downloads/
29 | eggs/
30 | .eggs/
31 | lib/
32 | lib64/
33 | parts/
34 | sdist/
35 | var/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *,cover
59 | .hypothesis/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # IPython Notebook
83 | .ipynb_checkpoints
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # celery beat schedule file
89 | celerybeat-schedule
90 |
91 | # dotenv
92 | .env
93 |
94 | # virtualenv
95 | venv/
96 | ENV/
97 |
98 | # Spyder project settings
99 | .spyderproject
100 |
101 | # Rope project settings
102 | .ropeproject
103 |
104 | runs
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 GT-RIPL, Yen-Cheng Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## When2com: Multi-Agent Perception via Communication Graph Grouping
2 |
[](https://opensource.org/licenses/MIT)
3 |
4 | This is the PyTorch implementation of our paper:
5 | **When2com: Multi-Agent Perception via Communication Graph Grouping**
6 | [__***Yen-Cheng Liu***__](https://ycliu93.github.io/), [Junjiao Tian](https://www.linkedin.com/in/junjiao-tian-42b9758a/), [Nathaniel Glaser](https://sites.google.com/view/nathanglaser/), [Zsolt Kira](https://www.cc.gatech.edu/~zk15/)
7 | IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2020
8 |
9 |
10 | [[Paper](http://openaccess.thecvf.com/content_CVPR_2020/papers/Liu_When2com_Multi-Agent_Perception_via_Communication_Graph_Grouping_CVPR_2020_paper.pdf)] [[GitHub](https://github.gatech.edu/RIPL/multi-agent-perception)] [[Project](https://ycliu93.github.io/projects/multi-agent-perception.html)]
11 |
12 |
13 |
14 |
15 |
16 | ## Prerequisites
17 | - Python 3.6
18 | - Pytorch 0.4.1
19 | - Other required packages in `requirement.txt`
20 |
21 |
22 | ## Getting started
23 | ### Download and install miniconda
24 | ```
25 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
26 | bash Miniconda3-latest-Linux-x86_64.sh
27 | ```
28 |
29 | ### Create conda environment
30 | ```
31 | conda create -n semseg python=3.6
32 | source actviate semseg
33 | ```
34 |
35 | ### Install the required packages
36 | ```
37 | pip install -r requirements.txt
38 | ```
39 |
40 | ### Download AirSim-MAP dataset and unzip it.
41 | - Download the zip file you would like to run
42 |
43 | [](https://gtvault-my.sharepoint.com/:f:/g/personal/yliu3133_gatech_edu/Ett0G1_5YYdBpgojk0uWESgBi95dO79LkbYaKRhlBIkVJQ?e=vdjklb/)
44 |
45 |
46 | ### Move the datasets to the dataset path
47 | ```
48 | mkdir dataset
49 | mv (dataset folder name) dataset/
50 | ```
51 |
52 | ### Training
53 | ```
54 | # [Single-request multi-support] All norm
55 | python train.py --config configs/srms-allnorm.yml --gpu=0
56 |
57 | # [Multi-request multi-support] when2com model
58 | python train.py --config configs/mrms-when2com.yml --gpu=0
59 |
60 | ```
61 |
62 | ### Testing
63 | ```
64 | # [Single-request multi-support] All norm
65 | python test.py --config configs/srms-allnorm.yml --model_path --gpu=0
66 |
67 | # [Multi-request multi-support] when2com model
68 | python test.py --config configs/mrms-when2com.yml --model_path --gpu=0
69 | ```
70 |
71 | ## Acknowledgments
72 | - This work was supported by ONR grant N00014-18-1-2829.
73 | - This code is built upon the implementation from [Pytorch-semseg](https://github.com/meetshah1995/pytorch-semseg).
74 |
75 | ## Citation
76 | If you find this repository useful, please cite our paper:
77 |
78 | ```
79 | @inproceedings{liu2020when2com,
80 | title={When2com: Multi-Agent Perception via Communication Graph Grouping},
81 | author={Yen-Cheng Liu and Junjiao Tian and Nathaniel Glaser and Zsolt Kira},
82 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
83 | year={2020}
84 | }
85 | ```
86 |
--------------------------------------------------------------------------------
/configs/multi-request-multi-support/mrms_allnorm.yml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: Single_agent
3 | shuffle_features: None
4 | agent_num: 6
5 | enc_backbone: resnet_encoder
6 | dec_backbone: simple_decoder
7 | feat_squeezer: -1
8 | feat_channel: 512
9 | multiple_output: True
10 | data:
11 | dataset: airsim
12 | train_split: train
13 | val_split: val
14 | test_split: test
15 | img_rows: 512
16 | img_cols: 512
17 | path: dataset/airsim-mrms-data
18 | noisy_type: None
19 | target_view: '6agent'
20 | training:
21 | train_iters: 12
22 | batch_size: 2
23 | val_interval: 6
24 | n_workers: 4
25 | print_interval: 2
26 | optimizer:
27 | name: 'adam'
28 | lr: 1.0e-5
29 | loss:
30 | name: 'cross_entropy'
31 | size_average: True
32 | lr_schedule:
33 | resume: None
34 |
--------------------------------------------------------------------------------
/configs/multi-request-multi-support/mrms_occdeg.yml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: Single_agent
3 | shuffle_features: None
4 | agent_num: 6
5 | enc_backbone: resnet_encoder
6 | dec_backbone: simple_decoder
7 | feat_squeezer: -1
8 | feat_channel: 512
9 | multiple_output: True
10 | data:
11 | dataset: airsim
12 | train_split: train
13 | val_split: val
14 | test_split: test
15 | img_rows: 512
16 | img_cols: 512
17 | path: dataset/airsim-mrms-noise-data
18 | noisy_type: None
19 | target_view: '6agent'
20 | training:
21 | train_iters: 200000
22 | batch_size: 2
23 | val_interval: 1000
24 | n_workers: 4
25 | print_interval: 50
26 | optimizer:
27 | name: 'adam'
28 | lr: 1.0e-5
29 | loss:
30 | name: 'cross_entropy'
31 | size_average: True
32 | lr_schedule:
33 | resume: None
34 |
--------------------------------------------------------------------------------
/configs/multi-request-multi-support/mrms_randcom.yml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: MIMO_All_agents
3 | shuffle_features: 'selection'
4 | agent_num: 6
5 | enc_backbone: resnet_encoder
6 | dec_backbone: simple_decoder
7 | feat_squeezer: -1
8 | feat_channel: 512
9 | multiple_output: True
10 | data:
11 | dataset: airsim
12 | train_split: train
13 | val_split: val
14 | test_split: test
15 | img_rows: 512
16 | img_cols: 512
17 | path: dataset/airsim-mrms-noise-data
18 | noisy_type: None
19 | target_view: '6agent'
20 | commun_label: 'mimo'
21 | training:
22 | train_iters: 200000
23 | batch_size: 1
24 | val_interval: 1000
25 | n_workers: 4
26 | print_interval: 50
27 | optimizer:
28 | name: 'adam'
29 | lr: 1.0e-5
30 | loss:
31 | name: 'cross_entropy'
32 | size_average: True
33 | lr_schedule:
34 | resume: None
35 |
--------------------------------------------------------------------------------
/configs/multi-request-multi-support/mrms_when2com.yml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: MIMOcom
3 | agent_num: 6
4 | shared_policy: True
5 | shared_img_encoder: 'unified'
6 | attention: 'general'
7 | sparse: False
8 | query: True
9 | query_size: 32
10 | key_size: 1024
11 | enc_backbone: resnet_encoder
12 | dec_backbone: simple_decoder
13 | feat_squeezer: -1
14 | feat_channel: 512
15 | multiple_output: True
16 | data:
17 | dataset: airsim
18 | train_split: train
19 | val_split: val
20 | test_split: test
21 | img_rows: 512
22 | img_cols: 512
23 | path: dataset/airsim-mrms-noise-data
24 | noisy_type: None
25 | target_view: '6agent'
26 | commun_label: 'mimo'
27 | training:
28 | train_iters: 200000
29 | batch_size: 2
30 | val_interval: 1000
31 | n_workers: 8
32 | print_interval: 50
33 | optimizer:
34 | name: 'adam'
35 | lr: 1.0e-5
36 | loss:
37 | name: 'cross_entropy'
38 | size_average: True
39 | lr_schedule:
40 | resume: None
41 |
--------------------------------------------------------------------------------
/configs/multi-request-multi-support/mrms_who2com.yml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: MIMOcomWho
3 | agent_num: 6
4 | shared_policy: True
5 | shared_img_encoder: 'unified'
6 | attention: 'general'
7 | sparse: False
8 | query: False
9 | query_size: 32
10 | key_size: 1024
11 | enc_backbone: resnet_encoder
12 | dec_backbone: simple_decoder
13 | feat_squeezer: -1
14 | feat_channel: 512
15 | multiple_output: True
16 | data:
17 | dataset: airsim
18 | train_split: train
19 | val_split: val
20 | test_split: test
21 | img_rows: 512
22 | img_cols: 512
23 | path: dataset/airsim-mrms-noise-data
24 | noisy_type: None
25 | target_view: '6agent'
26 | commun_label: 'mimo'
27 | training:
28 | train_iters: 200000
29 | batch_size: 2
30 | val_interval: 1000
31 | n_workers: 8
32 | print_interval: 50
33 | optimizer:
34 | name: 'adam'
35 | lr: 1.0e-5
36 | loss:
37 | name: 'cross_entropy'
38 | size_average: True
39 | lr_schedule:
40 | resume: None
41 |
--------------------------------------------------------------------------------
/configs/single-request-multiple-support/srms_allnorm.yml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: Single_agent
3 | shuffle_features: None
4 | agent_num: 5
5 | enc_backbone: resnet_encoder
6 | dec_backbone: simple_decoder
7 | feat_squeezer: -1
8 | feat_channel: 512
9 | multiple_output: False
10 | data:
11 | dataset: airsim
12 | train_split: train
13 | val_split: val
14 | test_split: test
15 | img_rows: 512
16 | img_cols: 512
17 | path: dataset/airsim-srms-data
18 | noisy_type: None
19 | target_view: 'target'
20 | commun_label: 'None'
21 | training:
22 | train_iters: 200000
23 | batch_size: 2
24 | val_interval: 1000
25 | n_workers: 8
26 | print_interval: 50
27 | optimizer:
28 | name: 'adam'
29 | lr: 1.0e-5
30 | loss:
31 | name: 'cross_entropy'
32 | size_average: True
33 | lr_schedule:
34 | resume: None
35 |
--------------------------------------------------------------------------------
/configs/single-request-multiple-support/srms_occdeg.yml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: Single_agent
3 | shuffle_features: None
4 | agent_num: 5
5 | enc_backbone: resnet_encoder
6 | dec_backbone: simple_decoder
7 | feat_squeezer: -1
8 | feat_channel: 512
9 | multiple_output: False
10 | data:
11 | dataset: airsim
12 | train_split: train
13 | val_split: val
14 | test_split: test
15 | img_rows: 512
16 | img_cols: 512
17 | path: dataset/airsim-srms-noise-data
18 | noisy_type: None
19 | target_view: 'target'
20 | commun_label: 'None'
21 | training:
22 | train_iters: 200000
23 | batch_size: 2
24 | val_interval: 1000
25 | n_workers: 8
26 | print_interval: 50
27 | optimizer:
28 | name: 'adam'
29 | lr: 1.0e-5
30 | loss:
31 | name: 'cross_entropy'
32 | size_average: True
33 | lr_schedule:
34 | resume: None
35 |
--------------------------------------------------------------------------------
/configs/single-request-multiple-support/srms_randcom.yml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: All_agents
3 | shuffle_features: 'selection'
4 | agent_num: 5
5 | enc_backbone: resnet_encoder
6 | dec_backbone: simple_decoder
7 | feat_squeezer: -1
8 | feat_channel: 512
9 | multiple_output: False
10 | data:
11 | dataset: airsim
12 | train_split: train
13 | val_split: val
14 | test_split: test
15 | img_rows: 512
16 | img_cols: 512
17 | path: dataset/airsim-srms-noise-data
18 | noisy_type: None
19 | target_view: 'target'
20 | commun_label: 'when2com'
21 | training:
22 | train_iters: 200000
23 | batch_size: 2
24 | val_interval: 1000
25 | n_workers: 8
26 | print_interval: 50
27 | optimizer:
28 | name: 'adam'
29 | lr: 1.0e-5
30 | loss:
31 | name: 'cross_entropy'
32 | size_average: True
33 | lr_schedule:
34 | resume: None
35 |
--------------------------------------------------------------------------------
/configs/single-request-multiple-support/srms_when2com.yml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: LearnWhen2Com
3 | agent_num: 5
4 | shared_policy: True
5 | shared_img_encoder: 'unified'
6 | attention: 'general'
7 | sparse: False
8 | query: True
9 | query_size: 8
10 | key_size: 1024
11 | enc_backbone: resnet_encoder
12 | dec_backbone: simple_decoder
13 | feat_squeezer: -1
14 | feat_channel: 512
15 | multiple_output: False
16 | data:
17 | dataset: airsim
18 | train_split: train
19 | val_split: val
20 | test_split: test
21 | img_rows: 512
22 | img_cols: 512
23 | path: dataset/airsim-srms-noise-data
24 | noisy_type: None
25 | target_view: 'target'
26 | commun_label: 'when2com'
27 | training:
28 | train_iters: 200000
29 | batch_size: 2
30 | val_interval: 1000
31 | n_workers: 8
32 | print_interval: 50
33 | optimizer:
34 | name: 'adam'
35 | lr: 1.0e-5
36 | loss:
37 | name: 'cross_entropy'
38 | size_average: True
39 | lr_schedule:
40 | resume: None
41 |
--------------------------------------------------------------------------------
/configs/single-request-multiple-support/srms_who2com.yml:
--------------------------------------------------------------------------------
1 | model:
2 | arch: LearnWho2Com
3 | agent_num: 5
4 | shared_policy: True
5 | shared_img_encoder: 'only_normal_agents'
6 | attention: 'general'
7 | sparse: False
8 | query: True
9 | query_size: 8
10 | key_size: 1024
11 | enc_backbone: resnet_encoder
12 | dec_backbone: simple_decoder
13 | feat_squeezer: -1
14 | feat_channel: 512
15 | multiple_output: False
16 | data:
17 | dataset: airsim
18 | train_split: train
19 | val_split: val
20 | test_split: test
21 | img_rows: 512
22 | img_cols: 512
23 | path: dataset/airsim-srms-noise-data
24 | noisy_type: None
25 | target_view: 'target'
26 | commun_label: 'when2com'
27 | training:
28 | train_iters: 200000
29 | batch_size: 2
30 | val_interval: 1000
31 | n_workers: 8
32 | print_interval: 50
33 | optimizer:
34 | name: 'adam'
35 | lr: 1.0e-5
36 | loss:
37 | name: 'cross_entropy'
38 | size_average: True
39 | lr_schedule:
40 | resume: None
41 |
--------------------------------------------------------------------------------
/ptsemseg/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GT-RIPL/MultiAgentPerception/4ef300547a7f7af2676a034f7cf742b009f57d99/ptsemseg/__init__.py
--------------------------------------------------------------------------------
/ptsemseg/augmentations/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from ptsemseg.augmentations.augmentations import (
3 | AdjustContrast,
4 | AdjustGamma,
5 | AdjustBrightness,
6 | AdjustSaturation,
7 | AdjustHue,
8 | RandomCrop,
9 | RandomHorizontallyFlip,
10 | RandomVerticallyFlip,
11 | Scale,
12 | RandomSized,
13 | RandomSizedCrop,
14 | RandomRotate,
15 | RandomTranslate,
16 | CenterCrop,
17 | Compose,
18 | )
19 |
20 | logger = logging.getLogger("ptsemseg")
21 |
22 | key2aug = {
23 | "gamma": AdjustGamma,
24 | "hue": AdjustHue,
25 | "brightness": AdjustBrightness,
26 | "saturation": AdjustSaturation,
27 | "contrast": AdjustContrast,
28 | "rcrop": RandomCrop,
29 | "hflip": RandomHorizontallyFlip,
30 | "vflip": RandomVerticallyFlip,
31 | "scale": Scale,
32 | "rsize": RandomSized,
33 | "rsizecrop": RandomSizedCrop,
34 | "rotate": RandomRotate,
35 | "translate": RandomTranslate,
36 | "ccrop": CenterCrop,
37 | }
38 |
39 |
40 | def get_composed_augmentations(aug_dict):
41 | if aug_dict is None:
42 | logger.info("Using No Augmentations")
43 | return None
44 |
45 | augmentations = []
46 | for aug_key, aug_param in aug_dict.items():
47 | augmentations.append(key2aug[aug_key](aug_param))
48 | logger.info("Using {} aug with params {}".format(aug_key, aug_param))
49 | return Compose(augmentations)
50 |
--------------------------------------------------------------------------------
/ptsemseg/augmentations/augmentations.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numbers
3 | import random
4 | import numpy as np
5 | import torchvision.transforms.functional as tf
6 |
7 | from PIL import Image, ImageOps
8 |
9 |
10 | class Compose(object):
11 | def __init__(self, augmentations):
12 | self.augmentations = augmentations
13 | self.PIL2Numpy = False
14 |
15 | def __call__(self, img, mask):
16 | if isinstance(img, np.ndarray):
17 | img = Image.fromarray(img, mode="RGB")
18 | mask = Image.fromarray(mask, mode="L")
19 | self.PIL2Numpy = True
20 |
21 | assert img.size == mask.size
22 | for a in self.augmentations:
23 | img, mask = a(img, mask)
24 |
25 | if self.PIL2Numpy:
26 | img, mask = np.array(img), np.array(mask, dtype=np.uint8)
27 |
28 | return img, mask
29 |
30 |
31 | class RandomCrop(object):
32 | def __init__(self, size, padding=0):
33 | if isinstance(size, numbers.Number):
34 | self.size = (int(size), int(size))
35 | else:
36 | self.size = size
37 | self.padding = padding
38 |
39 | def __call__(self, img, mask):
40 | if self.padding > 0:
41 | img = ImageOps.expand(img, border=self.padding, fill=0)
42 | mask = ImageOps.expand(mask, border=self.padding, fill=0)
43 |
44 | assert img.size == mask.size
45 | w, h = img.size
46 | th, tw = self.size
47 | if w == tw and h == th:
48 | return img, mask
49 | if w < tw or h < th:
50 | return (img.resize((tw, th), Image.BILINEAR), mask.resize((tw, th), Image.NEAREST))
51 |
52 | x1 = random.randint(0, w - tw)
53 | y1 = random.randint(0, h - th)
54 | return (img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)))
55 |
56 |
57 | class AdjustGamma(object):
58 | def __init__(self, gamma):
59 | self.gamma = gamma
60 |
61 | def __call__(self, img, mask):
62 | assert img.size == mask.size
63 | return tf.adjust_gamma(img, random.uniform(1, 1 + self.gamma)), mask
64 |
65 |
66 | class AdjustSaturation(object):
67 | def __init__(self, saturation):
68 | self.saturation = saturation
69 |
70 | def __call__(self, img, mask):
71 | assert img.size == mask.size
72 | return (
73 | tf.adjust_saturation(img, random.uniform(1 - self.saturation, 1 + self.saturation)),
74 | mask,
75 | )
76 |
77 |
78 | class AdjustHue(object):
79 | def __init__(self, hue):
80 | self.hue = hue
81 |
82 | def __call__(self, img, mask):
83 | assert img.size == mask.size
84 | return tf.adjust_hue(img, random.uniform(-self.hue, self.hue)), mask
85 |
86 |
87 | class AdjustBrightness(object):
88 | def __init__(self, bf):
89 | self.bf = bf
90 |
91 | def __call__(self, img, mask):
92 | assert img.size == mask.size
93 | return tf.adjust_brightness(img, random.uniform(1 - self.bf, 1 + self.bf)), mask
94 |
95 |
96 | class AdjustContrast(object):
97 | def __init__(self, cf):
98 | self.cf = cf
99 |
100 | def __call__(self, img, mask):
101 | assert img.size == mask.size
102 | return tf.adjust_contrast(img, random.uniform(1 - self.cf, 1 + self.cf)), mask
103 |
104 |
105 | class CenterCrop(object):
106 | def __init__(self, size):
107 | if isinstance(size, numbers.Number):
108 | self.size = (int(size), int(size))
109 | else:
110 | self.size = size
111 |
112 | def __call__(self, img, mask):
113 | assert img.size == mask.size
114 | w, h = img.size
115 | th, tw = self.size
116 | x1 = int(round((w - tw) / 2.0))
117 | y1 = int(round((h - th) / 2.0))
118 | return (img.crop((x1, y1, x1 + tw, y1 + th)), mask.crop((x1, y1, x1 + tw, y1 + th)))
119 |
120 |
121 | class RandomHorizontallyFlip(object):
122 | def __init__(self, p):
123 | self.p = p
124 |
125 | def __call__(self, img, mask):
126 | if random.random() < self.p:
127 | return (img.transpose(Image.FLIP_LEFT_RIGHT), mask.transpose(Image.FLIP_LEFT_RIGHT))
128 | return img, mask
129 |
130 |
131 | class RandomVerticallyFlip(object):
132 | def __init__(self, p):
133 | self.p = p
134 |
135 | def __call__(self, img, mask):
136 | if random.random() < self.p:
137 | return (img.transpose(Image.FLIP_TOP_BOTTOM), mask.transpose(Image.FLIP_TOP_BOTTOM))
138 | return img, mask
139 |
140 |
141 | class FreeScale(object):
142 | def __init__(self, size):
143 | self.size = tuple(reversed(size)) # size: (h, w)
144 |
145 | def __call__(self, img, mask):
146 | assert img.size == mask.size
147 | return (img.resize(self.size, Image.BILINEAR), mask.resize(self.size, Image.NEAREST))
148 |
149 |
150 | class RandomTranslate(object):
151 | def __init__(self, offset):
152 | # tuple (delta_x, delta_y)
153 | self.offset = offset
154 |
155 | def __call__(self, img, mask):
156 | assert img.size == mask.size
157 | x_offset = int(2 * (random.random() - 0.5) * self.offset[0])
158 | y_offset = int(2 * (random.random() - 0.5) * self.offset[1])
159 |
160 | x_crop_offset = x_offset
161 | y_crop_offset = y_offset
162 | if x_offset < 0:
163 | x_crop_offset = 0
164 | if y_offset < 0:
165 | y_crop_offset = 0
166 |
167 | cropped_img = tf.crop(
168 | img,
169 | y_crop_offset,
170 | x_crop_offset,
171 | img.size[1] - abs(y_offset),
172 | img.size[0] - abs(x_offset),
173 | )
174 |
175 | if x_offset >= 0 and y_offset >= 0:
176 | padding_tuple = (0, 0, x_offset, y_offset)
177 |
178 | elif x_offset >= 0 and y_offset < 0:
179 | padding_tuple = (0, abs(y_offset), x_offset, 0)
180 |
181 | elif x_offset < 0 and y_offset >= 0:
182 | padding_tuple = (abs(x_offset), 0, 0, y_offset)
183 |
184 | elif x_offset < 0 and y_offset < 0:
185 | padding_tuple = (abs(x_offset), abs(y_offset), 0, 0)
186 |
187 | return (
188 | tf.pad(cropped_img, padding_tuple, padding_mode="reflect"),
189 | tf.affine(
190 | mask,
191 | translate=(-x_offset, -y_offset),
192 | scale=1.0,
193 | angle=0.0,
194 | shear=0.0,
195 | fillcolor=250,
196 | ),
197 | )
198 |
199 |
200 | class RandomRotate(object):
201 | def __init__(self, degree):
202 | self.degree = degree
203 |
204 | def __call__(self, img, mask):
205 | rotate_degree = random.random() * 2 * self.degree - self.degree
206 | return (
207 | tf.affine(
208 | img,
209 | translate=(0, 0),
210 | scale=1.0,
211 | angle=rotate_degree,
212 | resample=Image.BILINEAR,
213 | fillcolor=(0, 0, 0),
214 | shear=0.0,
215 | ),
216 | tf.affine(
217 | mask,
218 | translate=(0, 0),
219 | scale=1.0,
220 | angle=rotate_degree,
221 | resample=Image.NEAREST,
222 | fillcolor=250,
223 | shear=0.0,
224 | ),
225 | )
226 |
227 |
228 | class Scale(object):
229 | def __init__(self, size):
230 | self.size = size
231 |
232 | def __call__(self, img, mask):
233 | assert img.size == mask.size
234 | w, h = img.size
235 | if (w >= h and w == self.size) or (h >= w and h == self.size):
236 | return img, mask
237 | if w > h:
238 | ow = self.size
239 | oh = int(self.size * h / w)
240 | return (img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST))
241 | else:
242 | oh = self.size
243 | ow = int(self.size * w / h)
244 | return (img.resize((ow, oh), Image.BILINEAR), mask.resize((ow, oh), Image.NEAREST))
245 |
246 |
247 | class RandomSizedCrop(object):
248 | def __init__(self, size):
249 | self.size = size
250 |
251 | def __call__(self, img, mask):
252 | assert img.size == mask.size
253 | for attempt in range(10):
254 | area = img.size[0] * img.size[1]
255 | target_area = random.uniform(0.45, 1.0) * area
256 | aspect_ratio = random.uniform(0.5, 2)
257 |
258 | w = int(round(math.sqrt(target_area * aspect_ratio)))
259 | h = int(round(math.sqrt(target_area / aspect_ratio)))
260 |
261 | if random.random() < 0.5:
262 | w, h = h, w
263 |
264 | if w <= img.size[0] and h <= img.size[1]:
265 | x1 = random.randint(0, img.size[0] - w)
266 | y1 = random.randint(0, img.size[1] - h)
267 |
268 | img = img.crop((x1, y1, x1 + w, y1 + h))
269 | mask = mask.crop((x1, y1, x1 + w, y1 + h))
270 | assert img.size == (w, h)
271 |
272 | return (
273 | img.resize((self.size, self.size), Image.BILINEAR),
274 | mask.resize((self.size, self.size), Image.NEAREST),
275 | )
276 |
277 | # Fallback
278 | scale = Scale(self.size)
279 | crop = CenterCrop(self.size)
280 | return crop(*scale(img, mask))
281 |
282 |
283 | class RandomSized(object):
284 | def __init__(self, size):
285 | self.size = size
286 | self.scale = Scale(self.size)
287 | self.crop = RandomCrop(self.size)
288 |
289 | def __call__(self, img, mask):
290 | assert img.size == mask.size
291 |
292 | w = int(random.uniform(0.5, 2) * img.size[0])
293 | h = int(random.uniform(0.5, 2) * img.size[1])
294 |
295 | img, mask = (img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST))
296 |
297 | return self.crop(*self.scale(img, mask))
298 |
--------------------------------------------------------------------------------
/ptsemseg/loader/__init__.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from ptsemseg.loader.airsim_loader import airsimLoader
4 |
5 |
6 | def get_loader(name):
7 | """get_loader
8 |
9 | :param name:
10 | """
11 | return {
12 | "airsim": airsimLoader,
13 | }[name]
14 |
15 |
--------------------------------------------------------------------------------
/ptsemseg/loader/airsim_loader.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | matplotlib.use('Agg')
3 |
4 | import os
5 | import torch
6 | import numpy as np
7 | import glob
8 | import cv2
9 | import copy
10 | from random import shuffle
11 | import random
12 | from torch.utils import data
13 | from ast import literal_eval as make_tuple
14 | import matplotlib.pyplot as plt
15 |
16 |
17 |
18 |
19 | def label_region_n_compute_distance(i,path_tuple):
20 | begin = path_tuple[0]
21 | end = path_tuple[1]
22 |
23 | # computer distance
24 | distance = ((begin[0]-end[0])**2 +(begin[1]-end[1])**2)**(0.5)
25 |
26 |
27 | # label region
28 | if begin[0] <= -400 or end[0]< -400:
29 | region = 'suburban'
30 | else:
31 | if begin[1] >= 300 or end[1] >=300:
32 | region = 'shopping'
33 | else:
34 | region = 'skyscraper'
35 |
36 |
37 | # update tuple
38 | path_tuple = (i,)+path_tuple + (distance, region,)
39 |
40 | return path_tuple
41 |
42 |
43 |
44 |
45 | class airsimLoader(data.Dataset):
46 |
47 | # segmentation decoding colors
48 | name2color = {"person": [[135, 169, 180]],
49 | "sidewalk": [[242, 107, 146]],
50 | "road": [[156,198,23],[43,79,150]],
51 | "sky": [[209,247,202]],
52 | "pole": [[249,79,73],[72,137,21],[45,157,177],[67,266,253],[206,190,59]],
53 | "building": [[161,171,27],[61,212,54],[151,161,26]],
54 | "car": [[153,108,6]],
55 | "bus": [[190,225,64]],
56 | "truck": [[112,105,191]],
57 | "vegetation": [[29,26,199],[234,21,250],[145,71,201],[247,200,111]]
58 | }
59 |
60 |
61 | name2id = {"person": 1,
62 | "sidewalk": 2,
63 | "road": 3,
64 | "sky": 4,
65 | "pole": 5,
66 | "building": 6,
67 | "car": 7,
68 | "bus": 8,
69 | "truck": 9,
70 | "vegetation": 10 }
71 |
72 |
73 | id2name = {i:name for name,i in name2id.items()}
74 |
75 | splits = ['train', 'val', 'test']
76 | image_modes = ['scene', 'segmentation_decoded']
77 |
78 | weathers = ['async_rotate_fog_000_clear']
79 |
80 | # list of nodes on the maps (this needs for loading from the folder)
81 | all_edges = [
82 | ((0, 0), (16, -74)),
83 | ((16, -74), (-86, -78)),
84 | ((-86, -78), (-94, -58)),
85 | ((-94, -58), (-94, 24)),
86 | ((-94, 24), (-143, 24)),
87 | ((-143, 24), (-219, 24)),
88 | ((-219, 24), (-219, -68)),
89 | ((-219, -68), (-214, -127)),
90 | ((-214, -127), (-336, -132)),
91 | ((-336, -132), (-335, -180)),
92 | ((-335, -180), (-216, -205)),
93 | ((-216, -205), (-226, -241)),
94 | ((-226, -241), (-240, -252)),
95 | ((-240, -252), (-440, -260)),
96 | ((-440, -260), (-483, -253)),
97 | ((-483, -253), (-494, -223)),
98 | ((-494, -223), (-493, -127)),
99 | ((-493, -127), (-441, -129)),
100 | ((-441, -129), (-443, -222)),
101 | ((-443, -222), (-339, -221)),
102 | ((-339, -221), (-335, -180)),
103 | ((-219, 24), (-248, 24)),
104 | ((-248, 24), (-302, 24)),
105 | ((-302, 24), (-337, 24)),
106 | ((-337, 24), (-593, 25)),
107 | ((-593, 25), (-597, -128)),
108 | ((-597, -128), (-597, -220)),
109 | ((-597, -220), (-748, -222)),
110 | ((-748, -222), (-744, -128)),
111 | ((-744, -128), (-746, 24)),
112 | ((-744, -128), (-597, -128)),
113 | ((-593, 25), (-746, 24)),
114 | ((-746, 24), (-832, 27)),
115 | ((-832, 27), (-804, 176)),
116 | ((-804, 176), (-747, 178)),
117 | ((-747, 178), (-745, 103)),
118 | ((-745, 103), (-696, 104)),
119 | ((-696, 104), (-596, 102)),
120 | ((-596, 102), (-599, 177)),
121 | ((-599, 177), (-747, 178)),
122 | ((-599, 177), (-597, 253)),
123 | ((-596, 102), (-593, 25)),
124 | ((-337, 24), (-338, 172)),
125 | ((-337, 172), (-332, 251)),
126 | ((-337, 172), (-221, 172)),
127 | ((-221, 172), (-221, 264)),
128 | ((-221, 172), (-219, 90)),
129 | ((-219, 90), (-219, 24)),
130 | ((-221, 172), (-148, 172)),
131 | ((-148, 172), (-130, 172)),
132 | ((-130, 172), (-57, 172)),
133 | ((-57, 172), (-57, 194)),
134 | ((20, 192), (20, 92)),
135 | ((20, 92), (21, 76)),
136 | ((21, 76), (66, 22)),
137 | ((66, 22), (123, 28)),
138 | ((123, 28), (123, 106)),
139 | ((123, 106), (123, 135)),
140 | ((123, 135), (176, 135)),
141 | ((176, 135), (176, 179)),
142 | ((176, 179), (210, 180)),
143 | ((210, 180), (210, 107)),
144 | ((210, 107), (216, 26)),
145 | ((216, 26), (118, 21)),
146 | ((118, 21), (118, 2)),
147 | ((118, 2), (100, -62)),
148 | ((100, -62), (89, -70)),
149 | ((89, -70), (62, -76)),
150 | ((62, -76), (28, -76)),
151 | ((28, -76), (16, -74)),
152 | ((16, -74), (14, -17)),
153 | ((-494, -223), (-597, -220)),
154 | ((-597, -128), (-493, -127)),
155 | ((-493, -127), (-493, 25)),
156 | ((-336, -132), (-337, 24)),
157 | ((14, -17), (66, 22)),
158 | ((-597, 253), (-443, 253)),
159 | ((-443, 253), (-332, 251)),
160 | ((-332, 251), (-221, 264)),
161 | ((-221, 264), (-211, 493)),
162 | ((-211, 493), (-129, 493)),
163 | ((-129, 493), (23, 493)),
164 | ((23, 493), (20, 274)),
165 | ((176, 274), (176, 348)),
166 | ((176, 348), (180, 493)),
167 | ((180, 493), (175, 660)),
168 | ((175, 660), (23, 646)),
169 | ((23, 646), (-128, 646)),
170 | ((-128, 646), (-134, 795)),
171 | ((-134, 795), (-130, 871)),
172 | ((-130, 871), (20, 872)),
173 | ((175, 872), (175, 795)),
174 | ((252, 799), (175, 795)),
175 | ((175, 795), (23, 798)),
176 | ((23, 798), (-134, 795)),
177 | ((-134, 795), (-128, 676)),
178 | ((-128, 676), (-129, 493)),
179 | ((23, 493), (23, 646)),
180 | ((23, 646), (23, 798)),
181 | ((23, 798), (20, 872)),
182 | ((-338, 172), (-332, 251)),
183 | ((-57, 255), (20, 255)),
184 | ((-57, 194), (20, 192)),
185 | ((20, 255), (20, 274)),
186 | ((20, 274), (176, 267)),
187 | ((23, 493), (180, 493)),
188 | ((176, 267), (176, 348))]
189 | split_subdirs = {}
190 | ignore_index = 0
191 | mean_rgb = {"airsim": [103.939, 116.779, 123.68],}
192 |
193 | def __init__(
194 | self,
195 | root,
196 | split="train",
197 | subsplit=None,
198 | is_transform=False,
199 | img_size=(512, 512),
200 | augmentations=None,
201 | img_norm=True,
202 | commun_label='None',
203 | version="airsim",
204 | target_view="target"
205 |
206 | ):
207 |
208 | # dataloader parameters
209 | self.dataset_div = self.divide_region_n_train_val_test()
210 | self.split_subdirs = self.generate_image_path(self.dataset_div)
211 | self.commun_label = commun_label
212 | self.root = root
213 | self.split = split
214 | self.is_transform = is_transform
215 | self.augmentations = augmentations
216 | self.img_norm = img_norm
217 | self.n_classes = 11
218 | self.img_size = (img_size if isinstance(img_size, tuple) else (img_size, img_size))
219 | self.mean = np.array(self.mean_rgb[version])
220 |
221 | # Set the target view; first element of list is target view
222 | self.cam_pos = self.get_cam_pos(target_view)
223 |
224 | # load the communication label
225 | if self.commun_label != 'None':
226 | comm_label = self.read_selection_label(self.commun_label)
227 |
228 | # Pre-define the empty list for the images
229 | self.imgs = {s:{c:{image_mode:[] for image_mode in self.image_modes} for c in self.cam_pos} for s in self.splits}
230 | self.com_label = {s:[] for s in self.splits}
231 |
232 | k = 0
233 | for split in self.splits: # [train, val]
234 | for subdir in self.split_subdirs[split]: # [trajectory ]
235 |
236 | file_list = sorted(glob.glob(os.path.join(root, 'scene', 'async_rotate_fog_000_clear',subdir,self.cam_pos[0],'*.png'),recursive=True))
237 |
238 | for file_path in file_list:
239 | ext = file_path.replace(root+"/scene/",'')
240 | file_name = ext.split("/")[-1]
241 | path_dir = ext.split("/")[1]
242 |
243 | # Check if a image file exists in all views and all modalities
244 | list_of_all_cams_n_modal = [os.path.exists(os.path.join(root,modal,'async_rotate_fog_000_clear',path_dir, cam,file_name)) for modal in self.image_modes for cam in self.cam_pos]
245 |
246 | if all(list_of_all_cams_n_modal):
247 | k += 1
248 | # Add the file path to the self.imgs
249 | for comb_modal in self.image_modes:
250 | for comb_cam in self.cam_pos:
251 | file_path = os.path.join(root,comb_modal,'async_rotate_fog_000_clear', path_dir,comb_cam,file_name)
252 | self.imgs[split][comb_cam][comb_modal].append(file_path)
253 |
254 | if self.commun_label != 'None': # Load the communication label
255 | self.com_label[split].append(comm_label[path_dir+'/'+file_name])
256 |
257 | if not self.imgs[self.split][self.cam_pos[0]][self.image_modes[0]]:
258 | raise Exception(
259 | "No files for split=[%s] found in %s" % (self.split, self.root)
260 | )
261 | print("Found %d %s images" % (len(self.imgs[self.split][self.cam_pos[0]][self.image_modes[0]]), self.split))
262 |
263 | # <---- Functions for conversion of paths ---->
264 | def tuple_to_folder_name(self, path_tuple):
265 | start = path_tuple[1]
266 | end = path_tuple[2]
267 | path=str(start[0])+'_'+str(-start[1])+'__'+str(end[0])+'_'+str(-end[1])+'*'
268 | return path
269 | def generate_image_path(self, dataset_div):
270 |
271 | # Merge across regions
272 | train_path_list = []
273 | val_path_list = []
274 | test_path_list = []
275 | for region in ['skyscraper','suburban','shopping']:
276 | for train_one_path in dataset_div['train'][region][1]:
277 | train_path_list.append(self.tuple_to_folder_name(train_one_path))
278 |
279 | for val_one_path in dataset_div['val'][region][1]:
280 | val_path_list.append(self.tuple_to_folder_name(val_one_path))
281 |
282 | for test_one_path in dataset_div['test'][region][1]:
283 | test_path_list.append(self.tuple_to_folder_name(test_one_path))
284 |
285 |
286 | split_subdirs = {}
287 | split_subdirs['train'] = train_path_list
288 | split_subdirs['val'] = val_path_list
289 | split_subdirs['test'] = test_path_list
290 |
291 | return split_subdirs
292 | def divide_region_n_train_val_test(self):
293 |
294 | region_dict = {'skyscraper':[0,[]],'suburban':[0,[]],'shopping':[0,[]]}
295 | test_ratio = 0.25
296 | val_ratio = 0.25
297 |
298 | dataset_div= {'train':{'skyscraper':[0,[]],'suburban':[0,[]],'shopping':[0,[]]},
299 | 'val' :{'skyscraper':[0,[]],'suburban':[0,[]],'shopping':[0,[]]},
300 | 'test' :{'skyscraper':[0,[]],'suburban':[0,[]],'shopping':[0,[]]}}
301 |
302 | process_edges = []
303 | # label and compute distance
304 | for i, path in enumerate(self.all_edges):
305 | process_edges.append(label_region_n_compute_distance(i,path))
306 |
307 | region_dict[process_edges[i][4]][1].append(process_edges[i])
308 | region_dict[process_edges[i][4]][0] = region_dict[process_edges[i][4]][0] + process_edges[i][3]
309 |
310 |
311 | for region_type, distance_and_path_list in region_dict.items():
312 | total_distance = distance_and_path_list[0]
313 | test_distance = total_distance*test_ratio
314 | val_distance = total_distance*val_ratio
315 |
316 | path_list = distance_and_path_list[1]
317 | tem_list = copy.deepcopy(path_list)
318 |
319 | random.seed(2019)
320 | shuffle(tem_list)
321 |
322 | sum_distance = 0
323 |
324 | # Test Set
325 | while sum_distance < test_distance*0.8:
326 | path = tem_list.pop()
327 | sum_distance += path[3]
328 | dataset_div['test'][region_type][0] = dataset_div['test'][region_type][0] + path[3]
329 | dataset_div['test'][region_type][1].append(path)
330 |
331 | # Val Set
332 | while sum_distance < (test_distance + val_distance)*0.8:
333 | path = tem_list.pop()
334 | sum_distance += path[3]
335 | dataset_div['val'][region_type][0] = dataset_div['val'][region_type][0] + path[3]
336 | dataset_div['val'][region_type][1].append(path)
337 |
338 | # Train Set
339 | dataset_div['train'][region_type][0] = total_distance - sum_distance
340 | dataset_div['train'][region_type][1] = tem_list
341 |
342 | color=['red','green','blue']
343 | ## Visualiaztion with respect to region
344 | fig, ax = plt.subplots(figsize=(30, 15))
345 | div_type = 'train'
346 |
347 | vis_txt_height = 800
348 | for div_type in ['train','val','test']:
349 | for region in ['skyscraper','suburban','shopping']:
350 | vis_path_list = dataset_div[div_type][region][1]
351 | for path in vis_path_list:
352 | x = [path[1][0],path[2][0]]
353 | y = [path[1][1],path[2][1]]
354 |
355 | if region == 'skyscraper':
356 | ax.plot(x, y, color='red', zorder=1, lw=3)
357 | elif region == 'suburban':
358 | ax.plot(x, y, color='blue', zorder=1, lw=3)
359 | elif region == 'shopping':
360 | ax.plot(x, y, color='green', zorder=1, lw=3)
361 |
362 | ax.scatter(x, y,color='black', s=120, zorder=2)
363 |
364 | # Visualize distance text
365 | distance = dataset_div[div_type][region][0]
366 | if region == 'skyscraper':
367 | ax.annotate(div_type+' - '+ region+': '+str(distance), (-800, vis_txt_height),fontsize=20,color='red')
368 | elif region == 'suburban':
369 | ax.annotate(div_type+' - '+ region+': '+str(distance), (-800, vis_txt_height),fontsize=20,color='blue')
370 | elif region == 'shopping':
371 | ax.annotate(div_type+' - '+ region+': '+str(distance), (-800, vis_txt_height),fontsize=20,color='green')
372 | vis_txt_height-=30
373 |
374 | plt.savefig('region.png', dpi=200)
375 | plt.close()
376 |
377 | ## Visualization with respect to train/val/test
378 | fig, ax = plt.subplots(figsize=(30, 15))
379 | div_type = 'train'
380 | vis_txt_height = 800
381 | for div_type in ['train','val','test']:
382 | for region in ['skyscraper','suburban','shopping']:
383 | vis_path_list = dataset_div[div_type][region][1]
384 | for path in vis_path_list:
385 | x = [path[1][0],path[2][0]]
386 | y = [path[1][1],path[2][1]]
387 |
388 | if div_type == 'train':
389 | ax.plot(x, y, color='red', zorder=1, lw=3)
390 | elif div_type == 'val':
391 | ax.plot(x, y, color='blue', zorder=1, lw=3)
392 | elif div_type == 'test':
393 | ax.plot(x, y, color='green', zorder=1, lw=3)
394 |
395 | ax.scatter(x, y,color='black', s=120, zorder=2)
396 |
397 | # Visualize distance text
398 | distance = dataset_div[div_type][region][0]
399 | if div_type == 'train':
400 | ax.annotate(div_type+' - '+ region+': '+str(distance), (-800, vis_txt_height),fontsize=20,color='red')
401 | elif div_type == 'val':
402 | ax.annotate(div_type+' - '+ region+': '+str(distance), (-800, vis_txt_height),fontsize=20,color='blue')
403 | elif div_type == 'test':
404 | ax.annotate(div_type+' - '+ region+': '+str(distance), (-800, vis_txt_height),fontsize=20,color='green')
405 | vis_txt_height-=30
406 |
407 | #ax.annotate(txt, (x, y))
408 | plt.savefig('train_val_test.png', dpi=200)
409 | plt.close()
410 |
411 | return dataset_div
412 | def read_selection_label(self, label_type):
413 |
414 | if label_type == 'when2com':
415 | with open(os.path.join(self.root,'gt_when_to_communicate.txt')) as f:
416 | content = f.readlines()
417 |
418 | com_label = {}
419 | for x in content:
420 | key = x.split(' ')[2].strip().split('/')[-3] + '/' + x.split(' ')[2].strip().split('/')[-1]+'.png'
421 | com_label[key] = int(x.split(' ')[1])
422 |
423 | elif label_type == 'mimo':
424 | with open(os.path.join(self.root,'gt_mimo_communicate.txt')) as f:
425 | content = f.readlines()
426 | com_label = {}
427 | for x in content:
428 | file_key = x.split(' ')[-1].strip().split('/')[-3] + '/' + x.split(' ')[-1].strip().split('/')[-1]+'.png'
429 | noise_label = make_tuple(x.split(' (')[0])
430 | link_label = make_tuple(x.split(') ')[1] + ')')
431 | com_label[file_key] = torch.tensor([noise_label, link_label])
432 |
433 | else:
434 | raise ValueError('Unknown label file name '+ str(label_type))
435 |
436 |
437 | print('Loaded: selection label.')
438 | return com_label
439 |
440 | def convert_link_label(self, link_label):
441 | div_list = []
442 | for i in link_label:
443 | div_list.append(int(i/2))
444 |
445 | new_link_label = []
446 | for i, elem_i in enumerate(div_list):
447 | for j, elem_j in enumerate(div_list):
448 | if j != i and elem_i == elem_j:
449 | new_link_label.append(j)
450 | new_link_label = tuple(new_link_label)
451 | return new_link_label
452 | def get_cam_pos(self, target_view):
453 | if target_view == "overhead":
454 | cam_pos = [ 'overhead', 'front', 'back', 'left', 'right']
455 | elif target_view == "front":
456 | cam_pos = [ 'front', 'back', 'left', 'right','overhead']
457 | elif target_view == "back":
458 | cam_pos = [ 'back', 'front', 'left', 'right','overhead']
459 | elif target_view == "left":
460 | cam_pos = [ 'left', 'back', 'front','right','overhead']
461 | elif target_view == "target":
462 | cam_pos = [ 'target', 'normal1', 'normal2','normal3','normal4']
463 | elif target_view == "6agent":
464 | cam_pos = ['agent1', 'agent2', 'agent3', 'agent4', 'agent5', 'agent6']
465 | elif target_view == "5agent":
466 | cam_pos = ['agent1', 'agent2', 'agent3', 'agent4', 'agent5']
467 | elif target_view == "DroneNP":
468 | cam_pos = ["DroneNN_main", "DroneNP_main", "DronePN_main", "DronePP_main", "DroneZZ_main"]
469 | elif target_view == "DroneNN_backNN":
470 | cam_pos = ["DroneNN_backNN", "DroneNP_backNP", "DronePN_backPN", "DroneNN_frontNN", "DroneNP_frontNP"]
471 | elif target_view == "5agentv7":
472 | cam_pos = ["agent1", "agent3", "agent5", "agent2", "agent4"]
473 | else:
474 | cam_pos = [ 'front', 'back', 'left', 'right', 'overhead']
475 | return cam_pos
476 | # <---- Functions for conversion of paths ---->
477 |
478 |
479 |
480 |
481 | def __len__(self):
482 | """__len__"""
483 | return len(self.imgs[self.split][self.cam_pos[0]][self.image_modes[0]])
484 |
485 | def __getitem__(self, index):
486 | """__getitem__
487 |
488 | :param index:
489 | """
490 | img_list = []
491 | lbl_list = []
492 |
493 | for k, camera in enumerate(self.cam_pos):
494 |
495 | img_path, mask_path = self.imgs[self.split][camera]['scene'][index], self.imgs[self.split][camera]['segmentation_decoded'][index]
496 | img, mask = np.array(cv2.imread(img_path),dtype=np.uint8)[:,:,:3], np.array(cv2.imread(mask_path),dtype=np.uint8)[:,:,0]
497 |
498 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
499 | lbl = mask
500 | if self.augmentations is not None:
501 | img, lbl, aux = self.augmentations(img, lbl)
502 |
503 | if self.is_transform:
504 | img, lbl = self.transform(img, lbl)
505 |
506 | img_list.append(img)
507 | lbl_list.append(lbl)
508 |
509 | if self.commun_label != 'None':
510 | return img_list, lbl_list, self.com_label[self.split][index] #, self.debug_file_path[self.split][index]
511 | else:
512 | return img_list, lbl_list
513 |
514 |
515 | def transform(self, img, lbl):
516 |
517 | """transform
518 | :param img:
519 | :param lbl:
520 | """
521 | img = img[:, :, ::-1] # RGB -> BGR
522 | img = img.astype(np.float64)
523 | img -= self.mean
524 | if self.img_norm:
525 | img = img.astype(float) / 255.0
526 | # NHWC -> NCHW
527 | img = img.transpose(2, 0, 1)
528 |
529 | classes = np.unique(lbl)
530 | lbl = lbl.astype(float)
531 | lbl = lbl.astype(int)
532 |
533 | if not np.all(np.unique(lbl[lbl != self.ignore_index]) < self.n_classes):
534 | print("after det", classes, np.unique(lbl))
535 | raise ValueError("Segmentation map contained invalid class values")
536 |
537 | img = torch.from_numpy(img).float()
538 | lbl = torch.from_numpy(lbl).long()
539 |
540 | return img, lbl
541 |
542 | def decode_segmap(self, temp):
543 | r = temp.copy()
544 | g = temp.copy()
545 | b = temp.copy()
546 | for i,name in self.id2name.items():
547 | r[(temp==i)] = self.name2color[name][0][0]
548 | g[(temp==i)] = self.name2color[name][0][1]
549 | b[(temp==i)] = self.name2color[name][0][2]
550 |
551 | rgb = np.zeros((temp.shape[0], temp.shape[1], 3))
552 | rgb[:, :, 0] = r / 255.0
553 | rgb[:, :, 1] = g / 255.0
554 | rgb[:, :, 2] = b / 255.0
555 | return rgb
556 |
557 |
558 | def save_tensor_imag(imgs):
559 | import numpy as np
560 | import cv2
561 | bs = imgs[0].shape[0]
562 | mean_rgb = np.array([103.939, 116.779, 123.68])
563 |
564 | for view in range(len(imgs)):
565 | for i in range(bs):
566 | image = imgs[view][i]
567 |
568 | image = image.cpu().numpy()
569 | image = np.transpose(image, (1, 2, 0))
570 | image = image * 255 + mean_rgb
571 | cv2.imwrite('debug_tmp/img_b' + str(i) +'_v'+str(view)+ '.png', image)
572 |
573 |
574 |
--------------------------------------------------------------------------------
/ptsemseg/loss/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import functools
3 |
4 | from ptsemseg.loss.loss import (
5 | cross_entropy2d,
6 | bootstrapped_cross_entropy2d,
7 | multi_scale_cross_entropy2d
8 | )
9 | ## Access to loss functions
10 |
11 | logger = logging.getLogger("ptsemseg")
12 |
13 | key2loss = {
14 | "cross_entropy": cross_entropy2d,
15 | "bootstrapped_cross_entropy": bootstrapped_cross_entropy2d,
16 | "multi_scale_cross_entropy": multi_scale_cross_entropy2d
17 | }
18 |
19 |
20 | def get_loss_function(cfg):
21 | if cfg["training"]["loss"] is None:
22 | logger.info("Using default cross entropy loss")
23 | return cross_entropy2d
24 |
25 | else:
26 | loss_dict = cfg["training"]["loss"]
27 | loss_name = loss_dict["name"]
28 | loss_params = {k: v for k, v in loss_dict.items() if k != "name"}
29 |
30 | if loss_name not in key2loss:
31 | raise NotImplementedError("Loss {} not implemented".format(loss_name))
32 |
33 | logger.info("Using {} with {} params".format(loss_name, loss_params))
34 | return functools.partial(key2loss[loss_name], **loss_params)
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/ptsemseg/loss/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import pdb
4 |
5 | def cross_entropy2d(input, target, weight=None, size_average=True):
6 | n, c, h, w = input.size()
7 | nt, ht, wt = target.size()
8 |
9 | # Handle inconsistent size between input and target
10 | if h != ht and w != wt: # upsample labels
11 | input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)
12 |
13 | input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
14 | target = target.view(-1)
15 | loss = F.cross_entropy(
16 | input, target, weight=weight, size_average=size_average, ignore_index=250
17 | )
18 | return loss
19 |
20 |
21 | def multi_scale_cross_entropy2d(input, target, weight=None, size_average=True, scale_weight=None):
22 | if not isinstance(input, tuple):
23 | return cross_entropy2d(input=input, target=target, weight=weight, size_average=size_average)
24 |
25 | # Auxiliary training for PSPNet [1.0, 0.4] and ICNet [1.0, 0.4, 0.16]
26 | if scale_weight is None: # scale_weight: torch tensor type
27 | n_inp = len(input)
28 | scale = 0.4
29 | scale_weight = torch.pow(scale * torch.ones(n_inp), torch.arange(n_inp).float()).to('cuda' if target.is_cuda else 'cpu')
30 |
31 | loss = 0.0
32 | for i, inp in enumerate(input):
33 | loss = loss + scale_weight[i] * cross_entropy2d(
34 | input=inp, target=target, weight=weight, size_average=size_average
35 | )
36 |
37 | return loss
38 |
39 |
40 | def bootstrapped_cross_entropy2d(input, target, K, weight=None, size_average=True):
41 |
42 | batch_size = input.size()[0]
43 |
44 | def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True):
45 |
46 | n, c, h, w = input.size()
47 | input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)
48 | target = target.view(-1)
49 | loss = F.cross_entropy(
50 | input, target, weight=weight, reduce=False, size_average=False, ignore_index=250
51 | )
52 |
53 | topk_loss, _ = loss.topk(K)
54 | reduced_topk_loss = topk_loss.sum() / K
55 |
56 | return reduced_topk_loss
57 |
58 | loss = 0.0
59 | # Bootstrap from each image not entire batch
60 | for i in range(batch_size):
61 | loss += _bootstrap_xentropy_single(
62 | input=torch.unsqueeze(input[i], 0),
63 | target=torch.unsqueeze(target[i], 0),
64 | K=K,
65 | weight=weight,
66 | size_average=size_average,
67 | )
68 | return loss / float(batch_size)
69 |
--------------------------------------------------------------------------------
/ptsemseg/metrics.py:
--------------------------------------------------------------------------------
1 | # Adapted from score written by wkentaro
2 | # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
3 |
4 | import numpy as np
5 | import torch
6 |
7 | class runningScore(object):
8 | def __init__(self, n_classes):
9 | self.n_classes = n_classes
10 | self.confusion_matrix = np.zeros((n_classes, n_classes))
11 | self.confusion_matrix_pos = np.zeros((n_classes, n_classes))
12 | self.confusion_matrix_neg = np.zeros((n_classes, n_classes))
13 | self.total_agent = 0
14 | self.correct_when2com = 0
15 | self.correct_who2com = 0
16 | self.total_bandW = 0
17 | self.count = 0
18 |
19 | def update_bandW(self, bandW):
20 | self.total_bandW += bandW
21 | self.count += 1.0
22 |
23 | def update_selection(self, if_commun_label, commun_label, action_argmax):
24 | if if_commun_label == 'when2com':
25 | action_argmax = torch.squeeze(action_argmax)
26 | commun_label = commun_label + 1 # -1,0,1,2,3 ->0, 1, 2 ,3 ,4
27 |
28 |
29 | self.total_agent += commun_label.size(0)
30 | when_to_commu_label = (commun_label == 0)
31 |
32 | if action_argmax.dim() == 2:
33 | predict_link = (action_argmax > 0.2).nonzero()
34 | link_num = predict_link.shape[0]
35 | when2com_pred = torch.zeros(commun_label.size(0), dtype=torch.int8)
36 |
37 | for row_idx in range(link_num):
38 | sample_idx = predict_link[row_idx,:][0]
39 | link_idx = predict_link[row_idx,:][1]
40 | if link_idx == commun_label[sample_idx]:
41 | self.correct_who2com = self.correct_who2com +1
42 | if link_idx != 0:
43 | when2com_pred[sample_idx] = True
44 | when2com_pred = when2com_pred.cuda()
45 | self.correct_when2com += (when2com_pred == when_to_commu_label).sum().item()
46 | elif action_argmax.dim() == 1:
47 | # Learn when to communicate accuracy
48 | when_to_commu_pred = (action_argmax == 0)
49 | self.correct_when2com += (when_to_commu_pred == when_to_commu_label).sum().item()
50 |
51 | # Learn who to communicate accuracy
52 | self.correct_who2com += (action_argmax == commun_label).sum().item()
53 | else:
54 | assert commun_label.shape == action_argmax.shape, "Shape of selection labels are different."
55 | elif if_commun_label == 'mimo':
56 |
57 | # commun_label = commun_label.cpu()
58 | self.total_agent += commun_label[:,0,:].shape[0]*commun_label[:,0,:].shape[1]
59 | # when2com
60 | id_tensor = torch.arange(action_argmax.shape[1]).repeat(action_argmax.shape[0], 1)
61 | when_to_commu_pred = (action_argmax.cpu() != id_tensor)
62 | when_to_commu_label = commun_label[:,0,:].type(torch.ByteTensor)
63 | self.correct_when2com += (when_to_commu_pred == when_to_commu_label).sum().item()
64 |
65 | # who2com (gpu) who *(need com)
66 | gt_action = commun_label[:,1,:] * commun_label[:,0,:] + id_tensor.cuda()*(1 - commun_label[:,0,:])
67 |
68 | self.correct_who2com += (action_argmax == gt_action).sum().item()
69 |
70 | def update_div(self, if_commun_label, label_trues, label_preds, commun_label):
71 | # import pdb;pdb.set_trace()
72 | if if_commun_label == 'when2com':
73 | commun_label = commun_label.cpu().numpy()
74 | when2comlab = (commun_label == -1) # -1 ---> noraml # other --> noist need com
75 | elif if_commun_label == 'mimo':
76 | commun_label = commun_label.cpu().numpy()[:, 0, :]
77 | when2comlab = (commun_label == 0) # 0 --> normal # 1 ---> noisy need com
78 | when2comlab = when2comlab.transpose(1, 0) #[batch of agent1, batch of agent2 ]
79 | when2comlab = when2comlab.flatten()
80 |
81 | # import pdb; pdb.set_trace()
82 | pos_idx = (when2comlab == True).nonzero()
83 | neg_idx = (when2comlab == False).nonzero()
84 |
85 | label_trues_pos = label_trues[pos_idx]
86 | label_preds_pos = label_preds[pos_idx]
87 | if label_trues_pos.shape[0] != 0 :
88 | for lt, lp in zip(label_trues_pos, label_preds_pos):
89 | self.confusion_matrix_pos += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
90 |
91 | label_trues_neg = label_trues[neg_idx]
92 | label_preds_neg = label_preds[neg_idx]
93 | if label_trues_neg.shape[0] != 0:
94 | for lt, lp in zip(label_trues_neg, label_preds_neg):
95 | self.confusion_matrix_neg += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
96 |
97 |
98 |
99 | def _fast_hist(self, label_true, label_pred, n_class):
100 | mask = (label_true >= 0) & (label_true < n_class)
101 | hist = np.bincount(
102 | n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2
103 | ).reshape(n_class, n_class)
104 | return hist
105 |
106 | def update(self, label_trues, label_preds):
107 | for lt, lp in zip(label_trues, label_preds):
108 | self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
109 |
110 | def get_avg_bandW(self):
111 | return self.total_bandW/self.count
112 |
113 | def get_only_noise_scores(self):
114 | """Returns accuracy score evaluation result.
115 | - overall accuracy
116 | - mean accuracy
117 | - mean IU
118 | - fwavacc
119 | """
120 | hist = self.confusion_matrix_neg
121 | acc = np.diag(hist).sum() / hist.sum()
122 | acc_cls = np.diag(hist) / hist.sum(axis=1)
123 | acc_cls = np.nanmean(acc_cls)
124 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
125 | mean_iu = np.nanmean(iu)
126 | freq = hist.sum(axis=1) / hist.sum()
127 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
128 | cls_iu = dict(zip(range(self.n_classes), iu))
129 |
130 | return (
131 | {
132 | "Overall Acc: \t": acc,
133 | "Mean Acc : \t": acc_cls,
134 | "FreqW Acc : \t": fwavacc,
135 | "Mean IoU : \t": mean_iu,
136 | },
137 | cls_iu,
138 | )
139 |
140 |
141 | def get_only_normal_scores(self):
142 | """Returns accuracy score evaluation result.
143 | - overall accuracy
144 | - mean accuracy
145 | - mean IU
146 | - fwavacc
147 | """
148 | hist = self.confusion_matrix_pos
149 | acc = np.diag(hist).sum() / hist.sum()
150 | acc_cls = np.diag(hist) / hist.sum(axis=1)
151 | acc_cls = np.nanmean(acc_cls)
152 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
153 | mean_iu = np.nanmean(iu)
154 | freq = hist.sum(axis=1) / hist.sum()
155 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
156 | cls_iu = dict(zip(range(self.n_classes), iu))
157 |
158 | return (
159 | {
160 | "Overall Acc: \t": acc,
161 | "Mean Acc : \t": acc_cls,
162 | "FreqW Acc : \t": fwavacc,
163 | "Mean IoU : \t": mean_iu,
164 | },
165 | cls_iu,
166 | )
167 |
168 | def get_scores(self):
169 | """Returns accuracy score evaluation result.
170 | - overall accuracy
171 | - mean accuracy
172 | - mean IU
173 | - fwavacc
174 | """
175 | hist = self.confusion_matrix
176 | acc = np.diag(hist).sum() / hist.sum()
177 | acc_cls = np.diag(hist) / hist.sum(axis=1)
178 | acc_cls = np.nanmean(acc_cls)
179 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
180 | mean_iu = np.nanmean(iu)
181 | freq = hist.sum(axis=1) / hist.sum()
182 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
183 | cls_iu = dict(zip(range(self.n_classes), iu))
184 |
185 | return (
186 | {
187 | "Overall Acc: \t": acc,
188 | "Mean Acc : \t": acc_cls,
189 | "FreqW Acc : \t": fwavacc,
190 | "Mean IoU : \t": mean_iu,
191 | },
192 | cls_iu,
193 | )
194 |
195 |
196 | def get_selection_accuracy(self):
197 | when_com_accuacy = self.correct_when2com / self.total_agent * 100
198 | who_com_accuracy = self.correct_who2com / self.total_agent * 100
199 |
200 | return when_com_accuacy, who_com_accuracy
201 |
202 |
203 |
204 |
205 | def reset(self):
206 | self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
207 | self.total_agent = 0
208 | self.correct_when2com = 0
209 | self.correct_who2com = 0
210 | self.total_bandW = 0
211 | self.count = 0
212 |
213 |
214 | def print_score(self,n_classes, score, class_iou):
215 | metric_string = ""
216 | class_string = ""
217 |
218 | for i in range(n_classes):
219 | # print(i, class_iou[i])
220 | metric_string = metric_string + " " + str(i)
221 | class_string = class_string + " " + str(round(class_iou[i] * 100, 2))
222 |
223 | for k, v in score.items():
224 | metric_string = metric_string + " " + str(k)
225 | class_string = class_string + " " + str(round(v * 100, 2))
226 | # print(k, v)
227 | print(metric_string)
228 | print(class_string)
229 |
230 |
231 | class averageMeter(object):
232 | """Computes and stores the average and current value"""
233 |
234 | def __init__(self):
235 | self.reset()
236 |
237 | def reset(self):
238 | self.val = 0
239 | self.avg = 0
240 | self.sum = 0
241 | self.count = 0
242 |
243 | def update(self, val, n=1):
244 | self.val = val
245 | self.sum += val * n
246 | self.count += n
247 | self.avg = self.sum / self.count
248 |
--------------------------------------------------------------------------------
/ptsemseg/models/__init__.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torchvision.models as models
3 |
4 |
5 | from ptsemseg.models.agent import Single_agent, All_agents, LearnWho2Com, LearnWhen2Com, MIMOcom,MIMO_All_agents, MIMOcomWho
6 |
7 |
8 | def get_model(model_dict, n_classes, version=None):
9 | name = model_dict["model"]["arch"]
10 |
11 | model = _get_model_instance(name)
12 | in_channels = 3
13 | if name == "Single_agent":
14 | model = model(n_classes=n_classes, in_channels=in_channels,
15 | enc_backbone=model_dict["model"]['enc_backbone'],
16 | dec_backbone=model_dict["model"]['dec_backbone'],
17 | feat_squeezer=model_dict["model"]['feat_squeezer'],
18 | feat_channel=model_dict["model"]['feat_channel'])
19 | elif name == "All_agents":
20 | model = model(n_classes=n_classes, in_channels=in_channels,
21 | aux_agent_num=model_dict["model"]['agent_num'],
22 | shuffle_flag=model_dict["model"]['shuffle_features'],
23 | enc_backbone=model_dict["model"]['enc_backbone'],
24 | dec_backbone=model_dict["model"]['dec_backbone'],
25 | feat_squeezer=model_dict["model"]['feat_squeezer'],
26 | feat_channel=model_dict["model"]['feat_channel'])
27 | elif name == "MIMO_All_agents":
28 | model = model(n_classes=n_classes, in_channels=in_channels,
29 | aux_agent_num=model_dict["model"]['agent_num'],
30 | shuffle_flag=model_dict["model"]['shuffle_features'],
31 | enc_backbone=model_dict["model"]['enc_backbone'],
32 | dec_backbone=model_dict["model"]['dec_backbone'],
33 | feat_squeezer=model_dict["model"]['feat_squeezer'],
34 | feat_channel=model_dict["model"]['feat_channel'])
35 |
36 | elif name == "LearnWho2Com":
37 | model = model(n_classes=n_classes, in_channels=in_channels,
38 | attention=model_dict["model"]['attention'],has_query=model_dict["model"]['query'],
39 | sparse=model_dict["model"]['sparse'],
40 | aux_agent_num=model_dict["model"]['agent_num'],
41 | shared_img_encoder=model_dict["model"]["shared_img_encoder"],
42 | image_size=model_dict["data"]["img_rows"],
43 | query_size=model_dict["model"]["query_size"],key_size=model_dict["model"]["key_size"],
44 | enc_backbone=model_dict["model"]['enc_backbone'],
45 | dec_backbone=model_dict["model"]['dec_backbone']
46 | )
47 | elif name == "LearnWhen2Com":
48 | model = model(n_classes=n_classes, in_channels=in_channels,
49 | attention=model_dict["model"]['attention'],has_query=model_dict["model"]['query'],
50 | sparse=model_dict["model"]['sparse'],
51 | aux_agent_num=model_dict["model"]['agent_num'],
52 | shared_img_encoder=model_dict["model"]["shared_img_encoder"],
53 | image_size=model_dict["data"]["img_rows"],
54 | query_size=model_dict["model"]["query_size"],key_size=model_dict["model"]["key_size"],
55 | enc_backbone=model_dict["model"]['enc_backbone'],
56 | dec_backbone=model_dict["model"]['dec_backbone']
57 | )
58 | elif name == "MIMOcom":
59 | model = model(n_classes=n_classes, in_channels=in_channels,
60 | attention=model_dict["model"]['attention'],has_query=model_dict["model"]['query'],
61 | sparse=model_dict["model"]['sparse'],
62 | agent_num=model_dict["model"]['agent_num'],
63 | shared_img_encoder=model_dict["model"]["shared_img_encoder"],
64 | image_size=model_dict["data"]["img_rows"],
65 | query_size=model_dict["model"]["query_size"],key_size=model_dict["model"]["key_size"],
66 | enc_backbone=model_dict["model"]['enc_backbone'],
67 | dec_backbone=model_dict["model"]['dec_backbone']
68 | )
69 | elif name == "MIMOcomWho":
70 | model = model(n_classes=n_classes, in_channels=in_channels,
71 | attention=model_dict["model"]['attention'],has_query=model_dict["model"]['query'],
72 | sparse=model_dict["model"]['sparse'],
73 | agent_num=model_dict["model"]['agent_num'],
74 | shared_img_encoder=model_dict["model"]["shared_img_encoder"],
75 | image_size=model_dict["data"]["img_rows"],
76 | query_size=model_dict["model"]["query_size"],key_size=model_dict["model"]["key_size"],
77 | enc_backbone=model_dict["model"]['enc_backbone'],
78 | dec_backbone=model_dict["model"]['dec_backbone']
79 | )
80 |
81 | else:
82 | model = model(n_classes=n_classes, in_channels=in_channels,
83 | enc_backbone=model_dict["model"]['enc_backbone'],
84 | dec_backbone=model_dict["model"]['dec_backbone'])
85 |
86 | return model
87 |
88 |
89 | def _get_model_instance(name):
90 | try:
91 | return {
92 | "Single_agent": Single_agent,
93 | "All_agents": All_agents,
94 | "MIMO_All_agents": MIMO_All_agents,
95 | 'LearnWho2Com':LearnWho2Com,
96 | 'LearnWhen2Com': LearnWhen2Com,
97 | 'MIMOcom': MIMOcom,
98 | 'MIMOcomWho': MIMOcomWho,
99 | }[name]
100 | except:
101 | raise ("Model {} not available".format(name))
102 |
--------------------------------------------------------------------------------
/ptsemseg/models/agent.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torchvision.models as models
4 | import torch.nn.functional as F
5 | import pretrainedmodels
6 | import copy
7 |
8 | from ptsemseg.models.utils import conv2DBatchNormRelu, deconv2DBatchNormRelu, Sparsemax
9 | import random
10 | from torch.distributions.categorical import Categorical
11 |
12 | from ptsemseg.models.backbone import n_segnet_encoder, resnet_encoder, n_segnet_decoder, simple_decoder, FCN_decoder
13 | import numpy as np
14 |
15 |
16 | def get_encoder(name):
17 | try:
18 | return {
19 | "n_segnet_encoder": n_segnet_encoder,
20 | "resnet_encoder": resnet_encoder,
21 | }[name]
22 | except:
23 | raise ("Encoder {} not available".format(name))
24 |
25 |
26 | def get_decoder(name):
27 | try:
28 | return {
29 | "n_segnet_decoder": n_segnet_decoder,
30 | "simple_decoder": simple_decoder,
31 | "FCN_decoder": FCN_decoder
32 |
33 | }[name]
34 | except:
35 | raise ("Decoder {} not available".format(name))
36 |
37 |
38 | ### ============= Modules ============= ###
39 | class img_encoder(nn.Module):
40 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, feat_squeezer=-1,
41 | enc_backbone='n_segnet_encoder'):
42 | super(img_encoder, self).__init__()
43 | feat_chn = 256
44 |
45 | self.feature_backbone = get_encoder(enc_backbone)(n_classes=n_classes, in_channels=in_channels)
46 | self.feat_squeezer = feat_squeezer
47 |
48 | # squeeze the feature map size
49 | if feat_squeezer == 2: # resolution/2
50 | self.squeezer = conv2DBatchNormRelu(512, feat_channel, k_size=3, stride=2, padding=1)
51 | elif feat_squeezer == 4: # resolution/4
52 | self.squeezer = conv2DBatchNormRelu(512, feat_channel, k_size=3, stride=4, padding=1)
53 | else:
54 | self.squeezer = conv2DBatchNormRelu(512, feat_channel, k_size=3, stride=1, padding=1)
55 |
56 | def forward(self, inputs):
57 | outputs = self.feature_backbone(inputs)
58 | outputs = self.squeezer(outputs)
59 |
60 | return outputs
61 |
62 |
63 | class img_decoder(nn.Module):
64 | def __init__(self, n_classes=21, in_channels=512, agent_num=5, feat_squeezer=-1, dec_backbone='n_segnet_decoder'):
65 | super(img_decoder, self).__init__()
66 |
67 | self.feat_squeezer = feat_squeezer
68 | if feat_squeezer == 2: # resolution/2
69 | self.desqueezer = deconv2DBatchNormRelu(in_channels, in_channels, k_size=3, stride=2, padding=1,
70 | output_padding=1)
71 | self.output_decoder = get_decoder(dec_backbone)(n_classes=n_classes, in_channels=in_channels)
72 |
73 | elif feat_squeezer == 4: # resolution/4
74 | self.desqueezer1 = deconv2DBatchNormRelu(in_channels, 512, k_size=3, stride=2, padding=1, output_padding=1)
75 | self.desqueezer2 = deconv2DBatchNormRelu(512, 512, k_size=3, stride=2, padding=1, output_padding=1)
76 | self.output_decoder = get_decoder(dec_backbone)(n_classes=n_classes, in_channels=512)
77 | else:
78 | self.output_decoder = get_decoder(dec_backbone)(n_classes=n_classes, in_channels=in_channels)
79 |
80 | def forward(self, inputs):
81 | if self.feat_squeezer == 2: # resolution/2
82 | inputs = self.desqueezer(inputs)
83 |
84 | elif self.feat_squeezer == 4: # resolution/4
85 | inputs = self.desqueezer1(inputs)
86 | inputs = self.desqueezer2(inputs)
87 |
88 | outputs = self.output_decoder(inputs)
89 | return outputs
90 |
91 |
92 | class msg_generator(nn.Module):
93 | def __init__(self, in_channels=512, message_size=32):
94 | super(msg_generator, self).__init__()
95 | self.in_channels = in_channels
96 |
97 | # Encoder
98 | # down 1
99 | self.conv1 = conv2DBatchNormRelu(self.in_channels, 256, k_size=3, stride=1, padding=1)
100 | self.conv2 = conv2DBatchNormRelu(256, 128, k_size=3, stride=1, padding=1)
101 | self.conv3 = conv2DBatchNormRelu(128, 64, k_size=3, stride=1, padding=1)
102 | self.conv4 = conv2DBatchNormRelu(64, 64, k_size=3, stride=1, padding=1)
103 | self.conv5 = conv2DBatchNormRelu(64, message_size, k_size=3, stride=1, padding=1)
104 |
105 | def forward(self, inputs):
106 | outputs = self.conv1(inputs)
107 | outputs = self.conv2(outputs)
108 | outputs = self.conv3(outputs)
109 | outputs = self.conv4(outputs)
110 | outputs = self.conv5(outputs)
111 | return outputs
112 |
113 |
114 | class policy_net4(nn.Module):
115 | def __init__(self, n_classes=21, in_channels=512, input_feat_sz=32, enc_backbone='n_segnet_encoder'):
116 | super(policy_net4, self).__init__()
117 | self.in_channels = in_channels
118 |
119 | feat_map_sz = input_feat_sz // 4
120 | self.n_feat = int(256 * feat_map_sz * feat_map_sz)
121 |
122 | self.img_encoder = img_encoder(n_classes=n_classes, in_channels=self.in_channels, enc_backbone=enc_backbone)
123 |
124 | # Encoder
125 | # down 1
126 | self.conv1 = conv2DBatchNormRelu(512, 512, k_size=3, stride=1, padding=1)
127 | self.conv2 = conv2DBatchNormRelu(512, 256, k_size=3, stride=1, padding=1)
128 | self.conv3 = conv2DBatchNormRelu(256, 256, k_size=3, stride=2, padding=1)
129 |
130 | # down 2
131 | self.conv4 = conv2DBatchNormRelu(256, 256, k_size=3, stride=1, padding=1)
132 | self.conv5 = conv2DBatchNormRelu(256, 256, k_size=3, stride=2, padding=1)
133 |
134 | def forward(self, features_map):
135 | outputs1 = self.img_encoder(features_map)
136 |
137 | outputs = self.conv1(outputs1)
138 | outputs = self.conv2(outputs)
139 | outputs = self.conv3(outputs)
140 | outputs = self.conv4(outputs)
141 | outputs = self.conv5(outputs)
142 | return outputs
143 |
144 |
145 | class km_generator(nn.Module):
146 | def __init__(self, out_size=128, input_feat_sz=32):
147 | super(km_generator, self).__init__()
148 | feat_map_sz = input_feat_sz // 4
149 | self.n_feat = int(256 * feat_map_sz * feat_map_sz)
150 | self.fc = nn.Sequential(
151 | nn.Linear(self.n_feat, 256), #
152 | nn.ReLU(inplace=True),
153 | nn.Linear(256, 128), #
154 | nn.ReLU(inplace=True),
155 | nn.Linear(128, out_size)) #
156 |
157 | def forward(self, features_map):
158 | outputs = self.fc(features_map.view(-1, self.n_feat))
159 | return outputs
160 |
161 |
162 | class linear(nn.Module):
163 | def __init__(self, out_size=128, input_feat_sz=32):
164 | super(linear, self).__init__()
165 | feat_map_sz = input_feat_sz // 4
166 | self.n_feat = int(256 * feat_map_sz * feat_map_sz)
167 |
168 | self.fc = nn.Sequential(
169 | nn.Linear(self.n_feat, 256),
170 | nn.ReLU(inplace=True),
171 | nn.Linear(256, 128),
172 | nn.ReLU(inplace=True),
173 | nn.Linear(128, out_size)
174 | )
175 |
176 | def forward(self, features_map):
177 | outputs = self.fc(features_map.view(-1, self.n_feat))
178 | return outputs
179 |
180 |
181 | class conv(nn.Module):
182 | def __init__(self, out_size=128):
183 | super(conv, self).__init__()
184 | feat_map_sz = input_feat_sz // 4
185 | self.conv = conv2DBatchNormRelu(256, out_size, k_size=1, stride=1, padding=1)
186 |
187 | def forward(self, features_map):
188 | outputs = self.conv(features_map)
189 | return outputs
190 |
191 |
192 |
193 | # <------ Attention ------> #
194 | class ScaledDotProductAttention(nn.Module):
195 | ''' Scaled Dot-Product Attention '''
196 |
197 | def __init__(self, temperature, attn_dropout=0.1):
198 | super().__init__()
199 | self.temperature = temperature
200 | self.sparsemax = Sparsemax(dim=1)
201 | self.softmax = nn.Softmax(dim=1)
202 |
203 | def forward(self, q, k, v, sparse=True):
204 | attn_orig = torch.bmm(k, q.transpose(2, 1))
205 | attn_orig = attn_orig / self.temperature
206 | if sparse:
207 | attn_orig = self.sparsemax(attn_orig)
208 | else:
209 | attn_orig = self.softmax(attn_orig)
210 | attn = torch.unsqueeze(torch.unsqueeze(attn_orig, 3), 4)
211 | output = attn * v # (batch,4,channel,size,size)
212 | output = output.sum(1) # (batch,1,channel,size,size)
213 | return output, attn_orig.transpose(2, 1)
214 |
215 | class AdditiveAttentin(nn.Module):
216 | def __init__(self):
217 | super().__init__()
218 | # self.dropout = nn.Dropout(attn_dropout)
219 | self.softmax = nn.Softmax(dim=1)
220 | self.sparsemax = Sparsemax(dim=1)
221 | self.linear_feat = nn.Linear(128, 128)
222 | self.linear_context = nn.Linear(128, 128)
223 | self.linear_out = nn.Linear(128, 1)
224 |
225 | def forward(self, q, k, v, sparse=True):
226 | # q (batch,1,128)
227 | # k (batch,4,128)
228 | # v (batch,4,channel,size,size)
229 | temp1 = self.linear_feat(k) # (batch,4,128)
230 | temp2 = self.linear_context(q) # (batch,1,128)
231 | attn_orig = self.linear_out(temp1 + temp2) # (batch,4,1)
232 | if sparse:
233 | attn_orig = self.sparsemax(attn_orig) # (batch,4,1)
234 | else:
235 | attn_orig = self.softmax(attn_orig) # (batch,4,1)
236 | attn = torch.unsqueeze(torch.unsqueeze(attn_orig, 3), 4) # (batch,4,1,1,1)
237 | output = attn * v
238 | output = output.sum(1) # (batch,1,channel,size,size)
239 | return output, attn_orig.transpose(2, 1)
240 |
241 | # MIMO (non warp)
242 | class MIMOGeneralDotProductAttention(nn.Module):
243 | ''' Scaled Dot-Product Attention '''
244 |
245 | def __init__(self, query_size, key_size, attn_dropout=0.1):
246 | super().__init__()
247 | self.sparsemax = Sparsemax(dim=1)
248 | self.softmax = nn.Softmax(dim=1)
249 | self.linear = nn.Linear(query_size, key_size)
250 | print('Msg size: ',query_size,' Key size: ', key_size)
251 |
252 | def forward(self, qu, k, v, sparse=True):
253 | # qu (batch,5,32)
254 | # k (batch,5,1024)
255 | # v (batch,5,channel,size,size)
256 | query = self.linear(qu) # (batch,5,key_size)
257 |
258 | # normalization
259 | # query_norm = query.norm(p=2,dim=2).unsqueeze(2).expand_as(query)
260 | # query = query.div(query_norm + 1e-9)
261 |
262 | # k_norm = k.norm(p=2,dim=2).unsqueeze(2).expand_as(k)
263 | # k = k.div(k_norm + 1e-9)
264 |
265 |
266 |
267 | # generate the
268 | attn_orig = torch.bmm(k, query.transpose(2, 1)) # (batch,5,5) column: differnt keys and the same query
269 |
270 | # scaling [not sure]
271 | # scaling = torch.sqrt(torch.tensor(k.shape[2],dtype=torch.float32)).cuda()
272 | # attn_orig = attn_orig/ scaling # (batch,5,5) column: differnt keys and the same query
273 |
274 | attn_orig_softmax = self.softmax(attn_orig) # (batch,5,5)
275 |
276 | attn_shape = attn_orig_softmax.shape
277 | bats, key_num, query_num = attn_shape[0], attn_shape[1], attn_shape[2]
278 | attn_orig_softmax_exp = attn_orig_softmax.view(bats, key_num, query_num, 1, 1, 1)
279 |
280 | v_exp = torch.unsqueeze(v, 2)
281 | v_exp = v_exp.expand(-1, -1, query_num, -1, -1, -1)
282 |
283 | output = attn_orig_softmax_exp * v_exp # (batch,4,channel,size,size)
284 | output_sum = output.sum(1) # (batch,1,channel,size,size)
285 |
286 | return output_sum, attn_orig_softmax
287 |
288 | # MIMO always com
289 | class MIMOWhoGeneralDotProductAttention(nn.Module):
290 | ''' Scaled Dot-Product Attention '''
291 |
292 | def __init__(self, query_size, key_size, attn_dropout=0.1):
293 | super().__init__()
294 | self.sparsemax = Sparsemax(dim=1)
295 | self.softmax = nn.Softmax(dim=1)
296 | self.linear = nn.Linear(query_size, key_size)
297 | print('Msg size: ',query_size,' Key size: ', key_size)
298 |
299 | def forward(self, qu, k, v, sparse=True):
300 | # qu (batch,5,32)
301 | # k (batch,5,1024)
302 | # v (batch,5,channel,size,size)
303 | query = self.linear(qu) # (batch,5,key_size)
304 |
305 |
306 | attn_orig = torch.bmm(k, query.transpose(2, 1)) # (batch,5,5) column: differnt keys and the same query
307 |
308 |
309 | # remove the diagonal and softmax
310 | del_diag_att_orig = []
311 | for bi in range(attn_orig.shape[0]):
312 | up = torch.triu(attn_orig[bi],diagonal=1,out=None)[:-1,]
313 | dow = torch.tril(attn_orig[bi],diagonal=-1,out=None)[1:,]
314 | del_diag_att_orig_per_sample = torch.unsqueeze((up+dow),dim=0)
315 | del_diag_att_orig.append(del_diag_att_orig_per_sample)
316 | del_diag_att_orig = torch.cat(tuple(del_diag_att_orig), dim=0)
317 |
318 | attn_orig_softmax = self.softmax(del_diag_att_orig) # (batch,5,5)
319 |
320 | append_att_orig = []
321 | for bi in range(attn_orig_softmax.shape[0]):
322 | up = torch.triu(attn_orig_softmax[bi],diagonal=1,out=None)
323 | up_ext = torch.cat((up, torch.zeros((1, up.shape[1])).cuda()))
324 | dow = torch.tril(attn_orig_softmax[bi],diagonal=0,out=None)
325 | dow_ext = torch.cat((torch.zeros((1, dow.shape[1])).cuda(), dow))
326 |
327 | append_att_orig_per_sample = torch.unsqueeze((up_ext + dow_ext),dim=0)
328 | append_att_orig.append(append_att_orig_per_sample)
329 | append_att_orig = torch.cat(tuple(append_att_orig), dim=0)
330 |
331 |
332 |
333 | attn_shape = append_att_orig.shape
334 | bats, key_num, query_num = attn_shape[0], attn_shape[1], attn_shape[2]
335 | attn_orig_softmax_exp = append_att_orig.view(bats, key_num, query_num, 1, 1, 1)
336 |
337 | v_exp = torch.unsqueeze(v, 2)
338 | v_exp = v_exp.expand(-1, -1, query_num, -1, -1, -1)
339 |
340 | output = attn_orig_softmax_exp * v_exp # (batch,4,channel,size,size)
341 | output_sum = output.sum(1) # (batch,1,channel,size,size)
342 |
343 | return output_sum, append_att_orig
344 |
345 | class GeneralDotProductAttention(nn.Module):
346 | ''' Scaled Dot-Product Attention '''
347 |
348 | def __init__(self, query_size, key_size, attn_dropout=0.1):
349 | super().__init__()
350 | self.sparsemax = Sparsemax(dim=1)
351 | self.softmax = nn.Softmax(dim=1)
352 | self.linear = nn.Linear(query_size, key_size)
353 | print('Msg size: ',query_size,' Key size: ', key_size)
354 |
355 | def forward(self, q, k, v, sparse=True):
356 | # q (batch,1,128)
357 | # k (batch,4,128)
358 | # v (batch,4,channel*size*size)
359 | query = self.linear(q) # (batch,1,key_size)
360 | attn_orig = torch.bmm(k, query.transpose(2, 1)) # (batch,4,1)
361 | if sparse:
362 | attn_orig = self.sparsemax(attn_orig) # (batch,4,1)
363 | else:
364 | attn_orig = self.softmax(attn_orig) # (batch,4,1)
365 | attn = torch.unsqueeze(torch.unsqueeze(attn_orig, 3), 4) # (batch,4,1,1,1)
366 | output = attn * v # (batch,4,channel,size,size)
367 | output = output.sum(1) # (batch,1,channel,size,size)
368 | return output, attn_orig.transpose(2, 1)
369 |
370 | # ============= Single normal and Single degarded ============= #
371 |
372 |
373 | # ======================= Model =========================
374 |
375 | class Single_agent(nn.Module):
376 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, enc_backbone='n_segnet_encoder',
377 | dec_backbone='n_segnet_decoder', feat_squeezer=-1):
378 | '''
379 | feat_squeezer: -1 (No squeeze),
380 | '''
381 | super(Single_agent, self).__init__()
382 | self.in_channels = in_channels
383 |
384 | # Encoder
385 | self.encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
386 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
387 |
388 | # Decoder
389 | self.decoder = img_decoder(n_classes=n_classes, in_channels=feat_channel * 1, feat_squeezer=feat_squeezer,
390 | dec_backbone=dec_backbone)
391 |
392 | def forward(self, inputs):
393 | feature_map = self.encoder(inputs)
394 | pred = self.decoder(feature_map)
395 | return pred
396 |
397 |
398 | # Randomly selection baseline and Concatenation of all observations
399 | class All_agents(nn.Module):
400 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, aux_agent_num=4, shuffle_flag=False,
401 | enc_backbone='n_segnet_encoder', dec_backbone='n_segnet_decoder', feat_squeezer=-1):
402 | super(All_agents, self).__init__()
403 |
404 | self.agent_num = aux_agent_num
405 | self.in_channels = in_channels
406 | self.shuffle_flag = shuffle_flag
407 |
408 | # Encoder
409 | self.encoder1 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
410 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
411 | self.encoder2 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
412 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
413 | self.encoder3 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
414 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
415 | self.encoder4 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
416 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
417 | self.encoder5 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
418 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
419 |
420 | # Decoder for interested agent
421 | if self.shuffle_flag == 'selection': # random selection
422 | self.decoder = img_decoder(n_classes=n_classes, in_channels=feat_channel * 2 ,
423 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone)
424 | else: # catall
425 | self.decoder = img_decoder(n_classes=n_classes, in_channels=feat_channel * self.agent_num,
426 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone)
427 |
428 | def divide_inputs(self, inputs):
429 | '''
430 | Divide the input into a list of several images
431 | '''
432 | input_list = []
433 | divide_num = 5
434 | for i in range(divide_num):
435 | input_list.append(inputs[:, 3 * i:3 * i + 3, :, :])
436 |
437 | return input_list
438 |
439 | def forward(self, inputs):
440 | # agent_num = 5
441 |
442 | input_list = self.divide_inputs(inputs)
443 | feat_map1 = self.encoder1(input_list[0])
444 | feat_map2 = self.encoder2(input_list[1])
445 | feat_map3 = self.encoder3(input_list[2])
446 | feat_map4 = self.encoder4(input_list[3])
447 | feat_map5 = self.encoder5(input_list[4])
448 |
449 | if self.shuffle_flag == 'selection': # use randomly picked feature and only specific numbers
450 | aux_view_feats = [feat_map1,feat_map2, feat_map3, feat_map4, feat_map5]
451 | aux_id = random.randint(0, 4)
452 | aux_view_feats = torch.unsqueeze(aux_view_feats[aux_id], 0)
453 | feat_map_list = (feat_map1,) + tuple(aux_view_feats)
454 | argmax_action = torch.ones(feat_map1.shape[0], dtype=torch.long)*aux_id
455 |
456 | elif self.shuffle_flag == 'fixed2':
457 | feat_map_list = (feat_map1, feat_map2)
458 |
459 | else:
460 | feat_map_list = (feat_map1, feat_map2, feat_map3, feat_map4, feat_map5)
461 |
462 | # combine the feat maps
463 | concat_featmaps = torch.cat(feat_map_list, 1)
464 | pred = self.decoder(concat_featmaps)
465 |
466 | if self.shuffle_flag == 'selection': # use randomly picked feature and only specific numbers
467 | return pred, argmax_action.cuda()
468 | else:
469 | return pred
470 |
471 |
472 | class LearnWho2Com(nn.Module):
473 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, feat_squeezer=-1, attention='additive',
474 | has_query=True, sparse=False, aux_agent_num=4, shuffle_flag=False, image_size=512,
475 | shared_img_encoder=False \
476 | , key_size=128, query_size=128, enc_backbone='n_segnet_encoder', dec_backbone='n_segnet_decoder'):
477 | super(LearnWho2Com, self).__init__()
478 | # agent_num = 2
479 | self.aux_agent_num = aux_agent_num
480 | self.in_channels = in_channels
481 | self.shuffle_flag = shuffle_flag
482 | self.feature_map_channel = 512
483 | self.key_size = key_size
484 | self.query_size = query_size
485 | self.shared_img_encoder = shared_img_encoder
486 | self.has_query = has_query
487 | self.sparse = sparse
488 | # Encoder
489 | # Non-shared
490 |
491 | if self.shared_img_encoder == 'unified':
492 | self.u_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
493 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
494 | elif self.shared_img_encoder == 'only_normal_agents':
495 | self.degarded_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
496 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
497 | self.normal_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
498 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
499 |
500 | else:
501 | self.encoder1 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
502 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
503 | self.encoder2 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
504 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
505 | self.encoder3 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
506 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
507 | self.encoder4 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
508 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
509 | self.encoder5 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
510 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
511 |
512 | # # Message generator
513 | self.query_key_net = policy_net4(n_classes=n_classes, in_channels=in_channels, enc_backbone=enc_backbone)
514 | if self.has_query:
515 | self.query_net = linear(out_size=self.query_size, input_feat_sz=image_size / 32)
516 |
517 | self.key_net = linear(out_size=self.key_size, input_feat_sz=image_size / 32)
518 | if attention == 'additive':
519 | self.attention_net = AdditiveAttentin()
520 | elif attention == 'general':
521 | self.attention_net = GeneralDotProductAttention(self.query_size, self.key_size)
522 | else:
523 | self.attention_net = ScaledDotProductAttention(128 ** 0.5)
524 |
525 | # Segmentation decoder
526 |
527 | self.decoder = img_decoder(n_classes=n_classes, in_channels=self.feature_map_channel * 2,
528 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone)
529 | # List the parameters of each modules
530 | self.attention_paras = list(self.attention_net.parameters())
531 | if self.shared_img_encoder == 'unified':
532 | self.img_net_paras = list(self.u_encoder.parameters()) + list(self.decoder.parameters())
533 | elif self.shared_img_encoder == 'only_normal_agents':
534 | self.img_net_paras = list(self.degarded_encoder.parameters()) + list(
535 | self.normal_encoder.parameters()) + list(self.decoder.parameters())
536 | else:
537 | self.img_net_paras = list(self.encoder1.parameters()) + \
538 | list(self.encoder2.parameters()) + \
539 | list(self.encoder3.parameters()) + \
540 | list(self.encoder4.parameters()) + \
541 | list(self.encoder5.parameters()) + \
542 | list(self.decoder.parameters())
543 |
544 | self.policy_net_paras = list(self.query_key_net.parameters()) + list(
545 | self.key_net.parameters()) + self.attention_paras
546 | if self.has_query:
547 | self.policy_net_paras = self.policy_net_paras + list(self.query_net.parameters())
548 |
549 | self.all_paras = self.img_net_paras + self.policy_net_paras
550 |
551 | def divide_inputs(self, inputs):
552 | '''
553 | Divide the input into a list of several images
554 | '''
555 | input_list = []
556 | divide_num = 5
557 | for i in range(divide_num):
558 | input_list.append(inputs[:, 3 * i:3 * i + 3, :, :])
559 |
560 | return input_list
561 |
562 | def shape_feat_map(self, feat, size):
563 | return torch.unsqueeze(feat.view(-1, size[1] * size[2] * size[3]), 1)
564 |
565 | def forward(self, inputs, training=True, inference='argmax'):
566 |
567 | batch_size, _, _, _ = inputs.size()
568 | input_list = self.divide_inputs(inputs)
569 | if self.shared_img_encoder == 'unified':
570 | # vectorize
571 | unified_feat_map = torch.cat((input_list[0], input_list[1], input_list[2], input_list[3], input_list[4]), 0)
572 | feat_map = self.u_encoder(unified_feat_map)
573 | feat_map1 = feat_map[0:batch_size * 1]
574 | feat_map2 = feat_map[batch_size * 1:batch_size * 2]
575 | feat_map3 = feat_map[batch_size * 2:batch_size * 3]
576 | feat_map4 = feat_map[batch_size * 3:batch_size * 4]
577 | feat_map5 = feat_map[batch_size * 4:batch_size * 5]
578 |
579 | elif self.shared_img_encoder == 'only_normal_agents':
580 | feat_map1 = self.degarded_encoder(input_list[0])
581 | # vectorize
582 | unified_normal_feat_map = torch.cat((input_list[1], input_list[2], input_list[3], input_list[4]), 0)
583 | feat_map = self.normal_encoder(unified_normal_feat_map)
584 | feat_map2 = feat_map[0:batch_size * 1]
585 | feat_map3 = feat_map[batch_size * 1:batch_size * 2]
586 | feat_map4 = feat_map[batch_size * 2:batch_size * 3]
587 | feat_map5 = feat_map[batch_size * 3:batch_size * 4]
588 | else:
589 | feat_map1 = self.encoder1(input_list[0])
590 | feat_map2 = self.encoder2(input_list[1])
591 | feat_map3 = self.encoder3(input_list[2])
592 | feat_map4 = self.encoder4(input_list[3])
593 | feat_map5 = self.encoder5(input_list[4])
594 |
595 | unified_feat_map = torch.cat((input_list[0], input_list[1], input_list[2], input_list[3], input_list[4]), 0)
596 | query_key_map = self.query_key_net(unified_feat_map)
597 | query_key_map1 = query_key_map[0:batch_size * 1]
598 | query_key_map2 = query_key_map[batch_size * 1:batch_size * 2]
599 | query_key_map3 = query_key_map[batch_size * 2:batch_size * 3]
600 | query_key_map4 = query_key_map[batch_size * 3:batch_size * 4]
601 | query_key_map5 = query_key_map[batch_size * 4:batch_size * 5]
602 |
603 | key2 = torch.unsqueeze(self.key_net(query_key_map2), 1)
604 | key3 = torch.unsqueeze(self.key_net(query_key_map3), 1)
605 | key4 = torch.unsqueeze(self.key_net(query_key_map4), 1)
606 | key5 = torch.unsqueeze(self.key_net(query_key_map5), 1)
607 |
608 | if self.has_query:
609 | query = torch.unsqueeze(self.query_net(query_key_map1), 1) # (batch,1,128)
610 | else:
611 | query = torch.ones(batch_size, 1, self.query_size).to('cuda')
612 |
613 | feat_map2 = torch.unsqueeze(feat_map2, 1) # (batch,1,channel,size,size)
614 | feat_map3 = torch.unsqueeze(feat_map3, 1) # (batch,1,channel,size,size)
615 | feat_map4 = torch.unsqueeze(feat_map4, 1) # (batch,1,channel,size,size)
616 | feat_map5 = torch.unsqueeze(feat_map5, 1) # (batch,1,channel,size,size)
617 |
618 | keys = torch.cat((key2, key3, key4, key5), 1) # (batch,4,128)
619 | vals = torch.cat((feat_map2, feat_map3, feat_map4, feat_map5), 1) # (batch,4,channel,size,size)
620 | aux_feat, prob_action = self.attention_net(query, keys, vals,
621 | sparse=self.sparse) # (batch,1,#channel*size*size),(batch,1,4)
622 | # print('Action: ', prob_action)
623 | concat_featmaps = torch.cat((feat_map1, aux_feat), 1)
624 | pred = self.decoder(concat_featmaps)
625 | if training:
626 | action = torch.argmax(prob_action, dim=2)
627 | return pred, prob_action, action
628 | else:
629 | if inference == 'softmax':
630 | action = torch.argmax(prob_action, dim=2)
631 | return pred, prob_action, action
632 | elif inference == 'argmax_test':
633 | action = torch.argmax(prob_action, dim=2)
634 | aux_feat_list = []
635 | for k in range(batch_size):
636 | if action[k] == 0:
637 | aux_feat_list.append(torch.unsqueeze(feat_map2[k], 0))
638 | elif action[k] == 1:
639 | aux_feat_list.append(torch.unsqueeze(feat_map3[k], 0))
640 | elif action[k] == 2:
641 | aux_feat_list.append(torch.unsqueeze(feat_map4[k], 0))
642 | elif action[k] == 3:
643 | aux_feat_list.append(torch.unsqueeze(feat_map5[k], 0))
644 | else:
645 | raise ValueError('Incorrect action')
646 | aux_feat_argmax = tuple(aux_feat_list)
647 | aux_feat_argmax = torch.cat(aux_feat_argmax, 0)
648 | aux_feat_argmax = torch.squeeze(aux_feat_argmax, 1)
649 | concat_featmaps_argmax = torch.cat((feat_map1.detach(), aux_feat_argmax.detach()), 1)
650 | pred_argmax = self.decoder(concat_featmaps_argmax)
651 | return pred_argmax, prob_action, action
652 | elif inference == 'argmax_train':
653 | action = torch.argmax(prob_action, dim=2)
654 | aux_feat_list = []
655 | for k in range(batch_size):
656 | if action[k] == 0:
657 | aux_feat_list.append(torch.unsqueeze(feat_map2[k], 0))
658 | elif action[k] == 1:
659 | aux_feat_list.append(torch.unsqueeze(feat_map3[k], 0))
660 | elif action[k] == 2:
661 | aux_feat_list.append(torch.unsqueeze(feat_map4[k], 0))
662 | elif action[k] == 3:
663 | aux_feat_list.append(torch.unsqueeze(feat_map5[k], 0))
664 | else:
665 | raise ValueError('Incorrect action')
666 | aux_feat_argmax = tuple(aux_feat_list)
667 | aux_feat_argmax = torch.cat(aux_feat_argmax, 0)
668 | aux_feat_argmax = torch.squeeze(aux_feat_argmax, 1)
669 | concat_featmaps_argmax = torch.cat((feat_map1.detach(), aux_feat_argmax.detach()), 1)
670 | pred_argmax = self.argmax_decoder(concat_featmaps_argmax)
671 | return pred_argmax, prob_action, action
672 | else:
673 | raise ValueError('Incorrect inference mode')
674 |
675 |
676 | class LearnWhen2Com(nn.Module):
677 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, feat_squeezer=-1, attention='additive',
678 | has_query=True, sparse=False, aux_agent_num=4, shuffle_flag=False, image_size=512,
679 | shared_img_encoder=False, key_size=128, query_size=128, enc_backbone='n_segnet_encoder',
680 | dec_backbone='n_segnet_decoder'):
681 | super(LearnWhen2Com, self).__init__()
682 | # agent_num = 2
683 | self.aux_agent_num = aux_agent_num
684 | self.in_channels = in_channels
685 | self.shuffle_flag = shuffle_flag
686 | self.feature_map_channel = 512
687 | self.key_size = key_size
688 | self.query_size = query_size
689 | self.shared_img_encoder = shared_img_encoder
690 | self.has_query = has_query
691 | self.sparse = sparse
692 | # Encoder
693 | # Non-shared
694 | if self.shared_img_encoder == 'unified':
695 | self.u_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
696 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
697 | elif self.shared_img_encoder == 'only_normal_agents':
698 | self.degarded_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
699 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
700 | self.normal_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
701 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
702 |
703 | else:
704 | self.encoder1 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
705 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
706 | self.encoder2 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
707 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
708 | self.encoder3 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
709 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
710 | self.encoder4 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
711 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
712 | self.encoder5 = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
713 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
714 |
715 | # # Message generator
716 | self.query_key_net = policy_net4(n_classes=n_classes, in_channels=in_channels, enc_backbone=enc_backbone)
717 | if self.has_query:
718 | self.query_net = linear(out_size=self.query_size,input_feat_sz=image_size / 32)
719 | self.key_net = linear(out_size=self.key_size,input_feat_sz=image_size / 32)
720 | if attention == 'additive':
721 | self.attention_net = AdditiveAttentin()
722 | elif attention == 'general':
723 | self.attention_net = GeneralDotProductAttention(self.query_size, self.key_size)
724 | else:
725 | self.attention_net = ScaledDotProductAttention(128 ** 0.5)
726 |
727 | # Segmentation decoder
728 | self.argmax_decoder = img_decoder(n_classes=n_classes, in_channels=self.feature_map_channel,
729 | agent_num=self.aux_agent_num + 1, dec_backbone=dec_backbone)
730 |
731 | self.decoder = img_decoder(n_classes=n_classes, in_channels=self.feature_map_channel,
732 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone)
733 | # List the parameters of each modules
734 | self.attention_paras = list(self.attention_net.parameters())
735 | if self.shared_img_encoder == 'unified':
736 | self.img_net_paras = list(self.u_encoder.parameters()) + list(self.decoder.parameters())
737 | elif self.shared_img_encoder == 'only_normal_agents':
738 | self.img_net_paras = list(self.degarded_encoder.parameters()) + list(
739 | self.normal_encoder.parameters()) + list(self.decoder.parameters())
740 | else:
741 | self.img_net_paras = list(self.encoder1.parameters()) + \
742 | list(self.encoder2.parameters()) + \
743 | list(self.encoder3.parameters()) + \
744 | list(self.encoder4.parameters()) + \
745 | list(self.encoder5.parameters()) + \
746 | list(self.decoder.parameters())
747 |
748 | self.img_net_paras = self.img_net_paras + list(self.argmax_decoder.parameters())
749 |
750 | self.policy_net_paras = list(self.query_key_net.parameters()) + list(
751 | self.key_net.parameters()) + self.attention_paras
752 | if self.has_query:
753 | self.policy_net_paras = self.policy_net_paras + list(self.query_net.parameters())
754 |
755 |
756 | self.all_paras = self.img_net_paras + self.policy_net_paras
757 |
758 | def divide_inputs(self, inputs):
759 | '''
760 | Divide the input into a list of several images
761 | '''
762 | input_list = []
763 | divide_num = 5
764 | for i in range(divide_num):
765 | input_list.append(inputs[:, 3 * i:3 * i + 3, :, :])
766 |
767 | return input_list
768 |
769 | def shape_feat_map(self, feat, size):
770 | return torch.unsqueeze(feat.view(-1, size[1] * size[2] * size[3]), 1)
771 |
772 | def argmax_select(self, feat_map1, feat_map2, feat_map3, feat_map4, feat_map5, action, batch_size):
773 |
774 | num_connect = 0
775 | feat_list = []
776 | for k in range(batch_size):
777 | if action[k] == 0:
778 | feat_list.append(torch.unsqueeze(feat_map1[k], 0))
779 | elif action[k] == 1:
780 | feat_list.append(torch.unsqueeze(feat_map2[k], 0))
781 | num_connect = num_connect + 1
782 | elif action[k] == 2:
783 | feat_list.append(torch.unsqueeze(feat_map3[k], 0))
784 | num_connect = num_connect + 1
785 | elif action[k] == 3:
786 | feat_list.append(torch.unsqueeze(feat_map4[k], 0))
787 | num_connect = num_connect + 1
788 | elif action[k] == 4:
789 | feat_list.append(torch.unsqueeze(feat_map5[k], 0))
790 | num_connect = num_connect + 1
791 | else:
792 | raise ValueError('Incorrect action')
793 | num_connect = num_connect / batch_size
794 |
795 | feat_argmax = tuple(feat_list)
796 | feat_argmax = torch.cat(feat_argmax, 0)
797 | feat_argmax = torch.squeeze(feat_argmax, 1)
798 | return feat_argmax, num_connect
799 |
800 | def activated_select(self, vals, W_mat, thres=0.2):
801 | action = torch.mul(W_mat, (W_mat > thres).float())
802 | attn = action.view(action.shape[0], action.shape[2], 1, 1, 1) # (batch,5,1,1,1)
803 | output = attn * vals
804 | feat_fuse = output.sum(1) # (batch,1,channel,size,size)
805 |
806 | batch_size = action.shape[0]
807 | num_connect = torch.nonzero(action[:, :, 1:]).shape[0] / batch_size
808 |
809 | return feat_fuse, action, num_connect
810 |
811 | def forward(self, inputs, training=True, inference='argmax'):
812 | batch_size, _, _, _ = inputs.size()
813 | input_list = self.divide_inputs(inputs)
814 | if self.shared_img_encoder == 'unified':
815 | unified_feat_map = torch.cat((input_list[0], input_list[1], input_list[2], input_list[3], input_list[4]), 0)
816 | feat_map = self.u_encoder(unified_feat_map)
817 | feat_map1 = feat_map[0:batch_size * 1]
818 | feat_map2 = feat_map[batch_size * 1:batch_size * 2]
819 | feat_map3 = feat_map[batch_size * 2:batch_size * 3]
820 | feat_map4 = feat_map[batch_size * 3:batch_size * 4]
821 | feat_map5 = feat_map[batch_size * 4:batch_size * 5]
822 |
823 | elif self.shared_img_encoder == 'only_normal_agents':
824 | feat_map1 = self.degarded_encoder(input_list[0])
825 | unified_normal_feat_map = torch.cat((input_list[1], input_list[2], input_list[3], input_list[4]), 0)
826 | feat_map = self.normal_encoder(unified_normal_feat_map)
827 | feat_map2 = feat_map[0:batch_size * 1]
828 | feat_map3 = feat_map[batch_size * 1:batch_size * 2]
829 | feat_map4 = feat_map[batch_size * 2:batch_size * 3]
830 | feat_map5 = feat_map[batch_size * 3:batch_size * 4]
831 | else:
832 | feat_map1 = self.encoder1(input_list[0])
833 | feat_map2 = self.encoder2(input_list[1])
834 | feat_map3 = self.encoder3(input_list[2])
835 | feat_map4 = self.encoder4(input_list[3])
836 | feat_map5 = self.encoder5(input_list[4])
837 |
838 | unified_feat_map = torch.cat((input_list[0], input_list[1], input_list[2], input_list[3], input_list[4]), 0)
839 | query_key_map = self.query_key_net(unified_feat_map)
840 | query_key_map1 = query_key_map[0:batch_size * 1]
841 |
842 | keys = self.key_net(query_key_map)
843 | key1 = torch.unsqueeze(keys[0:batch_size * 1], 1)
844 | key2 = torch.unsqueeze(keys[batch_size * 1:batch_size * 2], 1)
845 | key3 = torch.unsqueeze(keys[batch_size * 2:batch_size * 3], 1)
846 | key4 = torch.unsqueeze(keys[batch_size * 3:batch_size * 4], 1)
847 | key5 = torch.unsqueeze(keys[batch_size * 4:batch_size * 5], 1)
848 |
849 |
850 | if self.has_query:
851 | querys = self.query_net(query_key_map)
852 | query = torch.unsqueeze(querys[0:batch_size * 1], 1)
853 | else:
854 | query = torch.ones(batch_size, 1, self.query_size).to('cuda')
855 |
856 | feat_map1 = torch.unsqueeze(feat_map1, 1) # (batch,1,channel,size,size)
857 | feat_map2 = torch.unsqueeze(feat_map2, 1) # (batch,1,channel,size,size)
858 | feat_map3 = torch.unsqueeze(feat_map3, 1) # (batch,1,channel,size,size)
859 | feat_map4 = torch.unsqueeze(feat_map4, 1) # (batch,1,channel,size,size)
860 | feat_map5 = torch.unsqueeze(feat_map5, 1) # (batch,1,channel,size,size)
861 |
862 | keys = torch.cat((key1, key2, key3, key4, key5), 1) # (batch,4,128)
863 | vals = torch.cat((feat_map1, feat_map2, feat_map3, feat_map4, feat_map5), 1) # (batch,4,channel,size,size)
864 | aux_feat, prob_action = self.attention_net(query, keys, vals,
865 | sparse=self.sparse) # (batch,1,#channel*size*size),(batch,1,4)
866 | pred = self.decoder(aux_feat)
867 |
868 | if training:
869 | action = torch.argmax(prob_action, dim=2)
870 | return pred, prob_action, action
871 | else:
872 | if inference == 'softmax':
873 | action = torch.argmax(prob_action, dim=2)
874 | num_connect = 4
875 | return pred, prob_action, action, num_connect
876 | elif inference == 'argmax_test':
877 | action = torch.argmax(prob_action, dim=2)
878 | feat_argmax, num_connect = self.argmax_select(feat_map1, feat_map2, feat_map3, feat_map4, feat_map5,
879 | action, batch_size)
880 | featmaps_argmax = feat_argmax.detach()
881 | pred_argmax = self.decoder(featmaps_argmax)
882 | return pred_argmax, prob_action, action, num_connect
883 | elif inference == 'activated':
884 | feat_act, action, num_connect = self.activated_select(vals, prob_action)
885 | featmaps_act = feat_act.detach()
886 | pred_act = self.decoder(featmaps_act)
887 | return pred_act, prob_action, action, num_connect
888 | else:
889 | raise ValueError('Incorrect inference mode')
890 |
891 |
892 | class MIMO_All_agents(nn.Module):
893 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, aux_agent_num=4, shuffle_flag=False,
894 | enc_backbone='n_segnet_encoder', dec_backbone='n_segnet_decoder', feat_squeezer=-1):
895 | super(MIMO_All_agents, self).__init__()
896 |
897 | self.agent_num = aux_agent_num # include the target agent
898 | self.in_channels = in_channels
899 | self.shuffle_flag = shuffle_flag
900 |
901 | self.encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
902 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
903 |
904 |
905 |
906 | # Decoder for interested agent
907 | if self.shuffle_flag == 'selection': # random selection
908 | self.decoder = img_decoder(n_classes=n_classes, in_channels=feat_channel * 2 ,
909 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone)
910 | elif self.shuffle_flag == 'ComNet': # random selection
911 | self.decoder = img_decoder(n_classes=n_classes, in_channels=feat_channel * 2 ,
912 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone)
913 | else: # Catall
914 | self.decoder = img_decoder(n_classes=n_classes, in_channels=feat_channel * self.agent_num,
915 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone)
916 |
917 | def divide_inputs(self, inputs):
918 | '''
919 | Divide the input into a list of several images
920 | '''
921 | input_list = []
922 | divide_num = self.agent_num
923 | for i in range(divide_num):
924 | input_list.append(inputs[:, 3 * i:3 * i + 3, :, :])
925 | return input_list
926 |
927 | def forward(self, inputs):
928 | input_list = self.divide_inputs(inputs)
929 |
930 | feat_map = []
931 | for i in range(self.agent_num):
932 | fmp = self.encoder(input_list[i])
933 | feat_map.append(fmp)
934 |
935 | if self.shuffle_flag == 'selection': # use randomly picked feature and only specific numbers
936 | # randomly select
937 | concat_fmp_list = []
938 | rand_action_list = []
939 | for i in range(self.agent_num):
940 | randidx = random.randint(0, self.agent_num-1)
941 | rand_action_list.append( torch.ones((feat_map[0].shape[0],1), dtype=torch.long)*randidx)
942 |
943 | fmp_per_agent = torch.cat((feat_map[i], feat_map[randidx]), 1) # fmp_list = [bs, channel*2, h, w]
944 | concat_fmp_list.append(fmp_per_agent)
945 | concat_featmaps = torch.cat(tuple(concat_fmp_list), 0) # concat_featmaps = [bs*agent_num, channel*2, h, w]
946 | pred = self.decoder(concat_featmaps)
947 | argmax_action = torch.cat(tuple(rand_action_list), 1)
948 | elif self.shuffle_flag == 'ComNet': # use randomly picked feature and only specific numbers
949 |
950 | # randomly select
951 | concat_fmp_list = []
952 | for i in range(self.agent_num):
953 | sum_other_feat = torch.zeros((feat_map[0].shape[0], feat_map[0].shape[1],feat_map[0].shape[2],feat_map[0].shape[3])).cuda()
954 | for j in range(self.agent_num):
955 | if j!= i: # other agents
956 | sum_other_feat += feat_map[j]
957 | avg_other_feat = sum_other_feat/(self.agent_num-1)
958 |
959 | fmp_per_agent = torch.cat((feat_map[i], avg_other_feat), 1) # fmp_list = [bs, channel*2, h, w]
960 | concat_fmp_list.append(fmp_per_agent)
961 | concat_featmaps = torch.cat(tuple(concat_fmp_list), 0) # concat_featmaps = [bs*agent_num, channel*2, h, w]
962 | pred = self.decoder(concat_featmaps)
963 |
964 | else:
965 | concat_fmp_list = []
966 | for i in range(self.agent_num):
967 | fmp_list = []
968 | for j in range(self.agent_num):
969 | fmp_list.append(feat_map[ (i+j)%self.agent_num ])
970 | fmp_per_agent = torch.cat(tuple(fmp_list), 1) # fmp_list = [bs, channel*agent_num, h, w]
971 | concat_fmp_list.append(fmp_per_agent)
972 | concat_featmaps = torch.cat(tuple(concat_fmp_list), 0) # concat_featmaps = [bs*agent_num, channel*agent_num, h, w]
973 |
974 | pred = self.decoder(concat_featmaps)
975 |
976 |
977 | if self.shuffle_flag == 'selection': # use randomly picked feature and only specific numbers
978 | return pred, argmax_action.cuda()
979 | else:
980 | return pred
981 |
982 | # our model (no warping)
983 | class MIMOcom(nn.Module):
984 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, feat_squeezer=-1, attention='additive',
985 | has_query=True, sparse=False, agent_num=5, shuffle_flag=False, image_size=512,
986 | shared_img_encoder=False, key_size=128, query_size=128, enc_backbone='n_segnet_encoder',
987 | dec_backbone='n_segnet_decoder'):
988 | super(MIMOcom, self).__init__()
989 |
990 | self.agent_num = agent_num
991 | self.in_channels = in_channels
992 | self.shuffle_flag = shuffle_flag
993 | self.feature_map_channel = 512
994 | self.key_size = key_size
995 | self.query_size = query_size
996 | self.shared_img_encoder = shared_img_encoder
997 | self.has_query = has_query
998 | self.sparse = sparse
999 |
1000 |
1001 | print('When2com') # our model: detach the learning of values and keys
1002 | self.u_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
1003 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
1004 |
1005 | self.key_net = km_generator(out_size=self.key_size, input_feat_sz=image_size / 32)
1006 | self.attention_net = MIMOGeneralDotProductAttention(self.query_size, self.key_size)
1007 |
1008 | # # Message generator
1009 | self.query_key_net = policy_net4(n_classes=n_classes, in_channels=in_channels, enc_backbone=enc_backbone)
1010 | if self.has_query:
1011 | self.query_net = km_generator(out_size=self.query_size, input_feat_sz=image_size / 32)
1012 |
1013 | # Segmentation decoder
1014 | self.decoder = img_decoder(n_classes=n_classes, in_channels=self.feature_map_channel,
1015 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone)
1016 |
1017 |
1018 | # List the parameters of each modules
1019 | self.attention_paras = list(self.attention_net.parameters())
1020 | if self.shared_img_encoder == 'unified':
1021 | self.img_net_paras = list(self.u_encoder.parameters()) + list(self.decoder.parameters())
1022 |
1023 |
1024 |
1025 | self.policy_net_paras = list(self.query_key_net.parameters()) + list(
1026 | self.key_net.parameters()) + self.attention_paras
1027 | if self.has_query:
1028 | self.policy_net_paras = self.policy_net_paras + list(self.query_net.parameters())
1029 |
1030 | self.all_paras = self.img_net_paras + self.policy_net_paras
1031 |
1032 |
1033 | def shape_feat_map(self, feat, size):
1034 | return torch.unsqueeze(feat.view(-1, size[1] * size[2] * size[3]), 1)
1035 |
1036 | def argmax_select(self, val_mat, prob_action):
1037 | # v(batch, query_num, channel, size, size)
1038 | cls_num = prob_action.shape[1]
1039 |
1040 | coef_argmax = F.one_hot(prob_action.max(dim=1)[1], num_classes=cls_num).type(torch.cuda.FloatTensor)
1041 | coef_argmax = coef_argmax.transpose(1, 2)
1042 | attn_shape = coef_argmax.shape
1043 | bats, key_num, query_num = attn_shape[0], attn_shape[1], attn_shape[2]
1044 | coef_argmax_exp = coef_argmax.view(bats, key_num, query_num, 1, 1, 1)
1045 |
1046 | v_exp = torch.unsqueeze(val_mat, 2)
1047 | v_exp = v_exp.expand(-1, -1, query_num, -1, -1, -1)
1048 |
1049 | output = coef_argmax_exp * v_exp # (batch,4,channel,size,size)
1050 | feat_argmax = output.sum(1) # (batch,1,channel,size,size)
1051 |
1052 | # compute connect
1053 | count_coef = copy.deepcopy(coef_argmax)
1054 | ind = np.diag_indices(self.agent_num)
1055 | count_coef[:, ind[0], ind[1]] = 0
1056 | num_connect = torch.nonzero(count_coef).shape[0] / (self.agent_num * count_coef.shape[0])
1057 |
1058 | return feat_argmax, coef_argmax, num_connect
1059 |
1060 | def activated_select(self, val_mat, prob_action, thres=0.2):
1061 |
1062 | coef_act = torch.mul(prob_action, (prob_action > thres).float())
1063 | attn_shape = coef_act.shape
1064 | bats, key_num, query_num = attn_shape[0], attn_shape[1], attn_shape[2]
1065 | coef_act_exp = coef_act.view(bats, key_num, query_num, 1, 1, 1)
1066 |
1067 | v_exp = torch.unsqueeze(val_mat, 2)
1068 | v_exp = v_exp.expand(-1, -1, query_num, -1, -1, -1)
1069 |
1070 | output = coef_act_exp * v_exp # (batch,4,channel,size,size)
1071 | feat_act = output.sum(1) # (batch,1,channel,size,size)
1072 |
1073 | # compute connect
1074 | count_coef = coef_act.clone()
1075 | ind = np.diag_indices(self.agent_num)
1076 | count_coef[:, ind[0], ind[1]] = 0
1077 | num_connect = torch.nonzero(count_coef).shape[0] / (self.agent_num * count_coef.shape[0])
1078 | return feat_act, coef_act, num_connect
1079 |
1080 | def agents2batch(self, feats):
1081 | agent_num = feats.shape[1]
1082 | feat_list = []
1083 | for i in range(agent_num):
1084 | feat_list.append(feats[:, i, :, :, :])
1085 | feat_mat = torch.cat(tuple(feat_list), 0)
1086 | return feat_mat
1087 |
1088 | def divide_inputs(self, inputs):
1089 | '''
1090 | Divide the input into a list of several images
1091 | '''
1092 | input_list = []
1093 | for i in range(self.agent_num):
1094 | input_list.append(inputs[:, 3 * i:3 * i + 3, :, :])
1095 |
1096 | return input_list
1097 |
1098 | def forward(self, inputs, training=True, MO_flag=False , inference='argmax'):
1099 | batch_size, _, _, _ = inputs.size()
1100 |
1101 | input_list = self.divide_inputs(inputs)
1102 |
1103 | if self.shared_img_encoder == 'unified':
1104 | # vectorize input list
1105 | img_list = []
1106 | for i in range(self.agent_num):
1107 | img_list.append(input_list[i])
1108 | unified_img_list = torch.cat(tuple(img_list), 0)
1109 |
1110 | # pass encoder
1111 | feat_maps = self.u_encoder(unified_img_list)
1112 |
1113 | # get feat maps for each image
1114 | feat_map = {}
1115 | feat_list = []
1116 | for i in range(self.agent_num):
1117 | feat_map[i] = torch.unsqueeze(feat_maps[batch_size * i:batch_size * (i + 1)], 1)
1118 | feat_list.append(feat_map[i])
1119 | val_mat = torch.cat(tuple(feat_list), 1)
1120 | else:
1121 | raise ValueError('Incorrect encoder')
1122 |
1123 | # pass feature maps through key and query generator
1124 | query_key_maps = self.query_key_net(unified_img_list)
1125 |
1126 | keys = self.key_net(query_key_maps)
1127 |
1128 | if self.has_query:
1129 | querys = self.query_net(query_key_maps)
1130 |
1131 | # get key and query
1132 | key = {}
1133 | query = {}
1134 | key_list = []
1135 | query_list = []
1136 |
1137 | for i in range(self.agent_num):
1138 | key[i] = torch.unsqueeze(keys[batch_size * i:batch_size * (i + 1)], 1)
1139 | key_list.append(key[i])
1140 | if self.has_query:
1141 | query[i] = torch.unsqueeze(querys[batch_size * i:batch_size * (i + 1)], 1)
1142 | else:
1143 | query[i] = torch.ones(batch_size, 1, self.query_size).to('cuda')
1144 | query_list.append(query[i])
1145 |
1146 |
1147 | key_mat = torch.cat(tuple(key_list), 1)
1148 | query_mat = torch.cat(tuple(query_list), 1)
1149 |
1150 | if MO_flag:
1151 | query_mat = query_mat
1152 | else:
1153 | query_mat = torch.unsqueeze(query_mat[:,0,:],1)
1154 |
1155 | feat_fuse, prob_action = self.attention_net(query_mat, key_mat, val_mat, sparse=self.sparse)
1156 |
1157 |
1158 | # weighted feature maps is passed to decoder
1159 | feat_fuse_mat = self.agents2batch(feat_fuse)
1160 |
1161 | pred = self.decoder(feat_fuse_mat)
1162 |
1163 | # not related to how we combine the feature (prefer to use the agnets' own frames: to reduce the bandwidth)
1164 | small_bis = torch.eye(prob_action.shape[1])*0.001
1165 | small_bis = small_bis.reshape((1, prob_action.shape[1], prob_action.shape[2]))
1166 | small_bis = small_bis.repeat(prob_action.shape[0], 1, 1).cuda()
1167 | prob_action = prob_action + small_bis
1168 |
1169 |
1170 | if training:
1171 | action = torch.argmax(prob_action, dim=1)
1172 | num_connect = self.agent_num - 1
1173 |
1174 | return pred, prob_action, action, num_connect
1175 | else:
1176 | if inference == 'softmax':
1177 | action = torch.argmax(prob_action, dim=1)
1178 | num_connect = self.agent_num - 1
1179 |
1180 | return pred, prob_action, action, num_connect
1181 |
1182 | elif inference == 'argmax_test':
1183 |
1184 | feat_argmax, connect_mat, num_connect = self.argmax_select(val_mat, prob_action)
1185 |
1186 | feat_argmax_mat = self.agents2batch(feat_argmax) # (batchsize*agent_num, channel, size, size)
1187 | feat_argmax_mat = feat_argmax_mat.detach()
1188 | pred_argmax = self.decoder(feat_argmax_mat)
1189 | action = torch.argmax(connect_mat, dim=1)
1190 | return pred_argmax, prob_action, action, num_connect
1191 |
1192 | elif inference == 'activated':
1193 | feat_act, connect_mat, num_connect = self.activated_select(val_mat, prob_action)
1194 |
1195 | feat_act_mat = self.agents2batch(feat_act) # (batchsize*agent_num, channel, size, size)
1196 | feat_act_mat = feat_act_mat.detach()
1197 |
1198 | pred_act = self.decoder(feat_act_mat)
1199 |
1200 | action = torch.argmax(connect_mat, dim=1)
1201 |
1202 | return pred_act, prob_action, action, num_connect
1203 | else:
1204 | raise ValueError('Incorrect inference mode')
1205 |
1206 | # AuxAtt,
1207 | class MIMOcomWho(nn.Module):
1208 | def __init__(self, n_classes=21, in_channels=3, feat_channel=512, feat_squeezer=-1, attention='additive',
1209 | has_query=True, sparse=False, agent_num=5, shuffle_flag=False, image_size=512,
1210 | shared_img_encoder=False, key_size=128, query_size=128, enc_backbone='n_segnet_encoder',
1211 | dec_backbone='n_segnet_decoder'):
1212 | super(MIMOcomWho, self).__init__()
1213 |
1214 | self.agent_num = agent_num
1215 | self.in_channels = in_channels
1216 | self.shuffle_flag = shuffle_flag
1217 | self.feature_map_channel = 512
1218 | self.key_size = key_size
1219 | self.query_size = query_size
1220 | self.shared_img_encoder = shared_img_encoder
1221 | self.has_query = has_query
1222 | self.sparse = sparse
1223 | # Encoder
1224 |
1225 | # Non-shared
1226 | if self.shared_img_encoder == 'unified':
1227 | self.u_encoder = img_encoder(n_classes=n_classes, in_channels=in_channels, feat_channel=feat_channel,
1228 | feat_squeezer=feat_squeezer, enc_backbone=enc_backbone)
1229 | else:
1230 | raise ValueError('Incorrect shared_img_encoder flag')
1231 |
1232 | # # Message generator
1233 | self.query_key_net = policy_net4(n_classes=n_classes, in_channels=in_channels, enc_backbone=enc_backbone)
1234 | if self.has_query:
1235 | self.query_net = linear(out_size=self.query_size, input_feat_sz=image_size / 32)
1236 |
1237 | self.key_net = linear(out_size=self.key_size, input_feat_sz=image_size / 32)
1238 |
1239 | self.attention_net = MIMOWhoGeneralDotProductAttention(self.query_size, self.key_size)
1240 |
1241 | # Segmentation decoder
1242 | self.decoder = img_decoder(n_classes=n_classes, in_channels=self.feature_map_channel*2,
1243 | feat_squeezer=feat_squeezer, dec_backbone=dec_backbone)
1244 |
1245 |
1246 | # List the parameters of each modules
1247 | self.attention_paras = list(self.attention_net.parameters())
1248 | if self.shared_img_encoder == 'unified':
1249 | self.img_net_paras = list(self.u_encoder.parameters()) + list(self.decoder.parameters())
1250 |
1251 |
1252 |
1253 | self.policy_net_paras = list(self.query_key_net.parameters()) + list(
1254 | self.key_net.parameters()) + self.attention_paras
1255 | if self.has_query:
1256 | self.policy_net_paras = self.policy_net_paras + list(self.query_net.parameters())
1257 |
1258 | self.all_paras = self.img_net_paras + self.policy_net_paras
1259 |
1260 |
1261 | def shape_feat_map(self, feat, size):
1262 | return torch.unsqueeze(feat.view(-1, size[1] * size[2] * size[3]), 1)
1263 |
1264 | def argmax_select(self, val_mat, prob_action):
1265 | # v(batch, query_num, channel, size, size)
1266 | cls_num = prob_action.shape[1]
1267 |
1268 | coef_argmax = F.one_hot(prob_action.max(dim=1)[1], num_classes=cls_num).type(torch.cuda.FloatTensor)
1269 | coef_argmax = coef_argmax.transpose(1, 2)
1270 | attn_shape = coef_argmax.shape
1271 | bats, key_num, query_num = attn_shape[0], attn_shape[1], attn_shape[2]
1272 | coef_argmax_exp = coef_argmax.view(bats, key_num, query_num, 1, 1, 1)
1273 |
1274 | v_exp = torch.unsqueeze(val_mat, 2)
1275 | v_exp = v_exp.expand(-1, -1, query_num, -1, -1, -1)
1276 |
1277 | output = coef_argmax_exp * v_exp # (batch,4,channel,size,size)
1278 | feat_argmax = output.sum(1) # (batch,1,channel,size,size)
1279 |
1280 | # compute connect
1281 | count_coef = coef_argmax.clone()
1282 | ind = np.diag_indices(self.agent_num)
1283 | count_coef[:, ind[0], ind[1]] = 0
1284 | num_connect = torch.nonzero(count_coef).shape[0] / (self.agent_num * count_coef.shape[0])
1285 |
1286 | return feat_argmax, coef_argmax, num_connect
1287 |
1288 | def activated_select(self, val_mat, prob_action, thres=0.2):
1289 |
1290 | coef_act = torch.mul(prob_action, (prob_action > thres).float())
1291 | attn_shape = coef_act.shape
1292 | bats, key_num, query_num = attn_shape[0], attn_shape[1], attn_shape[2]
1293 | coef_act_exp = coef_act.view(bats, key_num, query_num, 1, 1, 1)
1294 |
1295 | v_exp = torch.unsqueeze(val_mat, 2)
1296 | v_exp = v_exp.expand(-1, -1, query_num, -1, -1, -1)
1297 |
1298 | output = coef_act_exp * v_exp # (batch,4,channel,size,size)
1299 | feat_act = output.sum(1) # (batch,1,channel,size,size)
1300 |
1301 | # compute connect
1302 | count_coef = coef_act.clone()
1303 | ind = np.diag_indices(self.agent_num)
1304 | count_coef[:, ind[0], ind[1]] = 0
1305 | num_connect = torch.nonzero(count_coef).shape[0] / (self.agent_num * count_coef.shape[0])
1306 |
1307 | return feat_act, coef_act, num_connect
1308 |
1309 | def agents2batch(self, feats):
1310 | agent_num = feats.shape[1]
1311 | feat_list = []
1312 | for i in range(agent_num):
1313 | feat_list.append(feats[:, i, :, :, :])
1314 | feat_mat = torch.cat(tuple(feat_list), 0)
1315 | return feat_mat
1316 |
1317 | def divide_inputs(self, inputs):
1318 | '''
1319 | Divide the input into a list of several images
1320 | '''
1321 | input_list = []
1322 | for i in range(self.agent_num):
1323 | input_list.append(inputs[:, 3 * i:3 * i + 3, :, :])
1324 |
1325 | return input_list
1326 |
1327 | def forward(self, inputs, training=True, MO_flag=False , inference='argmax'):
1328 | batch_size, _, _, _ = inputs.size()
1329 |
1330 | input_list = self.divide_inputs(inputs)
1331 |
1332 | if self.shared_img_encoder == 'unified':
1333 | # vectorize input list
1334 | img_list = []
1335 | for i in range(self.agent_num):
1336 | img_list.append(input_list[i])
1337 | unified_img_list = torch.cat(tuple(img_list), 0)
1338 |
1339 | # pass encoder
1340 | feat_maps = self.u_encoder(unified_img_list)
1341 |
1342 | # get feat maps for each image
1343 | feat_map = {}
1344 | feat_list = []
1345 | for i in range(self.agent_num):
1346 | feat_map[i] = torch.unsqueeze(feat_maps[batch_size * i:batch_size * (i + 1)], 1)
1347 | feat_list.append(feat_map[i])
1348 | val_mat = torch.cat(tuple(feat_list), 1)
1349 | else:
1350 | raise ValueError('Incorrect encoder')
1351 |
1352 | # pass feature maps through key and query generator
1353 | query_key_maps = self.query_key_net(unified_img_list)
1354 | keys = self.key_net(query_key_maps)
1355 | if self.has_query:
1356 | querys = self.query_net(query_key_maps)
1357 |
1358 | # get key and query
1359 | key = {}
1360 | query = {}
1361 | key_list = []
1362 | query_list = []
1363 |
1364 | for i in range(self.agent_num):
1365 | key[i] = torch.unsqueeze(keys[batch_size * i:batch_size * (i + 1)], 1)
1366 | key_list.append(key[i])
1367 | if self.has_query:
1368 | query[i] = torch.unsqueeze(querys[batch_size * i:batch_size * (i + 1)], 1)
1369 | else:
1370 | query[i] = torch.ones(batch_size, 1, self.query_size).to('cuda')
1371 | query_list.append(query[i])
1372 |
1373 | key_mat = torch.cat(tuple(key_list), 1)
1374 | query_mat = torch.cat(tuple(query_list), 1)
1375 |
1376 | if MO_flag:
1377 | query_mat = query_mat
1378 | else:
1379 | query_mat = torch.unsqueeze(query_mat[:,0,:],1)
1380 |
1381 | feat_fuse, prob_action = self.attention_net(query_mat, key_mat, val_mat, sparse=self.sparse)
1382 | fuse_map = torch.cat( (feat_fuse, val_mat), dim=2)
1383 |
1384 |
1385 | feat_fuse_mat = self.agents2batch(fuse_map)
1386 | pred = self.decoder(feat_fuse_mat)
1387 |
1388 | if training:
1389 | action = torch.argmax(prob_action, dim=1)
1390 | num_connect = self.agent_num - 1
1391 |
1392 | return pred, prob_action, action, num_connect
1393 | else:
1394 | if inference == 'softmax':
1395 | action = torch.argmax(prob_action, dim=1)
1396 | num_connect = self.agent_num - 1
1397 |
1398 | return pred, prob_action, action, num_connect
1399 |
1400 | elif inference == 'argmax_test':
1401 | feat_argmax, connect_mat, num_connect = self.argmax_select(val_mat, prob_action)
1402 | fuse_map = torch.cat((feat_argmax, val_mat), dim=2)
1403 |
1404 | # argmax feature map is passed to decoder
1405 | feat_argmax_mat = self.agents2batch(fuse_map) # (batchsize*agent_num, channel, size, size)
1406 | feat_argmax_mat = feat_argmax_mat.detach()
1407 | pred_argmax = self.decoder(feat_argmax_mat)
1408 | action = torch.argmax(prob_action, dim=1)
1409 |
1410 | return pred_argmax, prob_action, action, num_connect
1411 |
1412 | elif inference == 'activated':
1413 | feat_act, connect_mat, num_connect = self.activated_select(val_mat, prob_action)
1414 | fuse_map = torch.cat((feat_act, val_mat), dim=2)
1415 |
1416 | feat_act_mat = self.agents2batch(fuse_map) # (batchsize*agent_num, channel, size, size)
1417 | feat_act_mat = feat_act_mat.detach()
1418 | pred_act = self.decoder(feat_act_mat)
1419 | action = torch.argmax(prob_action, dim=1)
1420 |
1421 | return pred_act, prob_action, action, num_connect
1422 | else:
1423 | raise ValueError('Incorrect inference mode')
1424 |
--------------------------------------------------------------------------------
/ptsemseg/models/backbone.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torchvision.models as models
4 |
5 | import pretrainedmodels
6 |
7 | from ptsemseg.models.utils import conv2DBatchNormRelu, deconv2DBatchNormRelu, Sparsemax
8 | import random
9 |
10 |
11 |
12 | class n_segnet_encoder(nn.Module):
13 | def __init__(self, n_classes=21, in_channels=3):
14 | super(n_segnet_encoder, self).__init__()
15 | self.in_channels = in_channels
16 |
17 | # Encoder
18 | # down 1
19 | self.conv1 = conv2DBatchNormRelu(self.in_channels, 64, k_size=3, stride=1, padding=1)
20 | self.conv2 = conv2DBatchNormRelu(64 , 64, k_size=3, stride=2, padding=1)
21 |
22 | # down 2
23 | self.conv3 = conv2DBatchNormRelu(64 ,128, k_size=3, stride=1, padding=1)
24 | self.conv4 = conv2DBatchNormRelu(128 ,128, k_size=3, stride=2, padding=1)
25 |
26 | # down 3
27 | self.conv5 = conv2DBatchNormRelu(128 , 256, k_size=3, stride=1, padding=1)
28 | self.conv6 = conv2DBatchNormRelu(256 , 256, k_size=3, stride=1, padding=1)
29 | self.conv7 = conv2DBatchNormRelu(256 , 256, k_size=3, stride=2, padding=1)
30 |
31 | # down 4
32 | self.conv8 = conv2DBatchNormRelu(256 , 512, k_size=3, stride=1, padding=1)
33 | self.conv9 = conv2DBatchNormRelu(512 , 512, k_size=3, stride=1, padding=1)
34 | self.conv10= conv2DBatchNormRelu(512 , 512, k_size=3, stride=2, padding=1)
35 |
36 | # down 5
37 | self.conv11= conv2DBatchNormRelu(512 , 512, k_size=3, stride=1, padding=1)
38 | self.conv12= conv2DBatchNormRelu(512 , 512, k_size=3, stride=1, padding=1)
39 | self.conv13= conv2DBatchNormRelu(512 , 512, k_size=3, stride=2, padding=1)
40 |
41 | def forward(self, inputs):
42 | outputs = self.conv1(inputs)
43 | outputs = self.conv2(outputs)
44 | outputs = self.conv3(outputs)
45 | outputs = self.conv4(outputs)
46 | outputs = self.conv5(outputs)
47 | outputs = self.conv6(outputs)
48 | outputs = self.conv7(outputs)
49 | outputs = self.conv8(outputs)
50 | outputs = self.conv9(outputs)
51 | outputs = self.conv10(outputs)
52 | outputs = self.conv11(outputs)
53 | outputs = self.conv12(outputs)
54 | outputs = self.conv13(outputs)
55 | return outputs
56 |
57 |
58 | class resnet_encoder(nn.Module):
59 | def __init__(self, n_classes=21, in_channels=3):
60 | super(resnet_encoder, self).__init__()
61 | feat_chn = 256
62 | #self.feature_backbone = n_segnet_encoder(n_classes=n_classes, in_channels=in_channels)
63 | self.feature_backbone = pretrainedmodels.__dict__['resnet18'](num_classes=1000, pretrained=None)
64 |
65 | self.backbone_0 = self.feature_backbone.conv1
66 | self.backbone_1 = nn.Sequential(self.feature_backbone.bn1, self.feature_backbone.relu, self.feature_backbone.maxpool, self.feature_backbone.layer1)
67 | self.backbone_2 = self.feature_backbone.layer2
68 | self.backbone_3 = self.feature_backbone.layer3
69 | self.backbone_4 = self.feature_backbone.layer4
70 |
71 |
72 | def forward(self, inputs):
73 | # print('input:')
74 | # print(inputs.size())
75 | # import pdb; pdb.set_trace()
76 | outputs = self.backbone_0(inputs)
77 | # print('base_0 size: ')
78 | # print(base_0.size())
79 |
80 | outputs = self.backbone_1(outputs)
81 | # print('base_1 size: ')
82 | # print(base_1.size())
83 |
84 | outputs = self.backbone_2(outputs)
85 | # print('base_2 size: ')
86 | # print(base_2.size())
87 |
88 | outputs = self.backbone_3(outputs)
89 | # print('base_3 size: ')
90 | # print(base_3.size())
91 |
92 | outputs = self.backbone_4(outputs)
93 | # print('base_4 size: ')
94 | # print(base_4.size())
95 |
96 | return outputs
97 |
98 | ### ============= Decoder Backbone ============= ###
99 | class n_segnet_decoder(nn.Module):
100 | def __init__(self, n_classes=21, in_channels=512):
101 | #def __init__(self, n_classes=21, in_channels=512,agent_num=5):
102 | super(n_segnet_decoder, self).__init__()
103 | self.in_channels = in_channels
104 | # Decoder
105 | self.deconv1= deconv2DBatchNormRelu(self.in_channels, 512, k_size=3, stride=2, padding=1,output_padding=1)
106 | self.deconv2= conv2DBatchNormRelu(512 , 512, k_size=3, stride=1, padding=1)
107 | self.deconv3= conv2DBatchNormRelu(512 , 512, k_size=3, stride=1, padding=1)
108 |
109 | # up 4
110 | self.deconv4= deconv2DBatchNormRelu(512 , 512, k_size=3, stride=2, padding=1,output_padding=1)
111 | self.deconv5= conv2DBatchNormRelu(512 , 512, k_size=3, stride=1, padding=1)
112 | self.deconv6= conv2DBatchNormRelu(512 , 256, k_size=3, stride=1, padding=1)
113 |
114 | # up 3
115 | self.deconv7= deconv2DBatchNormRelu(256 , 256, k_size=3, stride=2, padding=1,output_padding=1)
116 | self.deconv8= conv2DBatchNormRelu(256 , 128, k_size=3, stride=1, padding=1)
117 |
118 | # up 2
119 | self.deconv9= deconv2DBatchNormRelu(128 , 128, k_size=3, stride=2, padding=1,output_padding=1)
120 | self.deconv10= conv2DBatchNormRelu(128 , 64, k_size=3, stride=1, padding=1)
121 |
122 | # up 1
123 | self.deconv11= deconv2DBatchNormRelu(64 , 64, k_size=3, stride=2, padding=1,output_padding=1)
124 | self.deconv12= conv2DBatchNormRelu(64 , n_classes, k_size=3, stride=1, padding=1)
125 |
126 | def forward(self, inputs):
127 | outputs = self.deconv1(inputs)
128 | outputs = self.deconv2(outputs)
129 | outputs = self.deconv3(outputs)
130 |
131 | outputs = self.deconv4(outputs)
132 | outputs = self.deconv5(outputs)
133 | outputs = self.deconv6(outputs)
134 | outputs = self.deconv7(outputs)
135 | outputs = self.deconv8(outputs)
136 | outputs = self.deconv9(outputs)
137 | outputs = self.deconv10(outputs)
138 | outputs = self.deconv11(outputs)
139 | outputs = self.deconv12(outputs)
140 | return outputs
141 |
142 |
143 | class simple_decoder(nn.Module):
144 | def __init__(self, n_classes=21, in_channels=512):
145 | super(simple_decoder, self).__init__()
146 | self.in_channels = in_channels
147 |
148 | feat_chn = 256
149 |
150 | self.pred = nn.Sequential(
151 | nn.Conv2d(self.in_channels, feat_chn, kernel_size=3, padding=1),
152 | nn.ReLU(inplace=True),
153 | nn.Conv2d(feat_chn, n_classes, kernel_size=3, padding=1)
154 | )
155 |
156 | def forward(self, inputs):
157 | pred = self.pred(inputs)
158 | # print('pred size: ')
159 | # print(pred.size())
160 | pred = nn.functional.interpolate(pred, size=torch.Size([inputs.size()[2]*32,inputs.size()[3]*32]), mode='bilinear', align_corners=False)
161 | # print('pred size: ')
162 | #rint(pred.size())
163 |
164 | return pred
165 |
166 |
167 | class FCN_decoder(nn.Module):
168 | def __init__(self, n_classes=21, in_channels=512):
169 | super(FCN_decoder, self).__init__()
170 | feat_chn = 256
171 |
172 | self.pred = nn.Sequential(
173 | nn.Conv2d(in_channels, feat_chn, kernel_size=3, padding=1),
174 | nn.ReLU(inplace=True),
175 | nn.Conv2d(feat_chn, n_classes, kernel_size=3, padding=1)
176 | )
177 |
178 | def forward(self, inputs):
179 | pred = self.pred(base_4)
180 | print('pred size: ')
181 | print(pred.size())
182 | pred = nn.functional.interpolate(pred, size=inputs.size()[-2:], mode='bilinear', align_corners=False)
183 | print('pred size: ')
184 | print(pred.size())
185 |
186 | return pred
187 |
188 |
--------------------------------------------------------------------------------
/ptsemseg/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 |
6 | from torch.autograd import Variable
7 |
8 |
9 | class conv2DBatchNorm(nn.Module):
10 | def __init__(
11 | self,
12 | in_channels,
13 | n_filters,
14 | k_size,
15 | stride,
16 | padding,
17 | bias=True,
18 | dilation=1,
19 | is_batchnorm=True,
20 | ):
21 | super(conv2DBatchNorm, self).__init__()
22 |
23 | conv_mod = nn.Conv2d(
24 | int(in_channels),
25 | int(n_filters),
26 | kernel_size=k_size,
27 | padding=padding,
28 | stride=stride,
29 | bias=bias,
30 | dilation=dilation,
31 | )
32 |
33 | if is_batchnorm:
34 | self.cb_unit = nn.Sequential(conv_mod, nn.BatchNorm2d(int(n_filters)))
35 | else:
36 | self.cb_unit = nn.Sequential(conv_mod)
37 |
38 | def forward(self, inputs):
39 | outputs = self.cb_unit(inputs)
40 | return outputs
41 |
42 |
43 | class conv2DGroupNorm(nn.Module):
44 | def __init__(
45 | self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, n_groups=16
46 | ):
47 | super(conv2DGroupNorm, self).__init__()
48 |
49 | conv_mod = nn.Conv2d(
50 | int(in_channels),
51 | int(n_filters),
52 | kernel_size=k_size,
53 | padding=padding,
54 | stride=stride,
55 | bias=bias,
56 | dilation=dilation,
57 | )
58 |
59 | self.cg_unit = nn.Sequential(conv_mod, nn.GroupNorm(n_groups, int(n_filters)))
60 |
61 | def forward(self, inputs):
62 | outputs = self.cg_unit(inputs)
63 | return outputs
64 |
65 |
66 | class deconv2DBatchNorm(nn.Module):
67 | def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True):
68 | super(deconv2DBatchNorm, self).__init__()
69 |
70 | self.dcb_unit = nn.Sequential(
71 | nn.ConvTranspose2d(
72 | int(in_channels),
73 | int(n_filters),
74 | kernel_size=k_size,
75 | padding=padding,
76 | stride=stride,
77 | bias=bias,
78 | ),
79 | nn.BatchNorm2d(int(n_filters)),
80 | )
81 |
82 | def forward(self, inputs):
83 | outputs = self.dcb_unit(inputs)
84 | return outputs
85 |
86 |
87 | class conv2DBatchNormRelu(nn.Module):
88 | def __init__(
89 | self,
90 | in_channels,
91 | n_filters,
92 | k_size,
93 | stride,
94 | padding,
95 | bias=True,
96 | dilation=1,
97 | is_batchnorm=True,
98 | ):
99 | super(conv2DBatchNormRelu, self).__init__()
100 |
101 | conv_mod = nn.Conv2d(
102 | int(in_channels),
103 | int(n_filters),
104 | kernel_size=k_size,
105 | padding=padding,
106 | stride=stride,
107 | bias=bias,
108 | dilation=dilation,
109 | )
110 |
111 | if is_batchnorm:
112 | self.cbr_unit = nn.Sequential(
113 | conv_mod, nn.BatchNorm2d(int(n_filters)), nn.ReLU(inplace=True)
114 | )
115 | else:
116 | self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True))
117 |
118 | def forward(self, inputs):
119 | outputs = self.cbr_unit(inputs)
120 | return outputs
121 |
122 |
123 | class conv2DGroupNormRelu(nn.Module):
124 | def __init__(
125 | self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, n_groups=16
126 | ):
127 | super(conv2DGroupNormRelu, self).__init__()
128 |
129 | conv_mod = nn.Conv2d(
130 | int(in_channels),
131 | int(n_filters),
132 | kernel_size=k_size,
133 | padding=padding,
134 | stride=stride,
135 | bias=bias,
136 | dilation=dilation,
137 | )
138 |
139 | self.cgr_unit = nn.Sequential(
140 | conv_mod, nn.GroupNorm(n_groups, int(n_filters)), nn.ReLU(inplace=True)
141 | )
142 |
143 | def forward(self, inputs):
144 | outputs = self.cgr_unit(inputs)
145 | return outputs
146 |
147 |
148 | class deconv2DBatchNormRelu(nn.Module):
149 | def __init__(self, in_channels, n_filters, k_size, stride, padding, output_padding,bias=True):
150 | super(deconv2DBatchNormRelu, self).__init__()
151 |
152 | self.dcbr_unit = nn.Sequential(
153 | nn.ConvTranspose2d(
154 | int(in_channels),
155 | int(n_filters),
156 | kernel_size=k_size,
157 | padding=padding,
158 | output_padding=output_padding,
159 | stride=stride,
160 | bias=bias,
161 | ),
162 | nn.BatchNorm2d(int(n_filters)),
163 | nn.ReLU(inplace=True),
164 | )
165 |
166 | def forward(self, inputs):
167 | outputs = self.dcbr_unit(inputs)
168 | return outputs
169 |
170 |
171 | class unetConv2(nn.Module):
172 | def __init__(self, in_size, out_size, is_batchnorm):
173 | super(unetConv2, self).__init__()
174 |
175 | if is_batchnorm:
176 | self.conv1 = nn.Sequential(
177 | nn.Conv2d(in_size, out_size, 3, 1, 0), nn.BatchNorm2d(out_size), nn.ReLU()
178 | )
179 | self.conv2 = nn.Sequential(
180 | nn.Conv2d(out_size, out_size, 3, 1, 0), nn.BatchNorm2d(out_size), nn.ReLU()
181 | )
182 | else:
183 | self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 0), nn.ReLU())
184 | self.conv2 = nn.Sequential(nn.Conv2d(out_size, out_size, 3, 1, 0), nn.ReLU())
185 |
186 | def forward(self, inputs):
187 | outputs = self.conv1(inputs)
188 | outputs = self.conv2(outputs)
189 | return outputs
190 |
191 |
192 | class unetUp(nn.Module):
193 | def __init__(self, in_size, out_size, is_deconv):
194 | super(unetUp, self).__init__()
195 | self.conv = unetConv2(in_size, out_size, False)
196 | if is_deconv:
197 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
198 | else:
199 | self.up = nn.UpsamplingBilinear2d(scale_factor=2)
200 |
201 | def forward(self, inputs1, inputs2):
202 | outputs2 = self.up(inputs2)
203 | offset = outputs2.size()[2] - inputs1.size()[2]
204 | padding = 2 * [offset // 2, offset // 2]
205 | outputs1 = F.pad(inputs1, padding)
206 | return self.conv(torch.cat([outputs1, outputs2], 1))
207 |
208 |
209 | class segnetDown2(nn.Module):
210 | def __init__(self, in_size, out_size):
211 | super(segnetDown2, self).__init__()
212 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, k_size=3, stride=1, padding=1)
213 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, k_size=3, stride=1, padding=1)
214 | self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)
215 |
216 | def forward(self, inputs):
217 | outputs = self.conv1(inputs)
218 | outputs = self.conv2(outputs)
219 | unpooled_shape = outputs.size()
220 | outputs, indices = self.maxpool_with_argmax(outputs)
221 | return outputs, indices, unpooled_shape
222 |
223 |
224 | class segnetDown3(nn.Module):
225 | def __init__(self, in_size, out_size):
226 | super(segnetDown3, self).__init__()
227 | self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
228 | self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1)
229 | self.conv3 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1)
230 | self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)
231 |
232 | def forward(self, inputs):
233 | outputs = self.conv1(inputs)
234 | outputs = self.conv2(outputs)
235 | outputs = self.conv3(outputs)
236 | unpooled_shape = outputs.size()
237 | outputs, indices = self.maxpool_with_argmax(outputs)
238 | return outputs, indices, unpooled_shape
239 |
240 |
241 | class segnetUp2(nn.Module):
242 | def __init__(self, in_size, out_size):
243 | super(segnetUp2, self).__init__()
244 | self.unpool = nn.MaxUnpool2d(2, 2)
245 | self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1)
246 | self.conv2 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
247 |
248 | def forward(self, inputs, indices, output_shape):
249 | outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape)
250 | outputs = self.conv1(outputs)
251 | outputs = self.conv2(outputs)
252 | return outputs
253 |
254 |
255 | class segnetUp3(nn.Module):
256 | def __init__(self, in_size, out_size):
257 | super(segnetUp3, self).__init__()
258 | self.unpool = nn.MaxUnpool2d(2, 2)
259 | self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1)
260 | self.conv2 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1)
261 | self.conv3 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1)
262 |
263 | def forward(self, inputs, indices, output_shape):
264 | outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape)
265 | outputs = self.conv1(outputs)
266 | outputs = self.conv2(outputs)
267 | outputs = self.conv3(outputs)
268 | return outputs
269 |
270 |
271 |
272 |
273 |
274 |
275 | class residualBlock(nn.Module):
276 | expansion = 1
277 |
278 | def __init__(self, in_channels, n_filters, stride=1, downsample=None):
279 | super(residualBlock, self).__init__()
280 |
281 | self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, 1, bias=False)
282 | self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, bias=False)
283 | self.downsample = downsample
284 | self.stride = stride
285 | self.relu = nn.ReLU(inplace=True)
286 |
287 | def forward(self, x):
288 | residual = x
289 |
290 | out = self.convbnrelu1(x)
291 | out = self.convbn2(out)
292 |
293 | if self.downsample is not None:
294 | residual = self.downsample(x)
295 |
296 | out += residual
297 | out = self.relu(out)
298 | return out
299 |
300 |
301 | class residualBottleneck(nn.Module):
302 | expansion = 4
303 |
304 | def __init__(self, in_channels, n_filters, stride=1, downsample=None):
305 | super(residualBottleneck, self).__init__()
306 | self.convbn1 = nn.Conv2DBatchNorm(in_channels, n_filters, k_size=1, bias=False)
307 | self.convbn2 = nn.Conv2DBatchNorm(
308 | n_filters, n_filters, k_size=3, padding=1, stride=stride, bias=False
309 | )
310 | self.convbn3 = nn.Conv2DBatchNorm(n_filters, n_filters * 4, k_size=1, bias=False)
311 | self.relu = nn.ReLU(inplace=True)
312 | self.downsample = downsample
313 | self.stride = stride
314 |
315 | def forward(self, x):
316 | residual = x
317 |
318 | out = self.convbn1(x)
319 | out = self.convbn2(out)
320 | out = self.convbn3(out)
321 |
322 | if self.downsample is not None:
323 | residual = self.downsample(x)
324 |
325 | out += residual
326 | out = self.relu(out)
327 |
328 | return out
329 |
330 |
331 | class linknetUp(nn.Module):
332 | def __init__(self, in_channels, n_filters):
333 | super(linknetUp, self).__init__()
334 |
335 | # B, 2C, H, W -> B, C/2, H, W
336 | self.convbnrelu1 = conv2DBatchNormRelu(
337 | in_channels, n_filters / 2, k_size=1, stride=1, padding=1
338 | )
339 |
340 | # B, C/2, H, W -> B, C/2, H, W
341 | self.deconvbnrelu2 = nn.deconv2DBatchNormRelu(
342 | n_filters / 2, n_filters / 2, k_size=3, stride=2, padding=0
343 | )
344 |
345 | # B, C/2, H, W -> B, C, H, W
346 | self.convbnrelu3 = conv2DBatchNormRelu(
347 | n_filters / 2, n_filters, k_size=1, stride=1, padding=1
348 | )
349 |
350 | def forward(self, x):
351 | x = self.convbnrelu1(x)
352 | x = self.deconvbnrelu2(x)
353 | x = self.convbnrelu3(x)
354 | return x
355 |
356 |
357 | class FRRU(nn.Module):
358 | """
359 | Full Resolution Residual Unit for FRRN
360 | """
361 |
362 | def __init__(self, prev_channels, out_channels, scale, group_norm=False, n_groups=None):
363 | super(FRRU, self).__init__()
364 | self.scale = scale
365 | self.prev_channels = prev_channels
366 | self.out_channels = out_channels
367 | self.group_norm = group_norm
368 | self.n_groups = n_groups
369 |
370 | if self.group_norm:
371 | conv_unit = conv2DGroupNormRelu
372 | self.conv1 = conv_unit(
373 | prev_channels + 32,
374 | out_channels,
375 | k_size=3,
376 | stride=1,
377 | padding=1,
378 | bias=False,
379 | n_groups=self.n_groups,
380 | )
381 | self.conv2 = conv_unit(
382 | out_channels,
383 | out_channels,
384 | k_size=3,
385 | stride=1,
386 | padding=1,
387 | bias=False,
388 | n_groups=self.n_groups,
389 | )
390 |
391 | else:
392 | conv_unit = conv2DBatchNormRelu
393 | self.conv1 = conv_unit(
394 | prev_channels + 32, out_channels, k_size=3, stride=1, padding=1, bias=False
395 | )
396 | self.conv2 = conv_unit(
397 | out_channels, out_channels, k_size=3, stride=1, padding=1, bias=False
398 | )
399 |
400 | self.conv_res = nn.Conv2d(out_channels, 32, kernel_size=1, stride=1, padding=0)
401 |
402 | def forward(self, y, z):
403 | x = torch.cat([y, nn.MaxPool2d(self.scale, self.scale)(z)], dim=1)
404 | y_prime = self.conv1(x)
405 | y_prime = self.conv2(y_prime)
406 |
407 | x = self.conv_res(y_prime)
408 | upsample_size = torch.Size([_s * self.scale for _s in y_prime.shape[-2:]])
409 | x = F.upsample(x, size=upsample_size, mode="nearest")
410 | z_prime = z + x
411 |
412 | return y_prime, z_prime
413 |
414 |
415 | class RU(nn.Module):
416 | """
417 | Residual Unit for FRRN
418 | """
419 |
420 | def __init__(self, channels, kernel_size=3, strides=1, group_norm=False, n_groups=None):
421 | super(RU, self).__init__()
422 | self.group_norm = group_norm
423 | self.n_groups = n_groups
424 |
425 | if self.group_norm:
426 | self.conv1 = conv2DGroupNormRelu(
427 | channels,
428 | channels,
429 | k_size=kernel_size,
430 | stride=strides,
431 | padding=1,
432 | bias=False,
433 | n_groups=self.n_groups,
434 | )
435 | self.conv2 = conv2DGroupNorm(
436 | channels,
437 | channels,
438 | k_size=kernel_size,
439 | stride=strides,
440 | padding=1,
441 | bias=False,
442 | n_groups=self.n_groups,
443 | )
444 |
445 | else:
446 | self.conv1 = conv2DBatchNormRelu(
447 | channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False
448 | )
449 | self.conv2 = conv2DBatchNorm(
450 | channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False
451 | )
452 |
453 | def forward(self, x):
454 | incoming = x
455 | x = self.conv1(x)
456 | x = self.conv2(x)
457 | return x + incoming
458 |
459 |
460 | class residualConvUnit(nn.Module):
461 | def __init__(self, channels, kernel_size=3):
462 | super(residualConvUnit, self).__init__()
463 |
464 | self.residual_conv_unit = nn.Sequential(
465 | nn.ReLU(inplace=True),
466 | nn.Conv2d(channels, channels, kernel_size=kernel_size),
467 | nn.ReLU(inplace=True),
468 | nn.Conv2d(channels, channels, kernel_size=kernel_size),
469 | )
470 |
471 | def forward(self, x):
472 | input = x
473 | x = self.residual_conv_unit(x)
474 | return x + input
475 |
476 |
477 | class multiResolutionFusion(nn.Module):
478 | def __init__(self, channels, up_scale_high, up_scale_low, high_shape, low_shape):
479 | super(multiResolutionFusion, self).__init__()
480 |
481 | self.up_scale_high = up_scale_high
482 | self.up_scale_low = up_scale_low
483 |
484 | self.conv_high = nn.Conv2d(high_shape[1], channels, kernel_size=3)
485 |
486 | if low_shape is not None:
487 | self.conv_low = nn.Conv2d(low_shape[1], channels, kernel_size=3)
488 |
489 | def forward(self, x_high, x_low):
490 | high_upsampled = F.upsample(
491 | self.conv_high(x_high), scale_factor=self.up_scale_high, mode="bilinear"
492 | )
493 |
494 | if x_low is None:
495 | return high_upsampled
496 |
497 | low_upsampled = F.upsample(
498 | self.conv_low(x_low), scale_factor=self.up_scale_low, mode="bilinear"
499 | )
500 |
501 | return low_upsampled + high_upsampled
502 |
503 |
504 | class chainedResidualPooling(nn.Module):
505 | def __init__(self, channels, input_shape):
506 | super(chainedResidualPooling, self).__init__()
507 |
508 | self.chained_residual_pooling = nn.Sequential(
509 | nn.ReLU(inplace=True),
510 | nn.MaxPool2d(5, 1, 2),
511 | nn.Conv2d(input_shape[1], channels, kernel_size=3),
512 | )
513 |
514 | def forward(self, x):
515 | input = x
516 | x = self.chained_residual_pooling(x)
517 | return x + input
518 |
519 |
520 | class pyramidPooling(nn.Module):
521 | def __init__(
522 | self, in_channels, pool_sizes, model_name="pspnet", fusion_mode="cat", is_batchnorm=True
523 | ):
524 | super(pyramidPooling, self).__init__()
525 |
526 | bias = not is_batchnorm
527 |
528 | self.paths = []
529 | for i in range(len(pool_sizes)):
530 | self.paths.append(
531 | conv2DBatchNormRelu(
532 | in_channels,
533 | int(in_channels / len(pool_sizes)),
534 | 1,
535 | 1,
536 | 0,
537 | bias=bias,
538 | is_batchnorm=is_batchnorm,
539 | )
540 | )
541 |
542 | self.path_module_list = nn.ModuleList(self.paths)
543 | self.pool_sizes = pool_sizes
544 | self.model_name = model_name
545 | self.fusion_mode = fusion_mode
546 |
547 | def forward(self, x):
548 | h, w = x.shape[2:]
549 |
550 | if self.training or self.model_name != "icnet": # general settings or pspnet
551 | k_sizes = []
552 | strides = []
553 | for pool_size in self.pool_sizes:
554 | k_sizes.append((int(h / pool_size), int(w / pool_size)))
555 | strides.append((int(h / pool_size), int(w / pool_size)))
556 | else: # eval mode and icnet: pre-trained for 1025 x 2049
557 | k_sizes = [(8, 15), (13, 25), (17, 33), (33, 65)]
558 | strides = [(5, 10), (10, 20), (16, 32), (33, 65)]
559 |
560 | if self.fusion_mode == "cat": # pspnet: concat (including x)
561 | output_slices = [x]
562 |
563 | for i, (module, pool_size) in enumerate(zip(self.path_module_list, self.pool_sizes)):
564 | out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
565 | # out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size))
566 | if self.model_name != "icnet":
567 | out = module(out)
568 | out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True)
569 | output_slices.append(out)
570 |
571 | return torch.cat(output_slices, dim=1)
572 | else: # icnet: element-wise sum (including x)
573 | pp_sum = x
574 |
575 | for i, (module, pool_size) in enumerate(zip(self.path_module_list, self.pool_sizes)):
576 | out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0)
577 | # out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size))
578 | if self.model_name != "icnet":
579 | out = module(out)
580 | out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True)
581 | pp_sum = pp_sum + out
582 |
583 | return pp_sum
584 |
585 |
586 | class bottleNeckPSP(nn.Module):
587 | def __init__(
588 | self, in_channels, mid_channels, out_channels, stride, dilation=1, is_batchnorm=True
589 | ):
590 | super(bottleNeckPSP, self).__init__()
591 |
592 | bias = not is_batchnorm
593 |
594 | self.cbr1 = conv2DBatchNormRelu(
595 | in_channels, mid_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm
596 | )
597 | if dilation > 1:
598 | self.cbr2 = conv2DBatchNormRelu(
599 | mid_channels,
600 | mid_channels,
601 | 3,
602 | stride=stride,
603 | padding=dilation,
604 | bias=bias,
605 | dilation=dilation,
606 | is_batchnorm=is_batchnorm,
607 | )
608 | else:
609 | self.cbr2 = conv2DBatchNormRelu(
610 | mid_channels,
611 | mid_channels,
612 | 3,
613 | stride=stride,
614 | padding=1,
615 | bias=bias,
616 | dilation=1,
617 | is_batchnorm=is_batchnorm,
618 | )
619 | self.cb3 = conv2DBatchNorm(
620 | mid_channels, out_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm
621 | )
622 | self.cb4 = conv2DBatchNorm(
623 | in_channels,
624 | out_channels,
625 | 1,
626 | stride=stride,
627 | padding=0,
628 | bias=bias,
629 | is_batchnorm=is_batchnorm,
630 | )
631 |
632 | def forward(self, x):
633 | conv = self.cb3(self.cbr2(self.cbr1(x)))
634 | residual = self.cb4(x)
635 | return F.relu(conv + residual, inplace=True)
636 |
637 |
638 | class bottleNeckIdentifyPSP(nn.Module):
639 | def __init__(self, in_channels, mid_channels, stride, dilation=1, is_batchnorm=True):
640 | super(bottleNeckIdentifyPSP, self).__init__()
641 |
642 | bias = not is_batchnorm
643 |
644 | self.cbr1 = conv2DBatchNormRelu(
645 | in_channels, mid_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm
646 | )
647 | if dilation > 1:
648 | self.cbr2 = conv2DBatchNormRelu(
649 | mid_channels,
650 | mid_channels,
651 | 3,
652 | stride=1,
653 | padding=dilation,
654 | bias=bias,
655 | dilation=dilation,
656 | is_batchnorm=is_batchnorm,
657 | )
658 | else:
659 | self.cbr2 = conv2DBatchNormRelu(
660 | mid_channels,
661 | mid_channels,
662 | 3,
663 | stride=1,
664 | padding=1,
665 | bias=bias,
666 | dilation=1,
667 | is_batchnorm=is_batchnorm,
668 | )
669 | self.cb3 = conv2DBatchNorm(
670 | mid_channels, in_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm
671 | )
672 |
673 | def forward(self, x):
674 | residual = x
675 | x = self.cb3(self.cbr2(self.cbr1(x)))
676 | return F.relu(x + residual, inplace=True)
677 |
678 |
679 | class residualBlockPSP(nn.Module):
680 | def __init__(
681 | self,
682 | n_blocks,
683 | in_channels,
684 | mid_channels,
685 | out_channels,
686 | stride,
687 | dilation=1,
688 | include_range="all",
689 | is_batchnorm=True,
690 | ):
691 | super(residualBlockPSP, self).__init__()
692 |
693 | if dilation > 1:
694 | stride = 1
695 |
696 | # residualBlockPSP = convBlockPSP + identityBlockPSPs
697 | layers = []
698 | if include_range in ["all", "conv"]:
699 | layers.append(
700 | bottleNeckPSP(
701 | in_channels,
702 | mid_channels,
703 | out_channels,
704 | stride,
705 | dilation,
706 | is_batchnorm=is_batchnorm,
707 | )
708 | )
709 | if include_range in ["all", "identity"]:
710 | for i in range(n_blocks - 1):
711 | layers.append(
712 | bottleNeckIdentifyPSP(
713 | out_channels, mid_channels, stride, dilation, is_batchnorm=is_batchnorm
714 | )
715 | )
716 |
717 | self.layers = nn.Sequential(*layers)
718 |
719 | def forward(self, x):
720 | return self.layers(x)
721 |
722 |
723 | class cascadeFeatureFusion(nn.Module):
724 | def __init__(
725 | self, n_classes, low_in_channels, high_in_channels, out_channels, is_batchnorm=True
726 | ):
727 | super(cascadeFeatureFusion, self).__init__()
728 |
729 | bias = not is_batchnorm
730 |
731 | self.low_dilated_conv_bn = conv2DBatchNorm(
732 | low_in_channels,
733 | out_channels,
734 | 3,
735 | stride=1,
736 | padding=2,
737 | bias=bias,
738 | dilation=2,
739 | is_batchnorm=is_batchnorm,
740 | )
741 | self.low_classifier_conv = nn.Conv2d(
742 | int(low_in_channels),
743 | int(n_classes),
744 | kernel_size=1,
745 | padding=0,
746 | stride=1,
747 | bias=True,
748 | dilation=1,
749 | ) # Train only
750 | self.high_proj_conv_bn = conv2DBatchNorm(
751 | high_in_channels,
752 | out_channels,
753 | 1,
754 | stride=1,
755 | padding=0,
756 | bias=bias,
757 | is_batchnorm=is_batchnorm,
758 | )
759 |
760 | def forward(self, x_low, x_high):
761 | x_low_upsampled = F.interpolate(
762 | x_low, size=get_interp_size(x_low, z_factor=2), mode="bilinear", align_corners=True
763 | )
764 |
765 | low_cls = self.low_classifier_conv(x_low_upsampled)
766 |
767 | low_fm = self.low_dilated_conv_bn(x_low_upsampled)
768 | high_fm = self.high_proj_conv_bn(x_high)
769 | high_fused_fm = F.relu(low_fm + high_fm, inplace=True)
770 |
771 | return high_fused_fm, low_cls
772 |
773 |
774 | def get_interp_size(input, s_factor=1, z_factor=1): # for caffe
775 | ori_h, ori_w = input.shape[2:]
776 |
777 | # shrink (s_factor >= 1)
778 | ori_h = (ori_h - 1) / s_factor + 1
779 | ori_w = (ori_w - 1) / s_factor + 1
780 |
781 | # zoom (z_factor >= 1)
782 | ori_h = ori_h + (ori_h - 1) * (z_factor - 1)
783 | ori_w = ori_w + (ori_w - 1) * (z_factor - 1)
784 |
785 | resize_shape = (int(ori_h), int(ori_w))
786 | return resize_shape
787 |
788 |
789 | def interp(input, output_size, mode="bilinear"):
790 | n, c, ih, iw = input.shape
791 | oh, ow = output_size
792 |
793 | # normalize to [-1, 1]
794 | h = torch.arange(0, oh, dtype=torch.float, device=input.device) / (oh - 1) * 2 - 1
795 | w = torch.arange(0, ow, dtype=torch.float, device=input.device) / (ow - 1) * 2 - 1
796 |
797 | grid = torch.zeros(oh, ow, 2, dtype=torch.float, device=input.device)
798 | grid[:, :, 0] = w.unsqueeze(0).repeat(oh, 1)
799 | grid[:, :, 1] = h.unsqueeze(0).repeat(ow, 1).transpose(0, 1)
800 | grid = grid.unsqueeze(0).repeat(n, 1, 1, 1) # grid.shape: [n, oh, ow, 2]
801 | grid = Variable(grid)
802 | if input.is_cuda:
803 | grid = grid.cuda()
804 |
805 | return F.grid_sample(input, grid, mode=mode)
806 |
807 |
808 | def get_upsampling_weight(in_channels, out_channels, kernel_size):
809 | """Make a 2D bilinear kernel suitable for upsampling"""
810 | factor = (kernel_size + 1) // 2
811 | if kernel_size % 2 == 1:
812 | center = factor - 1
813 | else:
814 | center = factor - 0.5
815 | og = np.ogrid[:kernel_size, :kernel_size]
816 | filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
817 | weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64)
818 | weight[range(in_channels), range(out_channels), :, :] = filt
819 | return torch.from_numpy(weight).float()
820 |
821 | class Sparsemax(nn.Module):
822 | """Sparsemax function."""
823 |
824 | def __init__(self, dim=None):
825 | """Initialize sparsemax activation
826 |
827 | Args:
828 | dim (int, optional): The dimension over which to apply the sparsemax function.
829 | """
830 | super(Sparsemax, self).__init__()
831 |
832 | self.dim = -1 if dim is None else dim
833 |
834 | def forward(self, input):
835 | """Forward function.
836 | Args:
837 | input (torch.Tensor): Input tensor. First dimension should be the batch size
838 | Returns:
839 | torch.Tensor: [batch_size x number_of_logits] Output tensor
840 | """
841 | # Sparsemax currently only handles 2-dim tensors,
842 | # so we reshape and reshape back after sparsemax
843 | original_size = input.size()
844 | input = input.view(-1, input.size(self.dim))
845 |
846 | dim = 1
847 | number_of_logits = input.size(dim)
848 |
849 | # Translate input by max for numerical stability
850 | input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)
851 |
852 | # Sort input in descending order.
853 | # (NOTE: Can be replaced with linear time selection method described here:
854 | # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html)
855 | zs = torch.sort(input=input, dim=dim, descending=True)[0]
856 | range = torch.range(start=1, end=number_of_logits, device=input.device).view(1, -1)
857 | range = range.expand_as(zs)
858 |
859 | # Determine sparsity of projection
860 | bound = 1 + range * zs
861 | cumulative_sum_zs = torch.cumsum(zs, dim)
862 | is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
863 | k = torch.max(is_gt * range, dim, keepdim=True)[0]
864 |
865 | # Compute threshold function
866 | zs_sparse = is_gt * zs
867 |
868 | # Compute taus
869 | taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
870 | taus = taus.expand_as(input)
871 |
872 | # Sparsemax
873 | self.output = torch.max(torch.zeros_like(input), input - taus)
874 |
875 | output = self.output.view(original_size)
876 |
877 | return output
878 |
879 | def backward(self, grad_output):
880 | """Backward function."""
881 | dim = 1
882 |
883 | nonzeros = torch.ne(self.output, 0)
884 | sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
885 | self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))
886 |
887 | return self.grad_input
888 |
--------------------------------------------------------------------------------
/ptsemseg/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from torch.optim import SGD, Adam, ASGD, Adamax, Adadelta, Adagrad, RMSprop
4 |
5 | logger = logging.getLogger("ptsemseg")
6 |
7 | key2opt = {
8 | "sgd": SGD,
9 | "adam": Adam,
10 | "asgd": ASGD,
11 | "adamax": Adamax,
12 | "adadelta": Adadelta,
13 | "adagrad": Adagrad,
14 | "rmsprop": RMSprop,
15 | }
16 |
17 |
18 | def get_optimizer(cfg):
19 | if cfg["training"]["optimizer"] is None:
20 | logger.info("Using SGD optimizer")
21 | return SGD
22 |
23 | else:
24 | opt_name = cfg["training"]["optimizer"]["name"]
25 | if opt_name not in key2opt:
26 | raise NotImplementedError("Optimizer {} not implemented".format(opt_name))
27 |
28 | logger.info("Using {} optimizer".format(opt_name))
29 | return key2opt[opt_name]
30 |
--------------------------------------------------------------------------------
/ptsemseg/probe.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def get_vectorize_grad(model_parameter):
4 | grad_vec = None
5 | for i, para in enumerate(model_parameter):
6 | if para.requires_grad == True:
7 | grad = para.grad.view(para.grad.numel())
8 | if i == 0:
9 | grad_vec = grad
10 | else:
11 | grad_vec = torch.cat((grad_vec,grad), 0)
12 | return grad_vec
13 |
--------------------------------------------------------------------------------
/ptsemseg/process_img.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.misc as m
3 | import time
4 | import cv2
5 | from torch.autograd import Variable
6 | def generate_noise(img,noisy_type=None):
7 | '''
8 | Input: img: RGB image, noisy_type: string of noisy type
9 | generate noisy image
10 | * image must be RGB
11 | '''
12 | image_batch = img.shape[0]
13 | img_ch = img.shape[1]
14 | img_row = img.shape[2]
15 | img_col = img.shape[3]
16 |
17 | if noisy_type == 'occlusion':
18 | #print('Noisy_type: Occlusion')
19 | img[:,:,int(img_row/5):(img_row),:] = 0
20 | elif noisy_type == 'random_noisy':
21 | noise = Variable(img.data.new(img.size()).normal_(0, 0.8))
22 | img = img + noise
23 | img_np = img.data.numpy()
24 | #m.imsave('noisy_image.png',img_np[0].transpose(1,2,0))
25 | elif noisy_type == 'grayscale':
26 | #print('Noisy_type: Grayscale')
27 | img = np.dot(img[...,:3], [0.299, 0.587, 0.114])
28 | elif noisy_type == 'low_resolution':
29 | #print('Noisy_type: Low resolution (but now is original image)')
30 | pass
31 | else:
32 | # print('Noisy_type: original image)')
33 | pass
34 |
35 | return img
36 |
37 | def recursive_glob(rootdir=".", suffix=""):
38 | """Performs recursive glob with given suffix and rootdir
39 | :param rootdir is the root directory
40 | :param suffix is the suffix to be searched
41 | """
42 | return [
43 | os.path.join(looproot, filename)
44 | for looproot, _, filenames in os.walk(rootdir)
45 | for filename in filenames
46 | if filename.endswith(suffix)
47 | ]
48 |
49 |
50 | def alpha_blend(input_image, segmentation_mask, alpha=0.5):
51 | """Alpha Blending utility to overlay RGB masks on RBG images
52 | :param input_image is a np.ndarray with 3 channels
53 | :param segmentation_mask is a np.ndarray with 3 channels
54 | :param alpha is a float value
55 | """
56 | blended = np.zeros(input_image.size, dtype=np.float32)
57 | blended = input_image * alpha + segmentation_mask * (1 - alpha)
58 | return blended
59 |
60 |
61 | def convert_state_dict(state_dict):
62 | """Converts a state dict saved from a dataParallel module to normal
63 | module state_dict inplace
64 | :param state_dict is the loaded DataParallel model_state
65 | """
66 | new_state_dict = OrderedDict()
67 | for k, v in state_dict.items():
68 | name = k[7:] # remove `module.`
69 | new_state_dict[name] = v
70 | return new_state_dict
71 |
72 |
73 | def get_logger(logdir):
74 | logger = logging.getLogger("ptsemseg")
75 | ts = str(datetime.datetime.now()).split(".")[0].replace(" ", "_")
76 | ts = ts.replace(":", "_").replace("-", "_")
77 | file_path = os.path.join(logdir, "run_{}.log".format(ts))
78 | hdlr = logging.FileHandler(file_path)
79 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
80 | hdlr.setFormatter(formatter)
81 | logger.addHandler(hdlr)
82 | logger.setLevel(logging.INFO)
83 | return logger
84 |
--------------------------------------------------------------------------------
/ptsemseg/schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR, CosineAnnealingLR
4 |
5 | from ptsemseg.schedulers.schedulers import WarmUpLR, ConstantLR, PolynomialLR
6 |
7 | logger = logging.getLogger("ptsemseg")
8 |
9 | key2scheduler = {
10 | "constant_lr": ConstantLR,
11 | "poly_lr": PolynomialLR,
12 | "multi_step": MultiStepLR,
13 | "cosine_annealing": CosineAnnealingLR,
14 | "exp_lr": ExponentialLR,
15 | }
16 |
17 |
18 | def get_scheduler(optimizer, scheduler_dict):
19 | if scheduler_dict is None:
20 | logger.info("Using No LR Scheduling")
21 | return ConstantLR(optimizer)
22 |
23 | s_type = scheduler_dict["name"]
24 | scheduler_dict.pop("name")
25 |
26 | logging.info("Using {} scheduler with {} params".format(s_type, scheduler_dict))
27 |
28 | warmup_dict = {}
29 | if "warmup_iters" in scheduler_dict:
30 | # This can be done in a more pythonic way...
31 | warmup_dict["warmup_iters"] = scheduler_dict.get("warmup_iters", 100)
32 | warmup_dict["mode"] = scheduler_dict.get("warmup_mode", "linear")
33 | warmup_dict["gamma"] = scheduler_dict.get("warmup_factor", 0.2)
34 |
35 | logger.info(
36 | "Using Warmup with {} iters {} gamma and {} mode".format(
37 | warmup_dict["warmup_iters"], warmup_dict["gamma"], warmup_dict["mode"]
38 | )
39 | )
40 |
41 | scheduler_dict.pop("warmup_iters", None)
42 | scheduler_dict.pop("warmup_mode", None)
43 | scheduler_dict.pop("warmup_factor", None)
44 |
45 | base_scheduler = key2scheduler[s_type](optimizer, **scheduler_dict)
46 | return WarmUpLR(optimizer, base_scheduler, **warmup_dict)
47 |
48 | return key2scheduler[s_type](optimizer, **scheduler_dict)
49 |
--------------------------------------------------------------------------------
/ptsemseg/schedulers/schedulers.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 |
3 |
4 | class ConstantLR(_LRScheduler):
5 | def __init__(self, optimizer, last_epoch=-1):
6 | super(ConstantLR, self).__init__(optimizer, last_epoch)
7 |
8 | def get_lr(self):
9 | return [base_lr for base_lr in self.base_lrs]
10 |
11 |
12 | class PolynomialLR(_LRScheduler):
13 | def __init__(self, optimizer, max_iter, decay_iter=1, gamma=0.9, last_epoch=-1):
14 | self.decay_iter = decay_iter
15 | self.max_iter = max_iter
16 | self.gamma = gamma
17 | super(PolynomialLR, self).__init__(optimizer, last_epoch)
18 |
19 | def get_lr(self):
20 | if self.last_epoch % self.decay_iter or self.last_epoch % self.max_iter:
21 | return [base_lr for base_lr in self.base_lrs]
22 | else:
23 | factor = (1 - self.last_epoch / float(self.max_iter)) ** self.gamma
24 | return [base_lr * factor for base_lr in self.base_lrs]
25 |
26 |
27 | class WarmUpLR(_LRScheduler):
28 | def __init__(
29 | self, optimizer, scheduler, mode="linear", warmup_iters=100, gamma=0.2, last_epoch=-1
30 | ):
31 | self.mode = mode
32 | self.scheduler = scheduler
33 | self.warmup_iters = warmup_iters
34 | self.gamma = gamma
35 | super(WarmUpLR, self).__init__(optimizer, last_epoch)
36 |
37 | def get_lr(self):
38 | cold_lrs = self.scheduler.get_lr()
39 |
40 | if self.last_epoch < self.warmup_iters:
41 | if self.mode == "linear":
42 | alpha = self.last_epoch / float(self.warmup_iters)
43 | factor = self.gamma * (1 - alpha) + alpha
44 |
45 | elif self.mode == "constant":
46 | factor = self.gamma
47 | else:
48 | raise KeyError("WarmUp type {} not implemented".format(self.mode))
49 |
50 | return [factor * base_lr for base_lr in cold_lrs]
51 |
52 | return cold_lrs
53 |
--------------------------------------------------------------------------------
/ptsemseg/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Misc Utility functions
3 | """
4 | import os
5 | import logging
6 | import datetime
7 | import numpy as np
8 |
9 | from collections import OrderedDict
10 | from torch import nn
11 | import torch.nn.init as init
12 |
13 | def init_weights(m):
14 | if isinstance(m, nn.Conv1d):
15 | init.normal_(m.weight.data)
16 | if m.bias is not None:
17 | init.normal_(m.bias.data)
18 | elif isinstance(m, nn.Conv2d):
19 | init.xavier_normal_(m.weight.data)
20 | if m.bias is not None:
21 | init.normal_(m.bias.data)
22 | elif isinstance(m, nn.Conv3d):
23 | init.xavier_normal_(m.weight.data)
24 | if m.bias is not None:
25 | init.normal_(m.bias.data)
26 | elif isinstance(m, nn.ConvTranspose1d):
27 | init.normal_(m.weight.data)
28 | if m.bias is not None:
29 | init.normal_(m.bias.data)
30 | elif isinstance(m, nn.ConvTranspose2d):
31 | init.xavier_normal_(m.weight.data)
32 | if m.bias is not None:
33 | init.normal_(m.bias.data)
34 | elif isinstance(m, nn.ConvTranspose3d):
35 | init.xavier_normal_(m.weight.data)
36 | if m.bias is not None:
37 | init.normal_(m.bias.data)
38 | elif isinstance(m, nn.BatchNorm1d):
39 | init.normal_(m.weight.data, mean=1, std=0.02)
40 | init.constant_(m.bias.data, 0)
41 | elif isinstance(m, nn.BatchNorm2d):
42 | init.normal_(m.weight.data, mean=1, std=0.02)
43 | init.constant_(m.bias.data, 0)
44 | elif isinstance(m, nn.BatchNorm3d):
45 | init.normal_(m.weight.data, mean=1, std=0.02)
46 | init.constant_(m.bias.data, 0)
47 | elif isinstance(m, nn.Linear):
48 | init.xavier_normal_(m.weight.data)
49 | init.normal_(m.bias.data)
50 | elif isinstance(m, nn.LSTM):
51 | for param in m.parameters():
52 | if len(param.shape) >= 2:
53 | init.orthogonal_(param.data)
54 | else:
55 | init.normal_(param.data)
56 | elif isinstance(m, nn.LSTMCell):
57 | for param in m.parameters():
58 | if len(param.shape) >= 2:
59 | init.orthogonal_(param.data)
60 | else:
61 | init.normal_(param.data)
62 | elif isinstance(m, nn.GRU):
63 | for param in m.parameters():
64 | if len(param.shape) >= 2:
65 | init.orthogonal_(param.data)
66 | else:
67 | init.normal_(param.data)
68 | elif isinstance(m, nn.GRUCell):
69 | for param in m.parameters():
70 | if len(param.shape) >= 2:
71 | init.orthogonal_(param.data)
72 | else:
73 | init.normal_(param.data)
74 |
75 |
76 | def recursive_glob(rootdir=".", suffix=""):
77 | """Performs recursive glob with given suffix and rootdir
78 | :param rootdir is the root directory
79 | :param suffix is the suffix to be searched
80 | """
81 | return [
82 | os.path.join(looproot, filename)
83 | for looproot, _, filenames in os.walk(rootdir)
84 | for filename in filenames
85 | if filename.endswith(suffix)
86 | ]
87 |
88 |
89 | def alpha_blend(input_image, segmentation_mask, alpha=0.5):
90 | """Alpha Blending utility to overlay RGB masks on RBG images
91 | :param input_image is a np.ndarray with 3 channels
92 | :param segmentation_mask is a np.ndarray with 3 channels
93 | :param alpha is a float value
94 | """
95 | blended = np.zeros(input_image.size, dtype=np.float32)
96 | blended = input_image * alpha + segmentation_mask * (1 - alpha)
97 | return blended
98 |
99 |
100 | def convert_state_dict(state_dict):
101 | """Converts a state dict saved from a dataParallel module to normal
102 | module state_dict inplace
103 | :param state_dict is the loaded DataParallel model_state
104 | """
105 | new_state_dict = OrderedDict()
106 | for k, v in state_dict.items():
107 | name = k[7:] # remove `module.`
108 | new_state_dict[name] = v
109 | return new_state_dict
110 |
111 |
112 | def get_logger(logdir):
113 | logger = logging.getLogger("ptsemseg")
114 | ts = str(datetime.datetime.now()).split(".")[0].replace(" ", "_")
115 | ts = ts.replace(":", "_").replace("-", "_")
116 | file_path = os.path.join(logdir, "run_{}.log".format(ts))
117 | hdlr = logging.FileHandler(file_path)
118 | formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
119 | hdlr.setFormatter(formatter)
120 | logger.addHandler(hdlr)
121 | logger.setLevel(logging.INFO)
122 | return logger
123 |
124 |
125 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==2.0.0
2 | numpy==1.22.0
3 | scipy==0.19.0
4 | torch==0.4.1
5 | torchvision==0.2.0
6 | tqdm==4.11.2
7 | pydensecrf
8 | protobuf
9 | tensorboardX
10 | pyyaml
11 | pretrainedmodels
12 | opencv-python
13 |
--------------------------------------------------------------------------------
/teaser/1359-teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GT-RIPL/MultiAgentPerception/4ef300547a7f7af2676a034f7cf742b009f57d99/teaser/1359-teaser.gif
--------------------------------------------------------------------------------
/teaser/pytorch-logo-dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GT-RIPL/MultiAgentPerception/4ef300547a7f7af2676a034f7cf742b009f57d99/teaser/pytorch-logo-dark.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import torch
3 | import argparse
4 | import timeit
5 | import numpy as np
6 | import cv2
7 | import os
8 |
9 | from torch.utils import data
10 | from ptsemseg.models import get_model
11 | from ptsemseg.loader import get_loader
12 | from ptsemseg.metrics import runningScore
13 | from ptsemseg.utils import convert_state_dict
14 | from ptsemseg.visual import draw_bounding
15 | from ptsemseg.trainer import *
16 |
17 | torch.backends.cudnn.benchmark = True
18 |
19 |
20 | if __name__ == "__main__":
21 | parser = argparse.ArgumentParser(description="config")
22 | parser.add_argument(
23 | "--config",
24 | nargs="?",
25 | type=str,
26 | default="configs/your_configs.yml",
27 | help="Configuration file to use",
28 | )
29 |
30 | parser.add_argument(
31 | "--gpu",
32 | nargs="?",
33 | type=str,
34 | default="0",
35 | help="Used GPUs",
36 | )
37 |
38 | parser.add_argument(
39 | "--run_time",
40 | nargs="?",
41 | type=int,
42 | default=1,
43 | help="run_time",
44 | )
45 |
46 |
47 | parser.add_argument(
48 | "--model_path",
49 | nargs="?",
50 | type=str,
51 | default="Single_Agent.pkl",
52 | help="Path to the saved model",
53 | )
54 |
55 |
56 | args = parser.parse_args()
57 |
58 | # Set the gpu
59 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
60 | run_times = args.run_time
61 |
62 | with open(args.config) as fp:
63 | cfg = yaml.load(fp)
64 |
65 | # ============= Testing =============
66 |
67 | # Setup device
68 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69 |
70 | # Setup Dataloader
71 | data_loader = get_loader(cfg["data"]["dataset"])
72 | data_path = cfg["data"]["path"]
73 |
74 | # Load communication label (note that some datasets do not provide this)
75 | if 'commun_label' in cfg["data"]:
76 | if_commun_label = cfg["data"]['commun_label']
77 | else:
78 | if_commun_label = 'None'
79 |
80 | # test data loadeer
81 | te_loader = data_loader(
82 | data_path,
83 | split=cfg["data"]['test_split'],
84 | is_transform=True,
85 | img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
86 | target_view=cfg["data"]["target_view"],
87 | commun_label=if_commun_label)
88 |
89 | testloader = data.DataLoader(te_loader, batch_size=cfg["training"]["batch_size"], num_workers=8)
90 |
91 |
92 | # Setup Model
93 | model = get_model(cfg, te_loader.n_classes).to(device)
94 |
95 | # set up the model
96 | if cfg['model']['arch'] == 'LearnWhen2Com': # Our when2com
97 | trainer = Trainer_LearnWhen2Com(cfg, None, None, model, None, None, None, None, None, device)
98 | elif cfg['model']['arch'] == 'LearnWho2Com': # Our who2com
99 | trainer = Trainer_LearnWho2Com(cfg, None, None, model, None, None, None, None, None, device)
100 | elif cfg['model']['arch'] == 'MIMOcom': #
101 | trainer = Trainer_MIMOcom(cfg, None, None, model, None, None, None, None, None, device)
102 | elif cfg['model']['arch'] == 'MIMOcomMultiWarp':
103 | trainer = Trainer_MIMOcomMultiWarp(cfg, None, None, None, None, None, None, None, None, device)
104 | elif cfg['model']['arch'] == 'MIMOcomWho':
105 | trainer = Trainer_MIMOcomWho(cfg, None, None, model, None, None, None, None, None, device)
106 | elif cfg['model']['arch'] == 'Single_agent':
107 | trainer = Trainer_Single_agent(cfg, None, None, model, None, None, None, None, None, device)
108 | elif cfg['model']['arch'] == 'All_agents':
109 | trainer = Trainer_All_agents(cfg, None, None, model, None, None, None, None, None, device)
110 | elif cfg['model']['arch'] == 'MIMO_All_agents':
111 | trainer = Trainer_MIMO_All_agents(cfg, None, None, model, None, None, None, None, None, device)
112 | else:
113 | raise ValueError('Unknown arch name for testing')
114 |
115 |
116 | print(args.model_path)
117 | # load best weight
118 | trainer.load_weight(args.model_path)
119 |
120 | # if you would like to obtain qual results or other stats, just change the output
121 | _ = trainer.evaluate(testloader)
122 |
123 |
124 |
125 |
126 |
127 |
128 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import time
4 | import shutil
5 | import torch
6 | import random
7 | import argparse
8 | import numpy as np
9 | import copy
10 | import timeit
11 | import statistics
12 | import datetime
13 | from torch.utils import data
14 | from tqdm import tqdm
15 | import cv2
16 |
17 | from ptsemseg.process_img import generate_noise
18 | from ptsemseg.models import get_model
19 | from ptsemseg.loss import get_loss_function
20 | from ptsemseg.loader import get_loader
21 | from ptsemseg.utils import get_logger, init_weights
22 | from ptsemseg.metrics import runningScore
23 | from ptsemseg.augmentations import get_composed_augmentations
24 | from ptsemseg.schedulers import get_scheduler
25 | from ptsemseg.optimizers import get_optimizer
26 | from ptsemseg.utils import convert_state_dict
27 |
28 | from ptsemseg.trainer import *
29 |
30 | from tensorboardX import SummaryWriter
31 |
32 |
33 | # main function
34 | if __name__ == "__main__":
35 |
36 | parser = argparse.ArgumentParser(description="config")
37 | parser.add_argument(
38 | "--config",
39 | nargs="?",
40 | type=str,
41 | default="configs/your_configs.yml",
42 | help="Configuration file to use",
43 | )
44 |
45 | parser.add_argument(
46 | "--gpu",
47 | nargs="?",
48 | type=str,
49 | default="0",
50 | help="Used GPUs",
51 | )
52 |
53 | parser.add_argument(
54 | "--run_time",
55 | nargs="?",
56 | type=int,
57 | default=1,
58 | help="run_time",
59 | )
60 |
61 | args = parser.parse_args()
62 |
63 | # Set the gpu
64 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
65 | run_times = args.run_time
66 |
67 | with open(args.config) as fp:
68 | cfg = yaml.load(fp)
69 |
70 | data_splits = ['val_split', 'test_split']
71 |
72 | # initialize for results stats
73 | score_list = {}
74 | class_iou_list = {}
75 | acc_list = {}
76 | if cfg['model']['arch'] == 'LearnWho2Com':
77 | for infer in ['softmax', 'argmax_test']:
78 | score_list[infer] = {}
79 | class_iou_list[infer] = {}
80 | acc_list[infer] = {}
81 | for data_sp in data_splits:
82 | score_list[infer][data_sp] = []
83 | class_iou_list[infer][data_sp] = []
84 | acc_list[infer][data_sp] = []
85 | elif cfg['model']['arch'] == 'LearnWhen2Com' or \
86 | cfg['model']['arch'] == 'MIMOcom' or \
87 | cfg['model']['arch'] == 'MIMOcomMultiWarp' or \
88 | cfg['model']['arch'] == 'MIMOcomWho' :
89 | for infer in ['softmax', 'argmax_test', 'activated']:
90 | score_list[infer] = {}
91 | class_iou_list[infer] = {}
92 | acc_list[infer] = {}
93 | for data_sp in data_splits:
94 | score_list[infer][data_sp] = []
95 | class_iou_list[infer][data_sp] = []
96 | acc_list[infer][data_sp] = []
97 | elif cfg['model']['arch'] == 'Single_agent' or cfg['model']['arch'] == 'All_agents' or cfg['model']['arch'] == 'MIMO_All_agents':
98 | for infer in ['default']:
99 | score_list[infer] = {}
100 | class_iou_list[infer] = {}
101 | acc_list[infer] = {}
102 | for data_sp in data_splits:
103 | score_list[infer][data_sp] = []
104 | class_iou_list[infer][data_sp] = []
105 | acc_list[infer][data_sp] = []
106 |
107 | for _ in range(run_times):
108 | run_id = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
109 | logdir = os.path.join("runs", os.path.basename(args.config)[:-4], str(run_id))
110 | writer = SummaryWriter(logdir=logdir)
111 |
112 | print("RUNDIR: {}".format(logdir))
113 | shutil.copy(args.config, logdir)
114 |
115 |
116 | # ============= Training =============
117 | # logger
118 | logger = get_logger(logdir)
119 | logger.info("Begin")
120 |
121 | # Setup seeds
122 | torch.manual_seed(cfg.get("seed", 1337))
123 | torch.cuda.manual_seed(cfg.get("seed", 1337))
124 | np.random.seed(cfg.get("seed", 1337))
125 | random.seed(cfg.get("seed", 1337))
126 |
127 | # Setup device
128 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
129 |
130 | # Setup Dataloader
131 | data_loader = get_loader(cfg["data"]["dataset"])
132 | data_path = cfg["data"]["path"]
133 |
134 | # Load communication label (note that some datasets do not provide this)
135 | if 'commun_label' in cfg["data"]:
136 | if_commun_label = cfg["data"]['commun_label']
137 | else:
138 | if_commun_label = 'None'
139 |
140 |
141 | # dataloaders
142 | t_loader = data_loader(
143 | data_path,
144 | is_transform=True,
145 | split=cfg["data"]["train_split"],
146 | img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
147 | augmentations=get_composed_augmentations(cfg["training"].get("augmentations", None)),
148 | target_view=cfg["data"]["target_view"],
149 | commun_label=if_commun_label
150 | )
151 |
152 | v_loader = data_loader(
153 | data_path,
154 | is_transform=True,
155 | split=cfg["data"]["val_split"],
156 | img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
157 | target_view=cfg["data"]["target_view"],
158 | commun_label=if_commun_label
159 | )
160 |
161 | trainloader = data.DataLoader(
162 | t_loader,
163 | batch_size=cfg["training"]["batch_size"],
164 | num_workers=cfg["training"]["n_workers"],
165 | shuffle=True,
166 | drop_last=True
167 | )
168 |
169 | valloader = data.DataLoader(
170 | v_loader,
171 | batch_size=cfg["training"]["batch_size"],
172 | num_workers=cfg["training"]["n_workers"]
173 | )
174 |
175 | # Setup Model
176 | model = get_model(cfg, t_loader.n_classes).to(device)
177 | model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
178 | # import pdb; pdb.set_trace()
179 |
180 | # Setup optimizer
181 | optimizer_cls = get_optimizer(cfg)
182 | optimizer_params = {k: v for k, v in cfg["training"]["optimizer"].items() if k != "name"}
183 | optimizer = optimizer_cls(model.parameters(), **optimizer_params)
184 | logger.info("Using optimizer {}".format(optimizer))
185 |
186 | # Setup scheduler
187 | scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])
188 |
189 | # Setup loss
190 | loss_fn = get_loss_function(cfg)
191 | logger.info("Using loss {}".format(loss_fn))
192 |
193 |
194 | # ================== TRAINING ==================
195 | if cfg['model']['arch'] == 'LearnWhen2Com': # Our when2com
196 | trainer = Trainer_LearnWhen2Com(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device)
197 | elif cfg['model']['arch'] == 'LearnWho2Com': # Our who2com
198 | trainer = Trainer_LearnWho2Com(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device)
199 | elif cfg['model']['arch'] == 'MIMOcom': #
200 | trainer = Trainer_MIMOcom(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device)
201 | elif cfg['model']['arch'] == 'MIMOcomMultiWarp':
202 | trainer = Trainer_MIMOcomMultiWarp(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device)
203 | elif cfg['model']['arch'] == 'MIMOcomWho':
204 | trainer = Trainer_MIMOcomWho(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device)
205 | elif cfg['model']['arch'] == 'Single_agent':
206 | trainer = Trainer_Single_agent(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device)
207 | elif cfg['model']['arch'] == 'All_agents':
208 | trainer = Trainer_All_agents(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device)
209 | elif cfg['model']['arch'] == 'MIMO_All_agents':
210 | trainer = Trainer_MIMO_All_agents(cfg, writer, logger, model, loss_fn, trainloader, valloader, optimizer, scheduler, device)
211 | else:
212 | raise ValueError('Unknown arch name for training')
213 |
214 | model_path = trainer.train()
215 |
216 |
217 | # ================ Val + Test ================
218 |
219 | te_loader = data_loader(
220 | data_path,
221 | split=cfg["data"]['test_split'],
222 | is_transform=True,
223 | img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
224 | target_view=cfg["data"]["target_view"],
225 | commun_label=if_commun_label)
226 |
227 | n_classes = te_loader.n_classes
228 | testloader = data.DataLoader(te_loader, batch_size=cfg["training"]["batch_size"], num_workers=8)
229 |
230 | # load best weight
231 | trainer.load_weight(model_path)
232 | _ = trainer.evaluate(testloader)
233 |
234 |
235 |
236 |
--------------------------------------------------------------------------------