├── .gitignore
├── LICENSE
├── README.md
├── assets
├── LEOD.png
├── det-results.png
└── gen1-video.mp4
├── callbacks
├── custom.py
├── detection.py
├── gradflow.py
├── utils
│ └── visualization.py
└── viz_base.py
├── config
├── dataset
│ ├── base.yaml
│ ├── gen1-tflip.yaml
│ ├── gen1.yaml
│ ├── gen1x0.01_seq.yaml
│ ├── gen1x0.01_ss-1round.yaml
│ ├── gen1x0.01_ss.yaml
│ ├── gen1x0.02_seq.yaml
│ ├── gen1x0.02_ss.yaml
│ ├── gen1x0.05_seq.yaml
│ ├── gen1x0.05_ss.yaml
│ ├── gen1x0.1_seq.yaml
│ ├── gen1x0.1_ss.yaml
│ ├── gen4-tflip.yaml
│ ├── gen4.yaml
│ ├── gen4x0.01_seq.yaml
│ ├── gen4x0.01_ss.yaml
│ ├── gen4x0.02_seq.yaml
│ ├── gen4x0.02_ss.yaml
│ ├── gen4x0.05_seq.yaml
│ ├── gen4x0.05_ss.yaml
│ ├── gen4x0.1_seq.yaml
│ └── gen4x0.1_ss.yaml
├── experiment
│ ├── gen1
│ │ ├── base.yaml
│ │ ├── default.yaml
│ │ ├── small.yaml
│ │ └── tiny.yaml
│ └── gen4
│ │ ├── base.yaml
│ │ ├── default.yaml
│ │ ├── small.yaml
│ │ └── tiny.yaml
├── general.yaml
├── model
│ ├── base.yaml
│ ├── maxvit_yolox
│ │ └── default.yaml
│ ├── pseudo_labeler-gen4-wsod.yaml
│ ├── pseudo_labeler.yaml
│ ├── rnndet-soft-gen4-wsod.yaml
│ ├── rnndet-soft.yaml
│ └── rnndet.yaml
├── modifier.py
├── predict.yaml
├── train.yaml
├── val.yaml
└── vis.yaml
├── data
├── genx_utils
│ ├── collate.py
│ ├── collate_from_pytorch.py
│ ├── dataset_rnd.py
│ ├── dataset_streaming.py
│ ├── labels.py
│ ├── sequence_base.py
│ ├── sequence_rnd.py
│ ├── sequence_streaming.py
│ └── splits
│ │ ├── gen1
│ │ ├── ssod_0.010-off0.pkl
│ │ ├── ssod_0.020-off0.pkl
│ │ ├── ssod_0.050-off0.pkl
│ │ └── ssod_0.100-off0.pkl
│ │ └── gen4
│ │ ├── ssod_0.010-off0.pkl
│ │ ├── ssod_0.020-off0.pkl
│ │ ├── ssod_0.050-off0.pkl
│ │ └── ssod_0.100-off0.pkl
└── utils
│ ├── augmentor.py
│ ├── misc.py
│ ├── representations.py
│ ├── spatial.py
│ ├── ssod_augmentor.py
│ ├── stream_concat_datapipe.py
│ ├── stream_sharded_datapipe.py
│ └── types.py
├── datasets
└── .gitkeep
├── docs
├── benchmark.md
└── install.md
├── environment.yml
├── loggers
└── utils.py
├── models
├── detection
│ ├── __init__.py
│ ├── recurrent_backbone
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── maxvit_rnn.py
│ ├── yolox
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── losses.py
│ │ │ ├── network_blocks.py
│ │ │ └── yolo_head.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── boxes.py
│ │ │ └── compat.py
│ └── yolox_extension
│ │ └── models
│ │ ├── __init__.py
│ │ ├── build.py
│ │ ├── detector.py
│ │ └── yolo_pafpn.py
└── layers
│ ├── maxvit
│ ├── __init__.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ ├── activations_jit.py
│ │ ├── activations_me.py
│ │ ├── adaptive_avgmax_pool.py
│ │ ├── attention_pool2d.py
│ │ ├── blur_pool.py
│ │ ├── bottleneck_attn.py
│ │ ├── cbam.py
│ │ ├── classifier.py
│ │ ├── cond_conv2d.py
│ │ ├── config.py
│ │ ├── conv2d_same.py
│ │ ├── conv_bn_act.py
│ │ ├── create_act.py
│ │ ├── create_attn.py
│ │ ├── create_conv2d.py
│ │ ├── create_norm.py
│ │ ├── create_norm_act.py
│ │ ├── drop.py
│ │ ├── eca.py
│ │ ├── evo_norm.py
│ │ ├── fast_norm.py
│ │ ├── filter_response_norm.py
│ │ ├── gather_excite.py
│ │ ├── global_context.py
│ │ ├── halo_attn.py
│ │ ├── helpers.py
│ │ ├── inplace_abn.py
│ │ ├── lambda_layer.py
│ │ ├── linear.py
│ │ ├── median_pool.py
│ │ ├── mixed_conv2d.py
│ │ ├── ml_decoder.py
│ │ ├── mlp.py
│ │ ├── non_local_attn.py
│ │ ├── norm.py
│ │ ├── norm_act.py
│ │ ├── padding.py
│ │ ├── patch_embed.py
│ │ ├── pool2d_same.py
│ │ ├── pos_embed.py
│ │ ├── selective_kernel.py
│ │ ├── separable_conv.py
│ │ ├── space_to_depth.py
│ │ ├── split_attn.py
│ │ ├── split_batchnorm.py
│ │ ├── squeeze_excite.py
│ │ ├── std_conv.py
│ │ ├── test_time_pool.py
│ │ ├── trace_utils.py
│ │ └── weight_init.py
│ └── maxvit.py
│ └── rnn.py
├── modules
├── __init__.py
├── data
│ └── genx.py
├── detection.py
├── pseudo_labeler.py
├── tracking
│ ├── __init__.py
│ ├── linear.py
│ ├── tracker.py
│ └── utils.py
└── utils
│ ├── detection.py
│ ├── fetch.py
│ ├── ssod.py
│ └── tta.py
├── predict.py
├── pretrained
└── .gitkeep
├── train.py
├── utils
├── bbox.py
├── evaluation
│ └── prophesee
│ │ ├── __init__.py
│ │ ├── evaluation.py
│ │ ├── evaluator.py
│ │ ├── io
│ │ ├── __init__.py
│ │ ├── box_filtering.py
│ │ ├── box_loading.py
│ │ ├── dat_events_tools.py
│ │ ├── npy_events_tools.py
│ │ └── psee_loader.py
│ │ ├── metrics
│ │ ├── __init__.py
│ │ └── coco_eval.py
│ │ └── visualize
│ │ ├── __init__.py
│ │ └── vis_utils.py
├── helpers.py
├── padding.py
├── preprocessing.py
└── timers.py
├── val.py
├── val_dst.py
└── vis_pred.py
/.gitignore:
--------------------------------------------------------------------------------
1 | pretrained
2 | pretrained/
3 | checkpoint/
4 | checkpoints/
5 | datasets/
6 | vis/
7 | old_vis/
8 | outputs/
9 | validation_logs/
10 | sbatch/
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # C extensions
18 | *.so
19 |
20 | # Distribution / packaging
21 | .Python
22 | build/
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | pip-wheel-metadata/
35 | share/python-wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .nox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | *.py,cover
62 | .hypothesis/
63 | .pytest_cache/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 | db.sqlite3
73 | db.sqlite3-journal
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106 | __pypackages__/
107 |
108 | # Celery stuff
109 | celerybeat-schedule
110 | celerybeat.pid
111 |
112 | # SageMath parsed files
113 | *.sage.py
114 |
115 | # Environments
116 | .env
117 | .venv
118 | env/
119 | venv/
120 | ENV/
121 | env.bak/
122 | venv.bak/
123 |
124 | # Spyder project settings
125 | .spyderproject
126 | .spyproject
127 |
128 | # Rope project settings
129 | .ropeproject
130 |
131 | # mkdocs documentation
132 | /site
133 |
134 | # mypy
135 | .mypy_cache/
136 | .dmypy.json
137 | dmypy.json
138 |
139 | # Pyre type checker
140 | .pyre/
141 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Mathias Gehrig & Ziyi Wu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LEOD
2 |
3 | This is the official Pytorch implementation for our CVPR 2024 paper:
4 |
5 | [**LEOD: Label-Efficient Object Detection for Event Cameras**](https://arxiv.org/abs/2311.17286)
6 | [Ziyi Wu](https://wuziyi616.github.io/),
7 | [Mathias Gehrig](https://magehrig.github.io/),
8 | Qing Lyu,
9 | Xudong Liu,
10 | [Igor Gilitschenski](https://tisl.cs.utoronto.ca/author/igor-gilitschenski/)
11 | _CVPR'24 |
12 | [GitHub](https://github.com/Wuziyi616/LEOD?tab=readme-ov-file#leod) |
13 | [arXiv](https://arxiv.org/abs/2311.17286)_
14 |
15 |
16 |
17 |
18 |
19 | ## TL;DR
20 |
21 | [Event cameras](https://tub-rip.github.io/eventvision2023/#null) are bio-inspired low-latency sensors, which hold great potentials for safety-critical applications such as object detection in self-driving.
22 | Due to the high temporal resolution (>1000 FPS) of event data, existing datasets are annotated at a low frame rate (e.g., 4 FPS).
23 | As a result, models are only trained on these annotated frames, leading to sub-optimal performance and slow convergence speed.
24 | In this paper, we tackle this problem from the perspective of weakly-/semi-supervised learning.
25 | We design a novel self-training framework that pseudo-labels unannotated events with reliable model predictions, which achieves SOTA performance on two largest detection benchmarks.
26 |
27 |
28 |
29 |
30 |
31 | ## Install
32 |
33 | This codebase builds upon [RVT](https://github.com/uzh-rpg/RVT).
34 | Please refer to [install.md](./docs/install.md) for detailed instructions.
35 |
36 | ## Experiments
37 |
38 | **This codebase is tailored to [Slurm](https://slurm.schedmd.com/documentation.html) GPU clusters with preemption mechanism.**
39 | There are some functions in the code (e.g. auto-detect and load previous checkpoints) which you might not need.
40 | Please go through all fields marked with `TODO` in `train.py` in case there is any conflict with your environment.
41 | To reproduce the results in the paper, please refer to [benchmark.md](docs/benchmark.md).
42 |
43 | ## Citation
44 |
45 | Please cite our paper if you find it useful in your research:
46 | ```bibtex
47 | @inproceedings{wu2024leod,
48 | title={LEOD: Label-Efficient Object Detection for Event Cameras},
49 | author={Wu, Ziyi and Gehrig, Mathias and Lyu, Qing and Liu, Xudong and Gilitschenski, Igor},
50 | booktitle={CVPR},
51 | year={2024}
52 | }
53 | ```
54 |
55 | ## Acknowledgement
56 |
57 | We thank the authors of [RVT](https://github.com/uzh-rpg/RVT), [SORT](https://github.com/abewley/sort), [Soft Teacher](https://github.com/microsoft/SoftTeacher), [Unbiased Teacher](https://github.com/facebookresearch/unbiased-teacher) and all the packages we use in this repo for opening source their wonderful works.
58 |
59 | ## License
60 |
61 | LEOD is released under the MIT License. See the LICENSE file for more details.
62 |
63 | ## Contact
64 |
65 | If you have any questions about the code, please contact Ziyi Wu dazitu616@gmail.com
66 |
--------------------------------------------------------------------------------
/assets/LEOD.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/assets/LEOD.png
--------------------------------------------------------------------------------
/assets/det-results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/assets/det-results.png
--------------------------------------------------------------------------------
/assets/gen1-video.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/assets/gen1-video.mp4
--------------------------------------------------------------------------------
/callbacks/custom.py:
--------------------------------------------------------------------------------
1 | from datetime import timedelta
2 | from omegaconf import DictConfig
3 | from pytorch_lightning.callbacks import Callback
4 | from pytorch_lightning.callbacks import ModelCheckpoint
5 |
6 | from callbacks.detection import DetectionVizCallback
7 |
8 |
9 | def get_ckpt_callback(config: DictConfig, ckpt_dir: str = None) -> ModelCheckpoint:
10 | prefix = 'val'
11 | metric = 'AP'
12 | mode = 'max'
13 | ckpt_callback_monitor = prefix + '/' + metric
14 | filename_monitor_str = prefix + '_' + metric
15 |
16 | ckpt_filename = 'epoch_{epoch:03d}-step_{step}-' + filename_monitor_str + '_{' + ckpt_callback_monitor + ':.4f}'
17 | every_n_min = config.logging.ckpt_every_min
18 | cktp_callback = ModelCheckpoint(
19 | dirpath=ckpt_dir,
20 | monitor=ckpt_callback_monitor,
21 | filename=ckpt_filename,
22 | auto_insert_metric_name=False, # because backslash would create a directory
23 | save_top_k=2, # in case the best one is broken
24 | mode=mode,
25 | train_time_interval=timedelta(minutes=every_n_min),
26 | save_last=True,
27 | verbose=True)
28 | cktp_callback.CHECKPOINT_NAME_LAST = 'last_epoch_{epoch:03d}-step_{step}'
29 | return cktp_callback
30 |
31 |
32 | def get_viz_callback(config: DictConfig) -> Callback:
33 | if hasattr(config.model, 'pseudo_label'):
34 | prefixs = ['', 'pseudo_']
35 | else:
36 | prefixs = ['']
37 | return DetectionVizCallback(config=config, prefixs=prefixs)
38 |
--------------------------------------------------------------------------------
/callbacks/detection.py:
--------------------------------------------------------------------------------
1 | from enum import Enum, auto
2 | from typing import Any, List
3 |
4 | import torch
5 | from einops import rearrange
6 | from omegaconf import DictConfig
7 | from pytorch_lightning.loggers.wandb import WandbLogger
8 |
9 | from data.utils.types import ObjDetOutput
10 | from utils.evaluation.prophesee.visualize.vis_utils import get_labelmap, draw_bboxes
11 | from .viz_base import VizCallbackBase
12 |
13 |
14 | class DetectionVizEnum(Enum):
15 | EV_IMG = auto()
16 | LABEL_IMG_PROPH = auto()
17 | PRED_IMG_PROPH = auto()
18 |
19 |
20 | class DetectionVizCallback(VizCallbackBase):
21 | """Visualize predicted and GT bbox on event-converted RGB frames."""
22 |
23 | def __init__(self, config: DictConfig, prefixs: List[str] = ['']):
24 | super().__init__(config=config, buffer_entries=DetectionVizEnum)
25 |
26 | self.label_map = get_labelmap(dst_name=config.dataset.name)
27 | self.prefixs = prefixs
28 |
29 | def on_train_batch_end_custom(self, *args, **kwargs) -> None:
30 | for prefix in self.prefixs:
31 | self._on_train_batch_end_custom(*args, **kwargs, prefix=prefix)
32 |
33 | def _on_train_batch_end_custom(self,
34 | logger: WandbLogger,
35 | outputs: Any,
36 | batch: Any,
37 | log_n_samples: int,
38 | global_step: int,
39 | prefix: str) -> None:
40 | """May need to load images from different labeled data."""
41 | if outputs is None:
42 | # If we tried to skip the training step (not supported in DDP in PL, atm)
43 | return
44 | if f'{prefix}{ObjDetOutput.EV_REPR}' not in outputs:
45 | return
46 | ev_tensors = outputs[f'{prefix}{ObjDetOutput.EV_REPR}']
47 | num_samples = len(ev_tensors)
48 | assert num_samples > 0
49 | log_n_samples = min(num_samples, log_n_samples)
50 |
51 | merged_img = []
52 | captions = []
53 | start_idx = num_samples - 1
54 | end_idx = start_idx - log_n_samples
55 | # for sample_idx in range(log_n_samples):
56 | for sample_idx in range(start_idx, end_idx, -1):
57 | ev_img = self.ev_repr_to_img(ev_tensors[sample_idx].cpu().numpy())
58 |
59 | predictions_proph = outputs[f'{prefix}{ObjDetOutput.PRED_PROPH}'][sample_idx]
60 | prediction_img = ev_img.copy()
61 | draw_bboxes(prediction_img, predictions_proph, labelmap=self.label_map)
62 |
63 | labels_proph = outputs[f'{prefix}{ObjDetOutput.LABELS_PROPH}'][sample_idx]
64 | label_img = ev_img.copy()
65 | draw_bboxes(label_img, labels_proph, labelmap=self.label_map)
66 |
67 | merged_img.append(rearrange([prediction_img, label_img], 'pl H W C -> (pl H) W C', pl=2, C=3))
68 | captions.append(f'sample_{sample_idx}')
69 |
70 | logger.log_image(key=f'train/{prefix}predictions', # PL's native wandb
71 | images=merged_img,
72 | caption=captions,
73 | step=global_step)
74 |
75 | def on_validation_batch_end_custom(self, batch: Any, outputs: Any) -> None:
76 | """Val is not affected by pseudo-label training."""
77 | if outputs[ObjDetOutput.SKIP_VIZ]:
78 | return
79 | ev_tensor = outputs[ObjDetOutput.EV_REPR]
80 | assert isinstance(ev_tensor, torch.Tensor)
81 |
82 | ev_img = self.ev_repr_to_img(ev_tensor.cpu().numpy())
83 |
84 | predictions_proph = outputs[ObjDetOutput.PRED_PROPH]
85 | prediction_img = ev_img.copy()
86 | draw_bboxes(prediction_img, predictions_proph, labelmap=self.label_map)
87 | self.add_to_buffer(DetectionVizEnum.PRED_IMG_PROPH, prediction_img)
88 |
89 | labels_proph = outputs[ObjDetOutput.LABELS_PROPH]
90 | label_img = ev_img.copy()
91 | draw_bboxes(label_img, labels_proph, labelmap=self.label_map)
92 | self.add_to_buffer(DetectionVizEnum.LABEL_IMG_PROPH, label_img)
93 |
94 | def on_validation_epoch_end_custom(self, logger: WandbLogger):
95 | pred_imgs = self.get_from_buffer(DetectionVizEnum.PRED_IMG_PROPH)
96 | label_imgs = self.get_from_buffer(DetectionVizEnum.LABEL_IMG_PROPH)
97 | assert len(pred_imgs) == len(label_imgs)
98 | merged_img = []
99 | captions = []
100 | for idx, (pred_img, label_img) in enumerate(zip(pred_imgs, label_imgs)):
101 | merged_img.append(rearrange([pred_img, label_img], 'pl H W C -> (pl H) W C', pl=2, C=3))
102 | captions.append(f'sample_{idx}')
103 |
104 | logger.log_image(key='val/predictions',
105 | images=merged_img,
106 | caption=captions)
107 |
--------------------------------------------------------------------------------
/callbacks/gradflow.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import pytorch_lightning as pl
4 | from pytorch_lightning.callbacks import Callback
5 | from pytorch_lightning.utilities.rank_zero import rank_zero_only
6 |
7 | from callbacks.utils.visualization import get_grad_flow_figure
8 |
9 |
10 | class GradFlowLogCallback(Callback):
11 | def __init__(self, log_every_n_train_steps: int):
12 | super().__init__()
13 | assert log_every_n_train_steps > 0
14 | self.log_every_n_train_steps = log_every_n_train_steps
15 |
16 | @rank_zero_only
17 | def on_before_zero_grad(self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Any) -> None:
18 | # NOTE: before we had this in the on_after_backward callback.
19 | # This was fine for fp32 but showed unscaled gradients for fp16.
20 | # That is why we move it to on_before_zero_grad where gradients are scaled.
21 | global_step = trainer.global_step
22 | if global_step % self.log_every_n_train_steps != 0:
23 | return
24 | named_parameters = pl_module.named_parameters()
25 | figure = get_grad_flow_figure(named_parameters)
26 | trainer.logger.log_metrics({'train/gradients': figure}, step=global_step)
27 |
--------------------------------------------------------------------------------
/callbacks/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import plotly.express as px
3 |
4 |
5 | def get_grad_flow_figure(named_params):
6 | """Creates figure to visualize gradients flowing through different layers in the net during training.
7 | Can be used for checking for possible gradient vanishing / exploding problems.
8 | Usage: Use this function after loss.backwards()
9 | """
10 | data_dict = {
11 | 'name': list(),
12 | 'grad_abs': list(),
13 | }
14 | for name, param in named_params:
15 | if param.requires_grad and param.grad is not None:
16 | grad_abs = param.grad.abs()
17 | data_dict['name'].append(name)
18 | data_dict['grad_abs'].append(grad_abs.mean().cpu().item())
19 |
20 | data_frame = pd.DataFrame.from_dict(data_dict)
21 |
22 | fig = px.bar(data_frame, x='name', y='grad_abs')
23 | return fig
24 |
--------------------------------------------------------------------------------
/config/dataset/base.yaml:
--------------------------------------------------------------------------------
1 | name: ???
2 | path: ???
3 | ssod: False # for semi-supervised object detection, see `modifier.py`
4 | ratio: -1 # sub-sample the labeling frequency of each event sequence
5 | train_ratio: -1 # sub-sample the training set
6 | val_ratio: -1 # to accelerate the val process, currently takes ~20min
7 | test_ratio: -1 # to accelerate the test process, currently takes ~20min
8 | only_load_labels: False # only load the label, not the events
9 | reverse_event_order: False # reverse the temporal order of events
10 | train:
11 | sampling: 'mixed' # ('random', 'stream', 'mixed')
12 | random:
13 | weighted_sampling: False
14 | mixed:
15 | w_stream: 1
16 | w_random: 1
17 | eval:
18 | sampling: 'stream'
19 | data_augmentation:
20 | tflip_offset: -1
21 | random:
22 | prob_hflip: 0.5
23 | prob_tflip: 0
24 | rotate:
25 | prob: 0
26 | min_angle_deg: 2
27 | max_angle_deg: 6
28 | zoom:
29 | prob: 0.8
30 | zoom_in:
31 | weight: 8
32 | factor:
33 | min: 1
34 | max: 1.5
35 | zoom_out:
36 | weight: 2
37 | factor:
38 | min: 1
39 | max: 1.2
40 | stream:
41 | start_from_zero: False
42 | prob_hflip: 0.5
43 | prob_tflip: 0
44 | rotate:
45 | prob: 0
46 | min_angle_deg: 2
47 | max_angle_deg: 6
48 | zoom:
49 | prob: 0.5
50 | zoom_out:
51 | factor:
52 | min: 1
53 | max: 1.2
54 |
--------------------------------------------------------------------------------
/config/dataset/gen1-tflip.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen1
3 |
4 | # time-flip data augmentation
5 | data_augmentation:
6 | random:
7 | prob_tflip: 0.5
8 | stream:
9 | prob_tflip: 0.5
10 |
--------------------------------------------------------------------------------
/config/dataset/gen1.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - base
3 |
4 | name: gen1
5 | path: ./datasets/gen1/
6 | ev_repr_name: 'stacked_histogram_dt=50_nbins=10'
7 | sequence_length: 21
8 | resolution_hw: [240, 304]
9 | downsample_by_factor_2: False
10 | only_load_end_labels: False
11 |
--------------------------------------------------------------------------------
/config/dataset/gen1x0.01_seq.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen1
3 |
4 | train_ratio: 0.01
5 | val_ratio: 0.5
6 |
7 | # time-flip data augmentation
8 | data_augmentation:
9 | random:
10 | prob_tflip: 0.5
11 | stream:
12 | prob_tflip: 0.5
13 |
--------------------------------------------------------------------------------
/config/dataset/gen1x0.01_ss-1round.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen1
3 |
4 | ratio: -1 # we now use all the pesudo labels
5 | val_ratio: 0.5
6 |
7 | path: ./datasets/pseudo_gen1/gen1x0.01_ss-1round
8 |
9 | # time-flip data augmentation
10 | data_augmentation:
11 | random:
12 | prob_tflip: 0.5
13 | stream:
14 | prob_tflip: 0.5
15 |
--------------------------------------------------------------------------------
/config/dataset/gen1x0.01_ss.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen1
3 |
4 | ratio: 0.01
5 | val_ratio: 0.5
6 |
7 | # time-flip data augmentation
8 | data_augmentation:
9 | random:
10 | prob_tflip: 0.5
11 | stream:
12 | prob_tflip: 0.5
13 |
--------------------------------------------------------------------------------
/config/dataset/gen1x0.02_seq.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen1
3 |
4 | train_ratio: 0.02
5 | val_ratio: 0.5
6 |
7 | # time-flip data augmentation
8 | data_augmentation:
9 | random:
10 | prob_tflip: 0.5
11 | stream:
12 | prob_tflip: 0.5
13 |
--------------------------------------------------------------------------------
/config/dataset/gen1x0.02_ss.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen1
3 |
4 | ratio: 0.02
5 | val_ratio: 0.5
6 |
7 | # time-flip data augmentation
8 | data_augmentation:
9 | random:
10 | prob_tflip: 0.5
11 | stream:
12 | prob_tflip: 0.5
13 |
--------------------------------------------------------------------------------
/config/dataset/gen1x0.05_seq.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen1
3 |
4 | train_ratio: 0.05
5 | val_ratio: 0.5
6 |
7 | # time-flip data augmentation
8 | data_augmentation:
9 | random:
10 | prob_tflip: 0.5
11 | stream:
12 | prob_tflip: 0.5
13 |
--------------------------------------------------------------------------------
/config/dataset/gen1x0.05_ss.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen1
3 |
4 | ratio: 0.05
5 | val_ratio: 0.5
6 |
7 | # time-flip data augmentation
8 | data_augmentation:
9 | random:
10 | prob_tflip: 0.5
11 | stream:
12 | prob_tflip: 0.5
13 |
--------------------------------------------------------------------------------
/config/dataset/gen1x0.1_seq.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen1
3 |
4 | train_ratio: 0.1
5 | val_ratio: 0.5
6 |
7 | # time-flip data augmentation
8 | data_augmentation:
9 | random:
10 | prob_tflip: 0.5
11 | stream:
12 | prob_tflip: 0.5
13 |
--------------------------------------------------------------------------------
/config/dataset/gen1x0.1_ss.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen1
3 |
4 | ratio: 0.1
5 | val_ratio: 0.5
6 |
7 | # time-flip data augmentation
8 | data_augmentation:
9 | random:
10 | prob_tflip: 0.5
11 | stream:
12 | prob_tflip: 0.5
13 |
--------------------------------------------------------------------------------
/config/dataset/gen4-tflip.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen4
3 |
4 | sequence_length: 5
5 |
6 | # time-flip data augmentation
7 | data_augmentation:
8 | random:
9 | prob_tflip: 0.5
10 | stream:
11 | prob_tflip: 0.5
12 |
--------------------------------------------------------------------------------
/config/dataset/gen4.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - base
3 |
4 | name: gen4
5 | path: ./datasets/gen4/
6 | ev_repr_name: 'stacked_histogram_dt=50_nbins=10'
7 | sequence_length: 5
8 | resolution_hw: [720, 1280]
9 | downsample_by_factor_2: True
10 | only_load_end_labels: False
11 |
12 | data_augmentation:
13 | tflip_offset: -2 # the GT labels are not well-aligned with the event frames
14 |
--------------------------------------------------------------------------------
/config/dataset/gen4x0.01_seq.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen4
3 |
4 | train_ratio: 0.01
5 | val_ratio: 0.5
6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage
7 |
8 | # time-flip data augmentation
9 | data_augmentation:
10 | random:
11 | prob_tflip: 0.5
12 | stream:
13 | prob_tflip: 0.5
14 |
--------------------------------------------------------------------------------
/config/dataset/gen4x0.01_ss.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen4
3 |
4 | ratio: 0.01
5 | val_ratio: 0.5
6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage
7 |
8 | # time-flip data augmentation
9 | data_augmentation:
10 | random:
11 | prob_tflip: 0.5
12 | stream:
13 | prob_tflip: 0.5
14 |
--------------------------------------------------------------------------------
/config/dataset/gen4x0.02_seq.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen4
3 |
4 | train_ratio: 0.02
5 | val_ratio: 0.5
6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage
7 |
8 | # time-flip data augmentation
9 | data_augmentation:
10 | random:
11 | prob_tflip: 0.5
12 | stream:
13 | prob_tflip: 0.5
14 |
--------------------------------------------------------------------------------
/config/dataset/gen4x0.02_ss.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen4
3 |
4 | ratio: 0.02
5 | val_ratio: 0.5
6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage
7 |
8 | # time-flip data augmentation
9 | data_augmentation:
10 | random:
11 | prob_tflip: 0.5
12 | stream:
13 | prob_tflip: 0.5
14 |
--------------------------------------------------------------------------------
/config/dataset/gen4x0.05_seq.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen4
3 |
4 | train_ratio: 0.05
5 | val_ratio: 0.5
6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage
7 |
8 | # time-flip data augmentation
9 | data_augmentation:
10 | random:
11 | prob_tflip: 0.5
12 | stream:
13 | prob_tflip: 0.5
14 |
--------------------------------------------------------------------------------
/config/dataset/gen4x0.05_ss.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen4
3 |
4 | ratio: 0.05
5 | val_ratio: 0.5
6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage
7 |
8 | # time-flip data augmentation
9 | data_augmentation:
10 | random:
11 | prob_tflip: 0.5
12 | stream:
13 | prob_tflip: 0.5
14 |
--------------------------------------------------------------------------------
/config/dataset/gen4x0.1_seq.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen4
3 |
4 | train_ratio: 0.1
5 | val_ratio: 0.5
6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage
7 |
8 | # time-flip data augmentation
9 | data_augmentation:
10 | random:
11 | prob_tflip: 0.5
12 | stream:
13 | prob_tflip: 0.5
14 |
--------------------------------------------------------------------------------
/config/dataset/gen4x0.1_ss.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - gen4
3 |
4 | ratio: 0.1
5 | val_ratio: 0.5
6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage
7 |
8 | # time-flip data augmentation
9 | data_augmentation:
10 | random:
11 | prob_tflip: 0.5
12 | stream:
13 | prob_tflip: 0.5
14 |
--------------------------------------------------------------------------------
/config/experiment/gen1/base.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - default
4 |
5 | model:
6 | backbone:
7 | embed_dim: 64
8 | fpn:
9 | depth: 0.67
10 |
--------------------------------------------------------------------------------
/config/experiment/gen1/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /model/maxvit_yolox: default
4 |
5 | training:
6 | precision: 16
7 | max_epochs: 10000
8 | max_steps: 400000
9 | learning_rate: 0.0002
10 | lr_scheduler:
11 | use: True
12 | total_steps: ${..max_steps}
13 | pct_start: 0.005
14 | div_factor: 20
15 | final_div_factor: 10000
16 | batch_size:
17 | train: 8
18 | eval: 8
19 | hardware:
20 | num_workers:
21 | train: 8
22 | eval: 8
23 | dataset:
24 | train:
25 | sampling: 'mixed'
26 | random:
27 | weighted_sampling: False
28 | mixed:
29 | w_stream: 1
30 | w_random: 1
31 | eval:
32 | sampling: 'stream'
33 | model:
34 | backbone:
35 | partition_split_32: 1
36 |
--------------------------------------------------------------------------------
/config/experiment/gen1/small.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - default
4 |
5 | model:
6 | backbone:
7 | embed_dim: 48
8 | stage:
9 | attention:
10 | dim_head: 24
11 | fpn:
12 | depth: 0.33
13 |
--------------------------------------------------------------------------------
/config/experiment/gen1/tiny.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - default
4 |
5 | model:
6 | backbone:
7 | embed_dim: 32
8 | fpn:
9 | depth: 0.33
10 |
--------------------------------------------------------------------------------
/config/experiment/gen4/base.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - default
4 |
5 | model:
6 | backbone:
7 | embed_dim: 64
8 | fpn:
9 | depth: 0.67
10 |
--------------------------------------------------------------------------------
/config/experiment/gen4/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /model/maxvit_yolox: default
4 |
5 | training:
6 | precision: 16
7 | max_epochs: 10000
8 | max_steps: 400000
9 | learning_rate: 0.000346
10 | lr_scheduler:
11 | use: True
12 | total_steps: ${..max_steps}
13 | pct_start: 0.005
14 | div_factor: 20
15 | final_div_factor: 10000
16 | batch_size:
17 | train: 12
18 | eval: 12
19 | hardware:
20 | num_workers:
21 | train: 8
22 | eval: 4
23 | dataset:
24 | train:
25 | sampling: 'mixed'
26 | random:
27 | weighted_sampling: False
28 | mixed:
29 | w_stream: 1
30 | w_random: 1
31 | eval:
32 | sampling: 'stream'
33 |
--------------------------------------------------------------------------------
/config/experiment/gen4/small.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - default
4 |
5 | model:
6 | backbone:
7 | embed_dim: 48
8 | stage:
9 | attention:
10 | dim_head: 24
11 | fpn:
12 | depth: 0.33
13 |
--------------------------------------------------------------------------------
/config/experiment/gen4/tiny.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - default
4 |
5 | model:
6 | backbone:
7 | embed_dim: 32
8 | fpn:
9 | depth: 0.33
10 |
--------------------------------------------------------------------------------
/config/general.yaml:
--------------------------------------------------------------------------------
1 | reproduce:
2 | seed_everything: null # Union[int, null]
3 | deterministic_flag: False # Must be true for fully deterministic behaviour (slows down training)
4 | benchmark: True # Should be set to false for fully deterministic behaviour. Could potentially speed up training.
5 | training:
6 | precision: 16
7 | max_epochs: 10000
8 | max_steps: 400000
9 | learning_rate: 0.0002
10 | weight_decay: 0
11 | gradient_clip_val: 1.0
12 | limit_train_batches: 1.0
13 | lr_scheduler:
14 | use: True
15 | total_steps: ${..max_steps}
16 | pct_start: 0.005
17 | div_factor: 25 # init_lr = max_lr / div_factor
18 | final_div_factor: 10000 # final_lr = max_lr / final_div_factor (this is different from Pytorch' OneCycleLR param)
19 | validation:
20 | limit_val_batches: 1.0
21 | val_check_interval: 20000 # Optional[int]
22 | check_val_every_n_epoch: null # Optional[int]
23 | batch_size:
24 | train: 8
25 | eval: 8
26 | hardware:
27 | num_workers:
28 | train: 8
29 | eval: 8
30 | gpus: 0 # Either a single integer (e.g. 3) or a list of integers (e.g. [3,5,6])
31 | dist_backend: "nccl"
32 | logging:
33 | ckpt_every_min: 18 # checkpoint every x minutes
34 | train:
35 | metrics:
36 | compute: false
37 | detection_metrics_every_n_steps: null # Optional[int] -> null: every train epoch, int: every N steps
38 | log_model_every_n_steps: 5000
39 | log_every_n_steps: 100
40 | high_dim:
41 | enable: True
42 | every_n_steps: 5000
43 | n_samples: 4
44 | validation:
45 | high_dim:
46 | enable: True
47 | every_n_epochs: 1
48 | n_samples: 8
49 | wandb:
50 | wandb_name: null # name of the run
51 | wandb_id: null # name of the run
52 | wandb_runpath: null # WandB run path. E.g. USERNAME/PROJECTNAME/1grv5kg6
53 | group_name: null # Specify group name of the run
54 | project_name: RVT
55 | suffix: "" # full dup run name suffix
56 | weight: "" # only resume weight
57 | checkpoint: "" # resume weight + training state
58 | pretrain_teacher_checkpoint: "" # pre-trained weight of the teacher model
59 | pretrain_student_checkpoint: "" # pre-trained weight of the student model
60 |
--------------------------------------------------------------------------------
/config/model/base.yaml:
--------------------------------------------------------------------------------
1 | name: ???
2 |
--------------------------------------------------------------------------------
/config/model/maxvit_yolox/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - override /model: rnndet
4 |
5 | model:
6 | backbone:
7 | name: MaxViTRNN
8 | compile:
9 | enable: False
10 | args:
11 | mode: reduce-overhead
12 | input_channels: 20
13 | enable_masking: False
14 | partition_split_32: 2
15 | embed_dim: 64
16 | dim_multiplier: [1, 2, 4, 8]
17 | num_blocks: [1, 1, 1, 1]
18 | T_max_chrono_init: [4, 8, 16, 32]
19 | stem:
20 | patch_size: 4
21 | stage:
22 | downsample:
23 | type: patch
24 | overlap: True
25 | norm_affine: True
26 | attention:
27 | use_torch_mha: False
28 | partition_size: ???
29 | dim_head: 32
30 | attention_bias: True
31 | mlp_activation: gelu
32 | mlp_gated: False
33 | mlp_bias: True
34 | mlp_ratio: 4
35 | drop_mlp: 0
36 | drop_path: 0
37 | ls_init_value: 1e-5
38 | lstm:
39 | dws_conv: False
40 | dws_conv_only_hidden: True
41 | dws_conv_kernel_size: 3
42 | drop_cell_update: 0
43 | fpn:
44 | name: PAFPN
45 | compile:
46 | enable: False
47 | args:
48 | mode: reduce-overhead
49 | depth: 0.67 # round(depth * 3) == num bottleneck blocks
50 | # stage 1 is the first and len(num_layers) is the last
51 | in_stages: [2, 3, 4]
52 | depthwise: False
53 | act: "silu"
54 | head:
55 | name: YoloX
56 | compile:
57 | enable: False
58 | args:
59 | mode: reduce-overhead
60 | depthwise: False
61 | act: "silu"
62 | postprocess:
63 | confidence_threshold: 0.1
64 | nms_threshold: 0.45
65 |
--------------------------------------------------------------------------------
/config/model/pseudo_labeler-gen4-wsod.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - base
3 |
4 | name: pseudo_labeler
5 | backbone:
6 | name: ???
7 | fpn:
8 | name: ???
9 | head:
10 | name: ???
11 | postprocess:
12 | confidence_threshold: 0.1
13 | nms_threshold: 0.45
14 | # SSOD-related configs
15 | pseudo_label:
16 | # if a sub-seq `is_first_sample`, we skip predictions at the first `skip_first_t` timesteps
17 | # because they don't have enough history information --> not accurate
18 | skip_first_t: 0
19 | # thresholds for filtering pseudo labels
20 | obj_thresh: [0.6, 0.5] # thresholds for each category
21 | cls_thresh: [0.6, 0.5] # gen1: ('car', 'ped'); gen4: ('ped', 'cyc', 'car')
22 | # by default we will use the same thresholds for ped and cyc
23 | # post-process using offline tracker, ignore bbox with short track length
24 | min_track_len: 6
25 | track_method: 'forward or backward'
26 | # 'forward or backward' (short in both directions --> ignore)
27 | inpaint: True # hallucinate bbox at missing frames
28 | ignore_label: 1024
29 |
--------------------------------------------------------------------------------
/config/model/pseudo_labeler.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - base
3 |
4 | name: pseudo_labeler
5 | backbone:
6 | name: ???
7 | fpn:
8 | name: ???
9 | head:
10 | name: ???
11 | postprocess:
12 | confidence_threshold: 0.1
13 | nms_threshold: 0.45
14 | # SSOD-related configs
15 | pseudo_label:
16 | # if a sub-seq `is_first_sample`, we skip predictions at the first `skip_first_t` timesteps
17 | # because they don't have enough history information --> not accurate
18 | skip_first_t: 0
19 | # thresholds for filtering pseudo labels
20 | obj_thresh: [0.6, 0.3] # thresholds for each category
21 | cls_thresh: [0.6, 0.3] # gen1: ('car', 'ped'); gen4: ('ped', 'cyc', 'car')
22 | # by default we will use the same thresholds for ped and cyc
23 | # post-process using offline tracker, ignore bbox with short track length
24 | min_track_len: 6
25 | track_method: 'forward or backward'
26 | # 'forward or backward' (short in both directions --> ignore)
27 | inpaint: True # hallucinate bbox at missing frames
28 | ignore_label: 1024
29 |
--------------------------------------------------------------------------------
/config/model/rnndet-soft-gen4-wsod.yaml:
--------------------------------------------------------------------------------
1 | # This config is for self-training RVT in 1Mpx WSOD
2 |
3 | defaults:
4 | - base
5 |
6 | name: rnndet
7 | backbone:
8 | name: ???
9 | fpn:
10 | name: ???
11 | head:
12 | name: ???
13 | obj_focal_loss: False # FocalLoss or BCE on objectness score
14 | bbox_loss_weighting: '' # 'obj' or 'cls' or 'objxcls' or ''
15 | # support further transformations, e.g. 'cls-w**2' is using cls**2 as weights
16 | ignore_bbox_thresh: [0.7, 0.55] # ignore bbox with obj/cls lower than this
17 | ignore_label: 1024 # ignore during training
18 | # don't suppress BG pixels with the highest k% confidence scores
19 | ignore_bg_k: 0
20 | postprocess:
21 | confidence_threshold: 0.1
22 | nms_threshold: 0.45
23 | use_label_every: 1
24 | ignore_image: False # ignore images where all bbox are with `ignore_label`
25 |
--------------------------------------------------------------------------------
/config/model/rnndet-soft.yaml:
--------------------------------------------------------------------------------
1 | # This config is for self-training RVT in all settings except 1Mpx WSOD
2 |
3 | defaults:
4 | - base
5 |
6 | name: rnndet
7 | backbone:
8 | name: ???
9 | fpn:
10 | name: ???
11 | head:
12 | name: ???
13 | obj_focal_loss: False # FocalLoss or BCE on objectness score
14 | bbox_loss_weighting: '' # 'obj' or 'cls' or 'objxcls' or ''
15 | # support further transformations, e.g. 'cls-w**2' is using cls**2 as weights
16 | ignore_bbox_thresh: [0.7, 0.35] # ignore bbox with obj/cls lower than this
17 | ignore_label: 1024 # ignore during training
18 | # don't suppress BG pixels with the highest k% confidence scores
19 | ignore_bg_k: 0
20 | postprocess:
21 | confidence_threshold: 0.1
22 | nms_threshold: 0.45
23 | use_label_every: 1
24 | ignore_image: False # ignore images where all bbox are with `ignore_label`
25 |
--------------------------------------------------------------------------------
/config/model/rnndet.yaml:
--------------------------------------------------------------------------------
1 | # This config is for pre-training RVT on limited annotated data
2 |
3 | defaults:
4 | - base
5 |
6 | name: rnndet
7 | backbone:
8 | name: ???
9 | fpn:
10 | name: ???
11 | head:
12 | name: ???
13 | obj_focal_loss: False # FocalLoss or BCE on objectness score
14 | bbox_loss_weighting: '' # 'obj' or 'cls' or 'objxcls' or ''
15 | # support further transformations, e.g. 'cls-w**2' is using cls**2 as weights
16 | ignore_bbox_thresh: null # ignore bbox with obj/cls lower than this
17 | ignore_label: 1024 # ignore during training
18 | # don't suppress BG pixels with the highest k% confidence scores
19 | ignore_bg_k: 0
20 | postprocess:
21 | confidence_threshold: 0.1
22 | nms_threshold: 0.45
23 | use_label_every: 1
24 | ignore_image: False # ignore images where all bbox are with `ignore_label`
25 |
--------------------------------------------------------------------------------
/config/predict.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: ???
3 | - model: pseudo_labeler
4 | - _self_
5 |
6 | is_train: False
7 | checkpoint: ???
8 | save_dir: "" # dir to save the generated dataset
9 | hardware:
10 | num_workers:
11 | eval: 8
12 | gpus: 0 # GPU idx (multi-gpu not supported for validation)
13 | batch_size:
14 | eval: 16
15 | training:
16 | precision: 16
17 | tta:
18 | enable: False
19 | hflip: True
20 | tflip: True
21 | use_gt: True # take GT labels on labeled frames, or still use pseudo labels
22 |
--------------------------------------------------------------------------------
/config/train.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - general
3 | - dataset: ???
4 | - model: rnndet
5 | - optional model/dataset: ${model}_${dataset}
6 |
7 | is_train: True
8 |
--------------------------------------------------------------------------------
/config/val.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: ???
3 | - model: rnndet
4 | - _self_
5 |
6 | is_train: False
7 | checkpoint: ???
8 | use_test_set: False
9 | reverse: False
10 | hardware:
11 | num_workers:
12 | eval: 8
13 | gpus: 0 # GPU idx (multi-gpu not supported for validation)
14 | batch_size:
15 | eval: 16
16 | training:
17 | precision: 16
18 | tta:
19 | enable: False
20 | hflip: True
21 | tflip: True
22 |
--------------------------------------------------------------------------------
/config/vis.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - dataset: ???
3 | - model: rnndet
4 | - _self_
5 |
6 | is_train: False
7 | checkpoint: ???
8 | num_video: 5
9 | reverse: False
10 | hardware:
11 | num_workers:
12 | eval: 0
13 | gpus: 0 # GPU idx (multi-gpu not supported for validation)
14 | batch_size:
15 | eval: 1
16 | training:
17 | precision: 16
18 | tta:
19 | enable: False
20 | hflip: False
21 | tflip: False
22 |
--------------------------------------------------------------------------------
/data/genx_utils/collate.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Any, Callable, Dict, Optional, Type, Tuple, Union
3 |
4 | import torch
5 |
6 | from data.genx_utils.collate_from_pytorch import collate, default_collate_fn_map
7 | from data.genx_utils.labels import ObjectLabels, SparselyBatchedObjectLabels
8 | from data.utils.augmentor import AugmentationState
9 | from modules.utils.detection import WORKER_ID_KEY, DATA_KEY
10 |
11 |
12 | def collate_object_labels(batch, *, collate_fn_map: Optional[
13 | Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
14 | return batch
15 |
16 |
17 | def collate_sparsely_batched_object_labels(batch, *, collate_fn_map: Optional[
18 | Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
19 | return SparselyBatchedObjectLabels.transpose_list(batch)
20 |
21 |
22 | def collate_augm_state(batch, *, collate_fn_map: Optional[
23 | Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
24 | """Collates a batch of AugmentationState objects into a dictionary where
25 | each field of AugmentationState maps to a list of values.
26 |
27 | Returns: {
28 | 'h_flip': {'active': `B`-len list of bool},
29 | 'zoom_out': {
30 | 'active': [`B`-len list of bool],
31 | 'x0': [`B`-len list of int],
32 | 'y0': [`B`-len list of int],
33 | 'factor': [`B`-len list of float],
34 | },
35 | 'zoom_in': {
36 | 'active': [`B`-len list of bool],
37 | 'x0': [`B`-len list of int],
38 | 'y0': [`B`-len list of int],
39 | 'factor': [`B`-len list of float],
40 | },
41 | 'rotation': {
42 | 'active': [`B`-len list of bool],
43 | 'angle_deg': [`B`-len list of float],
44 | },
45 | }
46 | """
47 | return AugmentationState.collate_augm_state(batch)
48 |
49 |
50 | custom_collate_fn_map = copy.deepcopy(default_collate_fn_map)
51 | custom_collate_fn_map[ObjectLabels] = collate_object_labels
52 | custom_collate_fn_map[SparselyBatchedObjectLabels] = collate_sparsely_batched_object_labels
53 | custom_collate_fn_map[AugmentationState] = collate_augm_state
54 |
55 |
56 | def custom_collate(batch: Any):
57 | return collate(batch, collate_fn_map=custom_collate_fn_map)
58 |
59 |
60 | def custom_collate_rnd(batch: Any):
61 | samples = batch
62 | # NOTE: We do not really need the worker id for map style datasets (rnd) but we still provide the id for consistency
63 | worker_info = torch.utils.data.get_worker_info()
64 | local_worker_id = 0 if worker_info is None else worker_info.id
65 | return {
66 | DATA_KEY: custom_collate(samples),
67 | WORKER_ID_KEY: local_worker_id,
68 | }
69 |
70 |
71 | def custom_collate_streaming(batch: Any):
72 | """We assume that we receive a batch collected by a worker of our streaming datapipe
73 | """
74 | samples = batch[0]
75 | worker_id = batch[1]
76 | assert isinstance(worker_id, int)
77 | return {
78 | DATA_KEY: custom_collate(samples),
79 | WORKER_ID_KEY: worker_id,
80 | }
81 |
--------------------------------------------------------------------------------
/data/genx_utils/splits/gen1/ssod_0.010-off0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen1/ssod_0.010-off0.pkl
--------------------------------------------------------------------------------
/data/genx_utils/splits/gen1/ssod_0.020-off0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen1/ssod_0.020-off0.pkl
--------------------------------------------------------------------------------
/data/genx_utils/splits/gen1/ssod_0.050-off0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen1/ssod_0.050-off0.pkl
--------------------------------------------------------------------------------
/data/genx_utils/splits/gen1/ssod_0.100-off0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen1/ssod_0.100-off0.pkl
--------------------------------------------------------------------------------
/data/genx_utils/splits/gen4/ssod_0.010-off0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen4/ssod_0.010-off0.pkl
--------------------------------------------------------------------------------
/data/genx_utils/splits/gen4/ssod_0.020-off0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen4/ssod_0.020-off0.pkl
--------------------------------------------------------------------------------
/data/genx_utils/splits/gen4/ssod_0.050-off0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen4/ssod_0.050-off0.pkl
--------------------------------------------------------------------------------
/data/genx_utils/splits/gen4/ssod_0.100-off0.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen4/ssod_0.100-off0.pkl
--------------------------------------------------------------------------------
/data/utils/misc.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, List
2 |
3 | import os
4 |
5 | import h5py
6 | import numpy as np
7 |
8 | from data.genx_utils.labels import ObjectLabelFactory, ObjectLabels
9 |
10 |
11 | def get_labels_npz_fn(seq_dir: str) -> str:
12 | """Get labels npz file name."""
13 | # seq_dir: path/to/dataset/train/18-03-29_13-15-02_5_605
14 | # labels_npz_fn: .../18-03-29_13-15-02_5_605/labels_v2/labels.npz
15 | labels_npz_fn = os.path.join(seq_dir, 'labels_v2/labels.npz')
16 | return labels_npz_fn
17 |
18 |
19 | def read_npz_labels(label_fn: str) -> Tuple[np.ndarray, np.ndarray]:
20 | """Read labels from npz file."""
21 | # if it's just .../18-03-29_13-15-02_5_605, we get the true npz first
22 | if 'labels_v2' not in label_fn:
23 | label_fn = get_labels_npz_fn(label_fn)
24 | labels = np.load(label_fn)
25 | return labels['labels'], labels['objframe_idx_2_label_idx']
26 |
27 |
28 | def read_labels_as_list(seq_dir: str, dst_cfg,
29 | L: int, start_idx: int = 0) -> List[ObjectLabels]:
30 | """Read the original full-frequency GT labels."""
31 | # seq_dir: .../18-03-29_13-15-02_5_605
32 | labels, objframe_idx_2_label_idx = read_npz_labels(seq_dir)
33 | hw = tuple(dst_cfg.ev_repr_hw)
34 | if dst_cfg.downsample_by_factor_2:
35 | hw = tuple(s * 2 for s in hw)
36 | label_factory = ObjectLabelFactory.from_structured_array(
37 | labels, objframe_idx_2_label_idx, hw,
38 | 2 if dst_cfg.downsample_by_factor_2 else None)
39 | ev_dir = get_ev_dir(seq_dir)
40 | objframe_idx_2_repr_idx = np.load(
41 | os.path.join(ev_dir, 'objframe_idx_2_repr_idx.npy'))
42 | obj_labels = [None] * L
43 | for objframe_idx, repr_idx in enumerate(objframe_idx_2_repr_idx):
44 | if start_idx <= repr_idx < start_idx + L:
45 | obj_labels[repr_idx - start_idx] = label_factory[objframe_idx]
46 | return obj_labels
47 |
48 |
49 | def get_ev_dir(seq_dir: str) -> str:
50 | """Get event representation directory."""
51 | # seq_dir: path/to/dataset/train/18-03-29_13-15-02_5_605
52 | # ev_dir: .../18-03-29_13-15-02_5_605/event_representations_v2/stacked_histogram_dt=50_nbins=10/
53 | ev_dir = os.path.join(seq_dir, 'event_representations_v2',
54 | 'stacked_histogram_dt=50_nbins=10')
55 | return ev_dir
56 |
57 |
58 | def get_objframe_idx_2_repr_idx_fn(ev_dir: str) -> str:
59 | """Get xxx/objframe_idx_2_repr_idx.npy file name."""
60 | if 'event_representations_v2' not in ev_dir:
61 | ev_dir = get_ev_dir(ev_dir)
62 | # ev_dir: .../18-03-29_13-15-02_5_605/event_representations_v2/stacked_histogram_dt=50_nbins=10/
63 | # fn: .../18-03-29_13-15-02_5_605/event_representations_v2/stacked_histogram_dt=50_nbins=10/objframe_idx_2_repr_idx.npy
64 | fn = os.path.join(ev_dir, 'objframe_idx_2_repr_idx.npy')
65 | return fn
66 |
67 |
68 | def get_ev_h5_fn(ev_dir: str, dst_name: str = None) -> str:
69 | """Get event representation h5 file name."""
70 | # if it's just .../18-03-29_13-15-02_5_605, we get the true ev_dir first
71 | if 'event_representations_v2' not in ev_dir:
72 | ev_dir = get_ev_dir(ev_dir)
73 | # ev_dir: .../18-03-29_13-15-02_5_605/event_representations_v2/stacked_histogram_dt=50_nbins=10/
74 | # ev_h5_fn: .../18-03-29_13-15-02_5_605/event_representations_v2/stacked_histogram_dt=50_nbins=10/event_representations.h5
75 | if dst_name is None:
76 | dst_name = 'gen1' if 'gen1' in ev_dir else 'gen4'
77 | ev_name = 'event_representations.h5' if dst_name == 'gen1' else \
78 | 'event_representations_ds2_nearest.h5'
79 | ev_h5_fn = os.path.join(ev_dir, ev_name)
80 | return ev_h5_fn
81 |
82 |
83 | def read_ev_repr(h5f: str) -> np.ndarray:
84 | if 'event_representations_v2' not in h5f:
85 | h5f = get_ev_h5_fn(h5f)
86 | with h5py.File(str(h5f), 'r') as h5f:
87 | ev_repr = h5f['data'][:]
88 | return ev_repr
89 |
90 |
91 | def read_objframe_idx_2_repr_idx(npy_fn: str) -> np.ndarray:
92 | if 'event_representations_v2' not in npy_fn:
93 | npy_fn = get_objframe_idx_2_repr_idx_fn(npy_fn)
94 | return np.load(npy_fn)
95 |
--------------------------------------------------------------------------------
/data/utils/spatial.py:
--------------------------------------------------------------------------------
1 | from omegaconf import DictConfig
2 |
3 | from data.utils.types import DatasetType
4 |
5 | _type_2_hw = {
6 | DatasetType.GEN1: (240, 304),
7 | DatasetType.GEN4: (720, 1280),
8 | }
9 |
10 | _str_2_type = {
11 | 'gen1': DatasetType.GEN1,
12 | 'gen4': DatasetType.GEN4,
13 | }
14 |
15 |
16 | def get_original_hw(dataset_type: DatasetType):
17 | return _type_2_hw[dataset_type]
18 |
19 |
20 | def get_dataloading_hw(dataset_config: DictConfig):
21 | dataset_name = dataset_config.name
22 | hw = get_original_hw(dataset_type=_str_2_type[dataset_name])
23 | downsample_by_factor_2 = dataset_config.downsample_by_factor_2
24 | if downsample_by_factor_2:
25 | hw = tuple(x // 2 for x in hw)
26 | return hw
27 |
--------------------------------------------------------------------------------
/data/utils/stream_concat_datapipe.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Iterator, List, Optional, Type
2 |
3 | import torch as th
4 | import torch.distributed as dist
5 | from torch.utils.data import DataLoader
6 | from torchdata.datapipes.iter import (
7 | Concater,
8 | IterableWrapper,
9 | IterDataPipe,
10 | Zipper,
11 | )
12 | from torchdata.datapipes.map import MapDataPipe
13 |
14 |
15 | class DummyIterDataPipe(IterDataPipe):
16 | def __init__(self, source_dp: IterDataPipe):
17 | super().__init__()
18 | assert isinstance(source_dp, IterDataPipe)
19 | self.source_dp = source_dp
20 |
21 | def __iter__(self):
22 | yield from self.source_dp
23 |
24 |
25 | class ConcatStreamingDataPipe(IterDataPipe):
26 | """This Dataset avoids the sharding problem by instantiating randomized stream concatenation at the batch and
27 | worker level.
28 | Pros:
29 | - Every single batch has valid samples. Consequently, the batch size is always constant.
30 | Cons:
31 | - There might be repeated samples in a batch. Although they should be different because of data augmentation.
32 | - Cannot be used for validation or testing because we repeat the dataset multiple times in an epoch.
33 |
34 | TLDR: preferred approach for training but not useful for validation or testing.
35 | """
36 |
37 | def __init__(self,
38 | datapipe_list: List[MapDataPipe],
39 | batch_size: int,
40 | num_workers: int,
41 | augmentation_pipeline: Optional[Type[IterDataPipe]] = None,
42 | print_seed_debug: bool = False):
43 | super().__init__()
44 | assert batch_size > 0
45 |
46 | if augmentation_pipeline is not None:
47 | self.augmentation_dp = augmentation_pipeline
48 | else:
49 | self.augmentation_dp = DummyIterDataPipe
50 |
51 | # We require MapDataPipes instead of IterDataPipes because IterDataPipes must be deepcopied in each worker.
52 | # Instead, MapDataPipes can be converted to IterDataPipes in each worker without requiring a deepcopy.
53 | self.datapipe_list = datapipe_list
54 | self.batch_size = batch_size
55 |
56 | self.print_seed_debug = print_seed_debug
57 |
58 | @staticmethod
59 | def random_torch_shuffle_list(data: List[Any]) -> Iterator[Any]:
60 | assert isinstance(data, List)
61 | return (data[idx] for idx in th.randperm(len(data)).tolist())
62 |
63 | def _get_zipped_streams(self, datapipe_list: List[MapDataPipe], batch_size: int):
64 | """Use it only in the iter function of this class!
65 | Reason: randomized shuffling must happen within each worker. Otherwise, the same random order will be used
66 | for all workers.
67 | """
68 | assert isinstance(datapipe_list, List)
69 | assert batch_size > 0
70 | # randomly shuffle datapipes and concat them as one datapipe
71 | # repeat it `BS` times so that we load `BS` samples per `next()` call
72 | streams = Zipper(*(Concater(*(self.augmentation_dp(x.to_iter_datapipe())
73 | for x in self.random_torch_shuffle_list(datapipe_list)))
74 | for _ in range(batch_size)))
75 | return streams
76 |
77 | def _print_seed_debug_info(self):
78 | """Debug purpose only."""
79 | worker_info = th.utils.data.get_worker_info()
80 | local_worker_id = 0 if worker_info is None else worker_info.id
81 |
82 | worker_torch_seed = worker_info.seed
83 | local_num_workers = 1 if worker_info is None else worker_info.num_workers
84 | if dist.is_available() and dist.is_initialized():
85 | global_rank = dist.get_rank()
86 | else:
87 | global_rank = 0
88 | global_worker_id = global_rank * local_num_workers + local_worker_id
89 |
90 | rnd_number = th.randn(1)
91 | print(f'{worker_torch_seed=},\t{global_worker_id=},\t{global_rank=},\t{local_worker_id=},\t{rnd_number=}',
92 | flush=True)
93 |
94 | def _get_zipped_streams_with_worker_id(self):
95 | """Use it only in the iter function of this class!
96 | """
97 | worker_info = th.utils.data.get_worker_info()
98 | local_worker_id = 0 if worker_info is None else worker_info.id
99 | # always return self's worker_id
100 | worker_id_stream = IterableWrapper([local_worker_id]).cycle(count=None)
101 | # get ONE datapipe that loads a batch of data at every `next()` call
102 | zipped_stream = self._get_zipped_streams(datapipe_list=self.datapipe_list, batch_size=self.batch_size)
103 | return zipped_stream.zip(worker_id_stream)
104 |
105 | def __iter__(self):
106 | if self.print_seed_debug:
107 | self._print_seed_debug_info()
108 | return iter(self._get_zipped_streams_with_worker_id())
109 |
--------------------------------------------------------------------------------
/data/utils/types.py:
--------------------------------------------------------------------------------
1 | from enum import auto, Enum
2 |
3 | try:
4 | from enum import StrEnum
5 | except ImportError:
6 | from strenum import StrEnum
7 | from typing import Dict, List, Optional, Tuple, Union, Any
8 |
9 | import torch as th
10 |
11 | from data.genx_utils.labels import ObjectLabels, SparselyBatchedObjectLabels
12 | # from data.utils.augmentor import AugmentationState # avoid circular import
13 |
14 |
15 | class DataType(Enum):
16 | PATH = auto() # 'path/to/dst/test/17-06-97_12-14-33_244500000_304500000'
17 | EV_IDX = auto() # index of the loaded ev_repr in the entire sequence
18 | EV_REPR = auto()
19 | FLOW = auto()
20 | IMAGE = auto()
21 | OBJLABELS = auto()
22 | OBJLABELS_SEQ = auto()
23 | SKIPPED_OBJLABELS_SEQ = auto() # GT labels that are skipped in SSOD
24 | IS_PADDED_MASK = auto()
25 | IS_FIRST_SAMPLE = auto()
26 | IS_LAST_SAMPLE = auto()
27 | IS_REVERSED = auto() # whether the sequence is in reverse order
28 | TOKEN_MASK = auto()
29 | PRED_MASK = auto() # if the teacher do prediction at this frame
30 | GT_MASK = auto() # which labels are from GT, others are pseudo labels
31 | PRED_PROBS = auto() # obj/cls_probs predicted by the teacher model
32 | AUGM_STATE = auto() # augmentation state
33 |
34 |
35 | class DatasetType(Enum):
36 | GEN1 = auto()
37 | GEN4 = auto()
38 |
39 |
40 | class DatasetMode(Enum):
41 | TRAIN = auto()
42 | VALIDATION = auto()
43 | TESTING = auto()
44 |
45 |
46 | class DatasetSamplingMode(StrEnum):
47 | RANDOM = 'random'
48 | STREAM = 'stream'
49 | MIXED = 'mixed'
50 |
51 |
52 | class ObjDetOutput(Enum):
53 | LABELS_PROPH = auto()
54 | PRED_PROPH = auto()
55 | EV_REPR = auto()
56 | SKIP_VIZ = auto()
57 |
58 |
59 | LoaderDataDictGenX = Dict[DataType, Union[List[th.Tensor], ObjectLabels,
60 | SparselyBatchedObjectLabels,
61 | # AugmentationState,
62 | List[bool], ]]
63 |
64 | LstmState = Optional[Tuple[th.Tensor, th.Tensor]]
65 | LstmStates = List[LstmState]
66 |
67 | FeatureMap = th.Tensor
68 | BackboneFeatures = Dict[int, th.Tensor]
69 | BatchAugmState = Dict[str, Dict[str, List[Any]]]
70 |
--------------------------------------------------------------------------------
/datasets/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/datasets/.gitkeep
--------------------------------------------------------------------------------
/docs/install.md:
--------------------------------------------------------------------------------
1 | # Install
2 |
3 | Most of this section is the same as [RVT](https://github.com/uzh-rpg/RVT).
4 | We also heavily rely on a personal package [nerv](https://github.com/Wuziyi616/nerv) for utility functions.
5 |
6 | ## Environment Setup
7 |
8 | Please use [Anaconda](https://www.anaconda.com/) for package management.
9 | ```Bash
10 | conda create -y -n leod python=3.9
11 | conda activate leod
12 |
13 | conda install -y pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia
14 |
15 | python -m pip install tqdm numba hdf5plugin h5py==3.8.0 \
16 | pandas==1.5.3 plotly==5.13.1 opencv-python==4.6.0.66 tabulate==0.9.0 \
17 | pycocotools==2.0.6 bbox-visualizer==0.1.0 StrEnum==0.4.10 \
18 | opencv-python hydra-core==1.3.2 einops==0.6.0 \
19 | pytorch-lightning==1.8.6 wandb==0.14.0 torchdata==0.6.0
20 |
21 | conda install -y blosc-hdf5-plugin -c conda-forge
22 |
23 | # install nerv: https://github.com/Wuziyi616/nerv
24 | git clone git@github.com:Wuziyi616/nerv.git
25 | cd nerv
26 | git checkout v0.4.0 # tested with v0.4.0 release
27 | pip install -e .
28 | cd .. # go back to the root directory of the project
29 |
30 | # (Optional) compile Detectron2 for faster evaluation
31 | python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
32 | ```
33 |
34 | We also provide a `environment.yml` file which lists all the packages in our final environment.
35 |
36 | I sometimes encounter a weird bug where `Detectron2` cannot run on types of GPUs different from the one I compile it on (e.g., if I compile it on RTX6000 GPUs, I cannot use it on A40 GPUs).
37 | To avoid this issue, go to [coco_eval.py](../utils/evaluation/prophesee/metrics/coco_eval.py#L17) and set the `compile_gpu` to the GPU you compile it (the program will not import `Detectron2` when detecting a different GPUs in use).
38 |
39 | ## Dataset
40 |
41 | In this project, we use two datasets: Gen1 and 1Mpx.
42 | Following the convention of RVT, we name Gen1 as `gen1` and 1Mpx as `gen4` (because of the camera used to capture them).
43 | Please download the pre-processed datasets from RVT:
44 |
45 |
46 | |
47 | 1 Mpx |
48 | Gen1 |
49 | pre-processed dataset |
50 | download |
51 | download |
52 |
53 | crc32 |
54 | c5ec7c38 |
55 | 5acab6f3 |
56 |
57 |
58 |
59 | After downloading and unzipping the datasets, soft link Gen1 to `./datasets/gen1` and 1Mpx to `./datasets/gen4`.
60 |
61 | ### Data Splits
62 |
63 | To simulate the weakly-/semi-supervised learning settings, we need to sub-sample labels from the original dataset.
64 | An important thing is that we need to keep the data split the same across experiments.
65 | - For semi-supervised setting where we keep the labels for some sequences while making other sequences completely unlabeled, it is relatively easy.
66 | We just sort the name of event sequences so that their order will be deterministic across runs, and select unlabeled sequences from it.
67 | - For weakly-supervised setting where we sub-sample the labels for all sequences, it is a bit tricky because there are two mode of data sampling in the codebase, and they pre-process events in different ways.
68 | To have a consistent data split, we create a split file for each setting, which are stored [here](../data/genx_utils/splits/).
69 | If you want to explore new experimental settings, remember to create your own split files and read from them [here](../data/genx_utils/dataset_streaming.py#L62).
70 |
71 | All results in the paper are averaged over three different splits (we offset the index when sub-sampling the data).
72 | Overall, the performance variations are very small across different splits.
73 | Therefore, we only release the split files, config files, and pre-trained weights for the first variant we experimented with.
74 |
75 | ## Pre-trained Weights
76 |
77 | We provide checkpoints for all the models used to produce the final performance in the paper.
78 | In addition, we provide models pre-trained on the limited annotated data (the `Supervised Baseline` method in the paper) to ease your experiments.
79 |
80 | Please download the pre-trained weights from [Google Drive](https://drive.google.com/file/d/1xBzFovvNbrtBt0YwYcvvrjbV8ozAdCUK/view?usp=sharing) and unzip them to `./pretrained/`.
81 | The weights are grouped by the Section they are presented in the paper.
82 | They naming follows the pattern `rvt-{$MODEL_SIZE}-{$DATASET}x{$RATIO_OF_DATA}_{$SETTING}.ckpt`.
83 |
84 | For example, `rvt-s-gen1x0.02_ss.ckpt` is the RVT-S pre-trained on 2% of Gen1 data under the weakly-supervised setting.
85 | `rvt-s-gen4x0.05_ss-final.ckpt` is the RVT-S trained on 5% of 1Mpx data under the semi-supervised setting, and `-final` means it is the LEOD self-trained model (used to produce the results in the paper).
86 |
87 | **Note:** it might be a bit confusing, but `ss` means weakly-supervised (all event sequences are sparsely labeled) and `seq` means semi-supervised (some event sequences are densely labeled, while others are completely unlabeled).
88 |
--------------------------------------------------------------------------------
/loggers/utils.py:
--------------------------------------------------------------------------------
1 | from omegaconf import DictConfig
2 | from pytorch_lightning.loggers.wandb import WandbLogger
3 |
4 |
5 | def get_wandb_logger(full_config: DictConfig) -> WandbLogger:
6 | """Build the native PyTorch Lightning WandB logger."""
7 | wandb_config = full_config.wandb
8 |
9 | logger = WandbLogger(
10 | project=wandb_config.project_name,
11 | name=wandb_config.wandb_name,
12 | id=wandb_config.wandb_id,
13 | # specify both to make sure it's saved in the right place
14 | save_dir=wandb_config.wandb_runpath,
15 | dir=wandb_config.wandb_runpath,
16 | # group=wandb_config.group_name, # not very useful
17 | # log_model=True, # don't log model weights
18 | # save_last_only_final=False,
19 | # save_code=True,
20 | # config_args=full_config_dict,
21 | )
22 |
23 | return logger
24 |
--------------------------------------------------------------------------------
/models/detection/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/models/detection/__init__.py
--------------------------------------------------------------------------------
/models/detection/recurrent_backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from omegaconf import DictConfig
2 |
3 | from .maxvit_rnn import RNNDetector as MaxViTRNNDetector
4 |
5 |
6 | def build_recurrent_backbone(backbone_cfg: DictConfig):
7 | name = backbone_cfg.name
8 | if name == 'MaxViTRNN':
9 | return MaxViTRNNDetector(backbone_cfg)
10 | else:
11 | raise NotImplementedError
12 |
--------------------------------------------------------------------------------
/models/detection/recurrent_backbone/base.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import torch.nn as nn
4 |
5 |
6 | class BaseDetector(nn.Module):
7 | def get_stage_dims(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:
8 | raise NotImplementedError
9 |
10 | def get_strides(self, stages: Tuple[int, ...]) -> Tuple[int, ...]:
11 | raise NotImplementedError
12 |
--------------------------------------------------------------------------------
/models/detection/yolox/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/models/detection/yolox/models/__init__.py
--------------------------------------------------------------------------------
/models/detection/yolox/models/losses.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- encoding: utf-8 -*-
3 | # Copyright (c) Megvii Inc. All rights reserved.
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.ops as ops
8 |
9 |
10 | class IOUloss(nn.Module):
11 | """IoU loss for bbox regression."""
12 |
13 | def __init__(self, reduction="none", loss_type="iou"):
14 | super(IOUloss, self).__init__()
15 | self.reduction = reduction
16 | self.loss_type = loss_type
17 |
18 | def forward(self, pred, target, weights=None):
19 | assert pred.shape[0] == target.shape[0]
20 | if pred.shape[0] == 0:
21 | return 0.
22 |
23 | pred = pred.view(-1, 4)
24 | target = target.view(-1, 4)
25 | tl = torch.max(
26 | (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
27 | )
28 | br = torch.min(
29 | (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
30 | )
31 |
32 | area_p = torch.prod(pred[:, 2:], 1)
33 | area_g = torch.prod(target[:, 2:], 1)
34 |
35 | en = (tl < br).type(tl.type()).prod(dim=1)
36 | area_i = torch.prod(br - tl, 1) * en
37 | area_u = area_p + area_g - area_i
38 | iou = (area_i) / (area_u + 1e-16)
39 |
40 | if self.loss_type == "iou":
41 | loss = 1 - iou ** 2
42 | elif self.loss_type == "giou":
43 | c_tl = torch.min(
44 | (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
45 | )
46 | c_br = torch.max(
47 | (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
48 | )
49 | area_c = torch.prod(c_br - c_tl, 1)
50 | giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
51 | loss = 1 - giou.clamp(min=-1.0, max=1.0)
52 | else:
53 | raise NotImplementedError
54 |
55 | # weighted IoU loss
56 | if weights is not None and isinstance(weights, torch.Tensor) and \
57 | (weights != 1.).any():
58 | assert weights.shape[0] == loss.shape[0]
59 | loss = loss * weights
60 |
61 | if self.reduction == "mean":
62 | loss = loss.mean()
63 | elif self.reduction == "sum":
64 | loss = loss.sum()
65 |
66 | return loss
67 |
68 |
69 | class FocalLoss(nn.Module):
70 | """Focal loss for foreground-background classification (objectness)."""
71 |
72 | def __init__(self, alpha=0.25, gamma=2, reduction='none'):
73 | super().__init__()
74 |
75 | self.alpha = alpha
76 | self.gamma = gamma
77 | self.reduction = reduction
78 |
79 | def forward(self, inputs, targets):
80 | return ops.sigmoid_focal_loss(
81 | inputs=inputs,
82 | targets=targets,
83 | alpha=self.alpha,
84 | gamma=self.gamma,
85 | reduction=self.reduction)
86 |
--------------------------------------------------------------------------------
/models/detection/yolox/models/network_blocks.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- encoding: utf-8 -*-
3 | # Copyright (c) Megvii Inc. All rights reserved.
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class SiLU(nn.Module):
10 | """export-friendly version of nn.SiLU()"""
11 |
12 | @staticmethod
13 | def forward(x):
14 | return x * torch.sigmoid(x)
15 |
16 |
17 | def get_activation(name="silu", inplace=True):
18 | if name == "silu":
19 | module = nn.SiLU(inplace=inplace)
20 | elif name == "relu":
21 | module = nn.ReLU(inplace=inplace)
22 | elif name == "lrelu":
23 | module = nn.LeakyReLU(0.1, inplace=inplace)
24 | else:
25 | raise AttributeError("Unsupported act type: {}".format(name))
26 | return module
27 |
28 |
29 | class BaseConv(nn.Module):
30 | """A Conv2d -> Batchnorm -> silu/leaky relu block"""
31 |
32 | def __init__(
33 | self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"
34 | ):
35 | super().__init__()
36 | # same padding
37 | pad = (ksize - 1) // 2
38 | self.conv = nn.Conv2d(
39 | in_channels,
40 | out_channels,
41 | kernel_size=ksize,
42 | stride=stride,
43 | padding=pad,
44 | groups=groups,
45 | bias=bias,
46 | )
47 | self.bn = nn.BatchNorm2d(out_channels)
48 | self.act = get_activation(act, inplace=True)
49 |
50 | def forward(self, x):
51 | return self.act(self.bn(self.conv(x)))
52 |
53 | def fuseforward(self, x):
54 | return self.act(self.conv(x))
55 |
56 |
57 | class DWConv(nn.Module):
58 | """Depthwise Conv + Conv"""
59 |
60 | def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"):
61 | super().__init__()
62 | self.dconv = BaseConv(
63 | in_channels,
64 | in_channels,
65 | ksize=ksize,
66 | stride=stride,
67 | groups=in_channels,
68 | act=act,
69 | )
70 | self.pconv = BaseConv(
71 | in_channels, out_channels, ksize=1, stride=1, groups=1, act=act
72 | )
73 |
74 | def forward(self, x):
75 | x = self.dconv(x)
76 | return self.pconv(x)
77 |
78 |
79 | class Bottleneck(nn.Module):
80 | # Standard bottleneck
81 | def __init__(
82 | self,
83 | in_channels,
84 | out_channels,
85 | shortcut=True,
86 | expansion=0.5,
87 | depthwise=False,
88 | act="silu",
89 | ):
90 | super().__init__()
91 | hidden_channels = int(out_channels * expansion)
92 | Conv = DWConv if depthwise else BaseConv
93 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
94 | self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act)
95 | self.use_add = shortcut and in_channels == out_channels
96 |
97 | def forward(self, x):
98 | y = self.conv2(self.conv1(x))
99 | if self.use_add:
100 | y = y + x
101 | return y
102 |
103 |
104 | class CSPLayer(nn.Module):
105 | """C3 in yolov5, CSP Bottleneck with 3 convolutions"""
106 |
107 | def __init__(
108 | self,
109 | in_channels,
110 | out_channels,
111 | n=1,
112 | shortcut=True,
113 | expansion=0.5,
114 | depthwise=False,
115 | act="silu",
116 | ):
117 | """
118 | Args:
119 | in_channels (int): input channels.
120 | out_channels (int): output channels.
121 | n (int): number of Bottlenecks. Default value: 1.
122 | """
123 | # ch_in, ch_out, number, shortcut, groups, expansion
124 | super().__init__()
125 | hidden_channels = int(out_channels * expansion) # hidden channels
126 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
127 | self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act)
128 | self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act)
129 | module_list = [
130 | Bottleneck(
131 | hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act
132 | )
133 | for _ in range(n)
134 | ]
135 | self.m = nn.Sequential(*module_list)
136 |
137 | def forward(self, x):
138 | x_1 = self.conv1(x)
139 | x_2 = self.conv2(x)
140 | x_1 = self.m(x_1)
141 | x = torch.cat((x_1, x_2), dim=1)
142 | return self.conv3(x)
143 |
--------------------------------------------------------------------------------
/models/detection/yolox/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 | # Copyright (c) Megvii Inc. All rights reserved.
4 |
5 | from .boxes import *
6 | from .compat import meshgrid
7 |
--------------------------------------------------------------------------------
/models/detection/yolox/utils/compat.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding:utf-8 -*-
3 |
4 | import torch
5 |
6 | _TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]]
7 |
8 | __all__ = ["meshgrid"]
9 |
10 |
11 | def meshgrid(*tensors):
12 | if _TORCH_VER >= [1, 10]:
13 | return torch.meshgrid(*tensors, indexing="ij")
14 | else:
15 | return torch.meshgrid(*tensors)
16 |
--------------------------------------------------------------------------------
/models/detection/yolox_extension/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/models/detection/yolox_extension/models/__init__.py
--------------------------------------------------------------------------------
/models/detection/yolox_extension/models/build.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | from omegaconf import OmegaConf, DictConfig
4 |
5 | from .yolo_pafpn import YOLOPAFPN
6 | from ...yolox.models.yolo_head import YOLOXHead
7 |
8 |
9 | def build_yolox_head(head_cfg: DictConfig, in_channels: Tuple[int, ...], strides: Tuple[int, ...], ssod: bool = False):
10 | assert not ssod # legacy code already removed, the flag is useless now
11 | head_cfg_dict = OmegaConf.to_container(head_cfg, resolve=True, throw_on_missing=True)
12 | head_cfg_dict.pop('name')
13 | head_cfg_dict.pop('version', None)
14 | head_cfg_dict.update({"in_channels": in_channels})
15 | head_cfg_dict.update({"strides": strides})
16 | compile_cfg = head_cfg_dict.pop('compile', None)
17 | head_cfg_dict.update({"compile_cfg": compile_cfg})
18 | module = YOLOXHead
19 | return module(**head_cfg_dict)
20 |
21 |
22 | def build_yolox_fpn(fpn_cfg: DictConfig, in_channels: Tuple[int, ...]):
23 | fpn_cfg_dict = OmegaConf.to_container(fpn_cfg, resolve=True, throw_on_missing=True)
24 | fpn_name = fpn_cfg_dict.pop('name')
25 | fpn_cfg_dict.update({"in_channels": in_channels})
26 | if fpn_name in {'PAFPN', 'pafpn'}:
27 | compile_cfg = fpn_cfg_dict.pop('compile', None)
28 | fpn_cfg_dict.update({"compile_cfg": compile_cfg})
29 | return YOLOPAFPN(**fpn_cfg_dict)
30 | raise NotImplementedError
31 |
--------------------------------------------------------------------------------
/models/detection/yolox_extension/models/detector.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Tuple, Union
2 |
3 | import torch as th
4 | from omegaconf import DictConfig
5 |
6 | try:
7 | from torch import compile as th_compile
8 | except ImportError:
9 | th_compile = None
10 |
11 | from ...recurrent_backbone import build_recurrent_backbone
12 | from .build import build_yolox_fpn, build_yolox_head
13 | from utils.timers import TimerDummy as CudaTimer
14 | # from utils.timers import CudaTimer
15 | from data.utils.types import BackboneFeatures, LstmStates
16 |
17 |
18 | class YoloXDetector(th.nn.Module):
19 | """RNN-based MaxViT backbone + YOLOX detection head."""
20 |
21 | def __init__(self, model_cfg: DictConfig, ssod: bool = False):
22 | super().__init__()
23 | backbone_cfg = model_cfg.backbone
24 | fpn_cfg = model_cfg.fpn
25 | head_cfg = model_cfg.head
26 |
27 | self.backbone = build_recurrent_backbone(backbone_cfg) # maxvit_rnn
28 |
29 | in_channels = self.backbone.get_stage_dims(fpn_cfg.in_stages)
30 | self.fpn = build_yolox_fpn(fpn_cfg, in_channels=in_channels)
31 |
32 | strides = self.backbone.get_strides(fpn_cfg.in_stages)
33 | self.yolox_head = build_yolox_head(head_cfg, in_channels=in_channels, strides=strides, ssod=ssod)
34 |
35 | def forward_backbone(self,
36 | x: th.Tensor,
37 | previous_states: Optional[LstmStates] = None,
38 | token_mask: Optional[th.Tensor] = None) -> \
39 | Tuple[BackboneFeatures, LstmStates]:
40 | """Extract multi-stage features from the backbone.
41 |
42 | Input:
43 | x: (B, C, H, W), image
44 | previous_states: List[(lstm_h, lstm_c)], RNN states from prev timestep
45 | token_mask: (B, H, W) or None, pixel padding mask
46 |
47 | Returns:
48 | backbone_features: Dict{stage_id: feats, [B, C, h, w]}, multi-stage
49 | states: List[(lstm_h, lstm_c), same shape], RNN state of each stage
50 | """
51 | with CudaTimer(device=x.device, timer_name="Backbone"):
52 | backbone_features, states = self.backbone(x, previous_states, token_mask)
53 | return backbone_features, states
54 |
55 | def forward_detect(self,
56 | backbone_features: BackboneFeatures,
57 | targets: Optional[th.Tensor] = None,
58 | soft_targets: Optional[th.Tensor] = None) -> \
59 | Tuple[th.Tensor, Union[Dict[str, th.Tensor], None]]:
60 | """Predict object bbox from multi-stage features.
61 |
62 | Returns:
63 | outputs: (B, N, 4 + 1 + num_cls), [(x, y, w, h), obj_conf, cls]
64 | losses: Dict{loss_name: loss, torch.scalar tensor}
65 | """
66 | device = next(iter(backbone_features.values())).device
67 | with CudaTimer(device=device, timer_name="FPN"):
68 | fpn_features = self.fpn(backbone_features) # Tuple(feats, [B, C, h, w])
69 | if self.training:
70 | assert targets is not None
71 | with CudaTimer(device=device, timer_name="HEAD + Loss"):
72 | outputs, losses = self.yolox_head(fpn_features, targets, soft_targets)
73 | return outputs, losses
74 | with CudaTimer(device=device, timer_name="HEAD"):
75 | outputs, losses = self.yolox_head(fpn_features)
76 | assert losses is None
77 | return outputs, losses
78 |
79 | def forward(self,
80 | x: th.Tensor,
81 | previous_states: Optional[LstmStates] = None,
82 | retrieve_detections: bool = True,
83 | targets: Optional[th.Tensor] = None) -> \
84 | Tuple[Union[th.Tensor, None], Union[Dict[str, th.Tensor], None], LstmStates]:
85 | backbone_features, states = self.forward_backbone(x, previous_states)
86 | outputs, losses = None, None
87 | if not retrieve_detections:
88 | assert targets is None
89 | return outputs, losses, states
90 | outputs, losses = self.forward_detect(backbone_features=backbone_features, targets=targets)
91 | return outputs, losses, states
92 |
--------------------------------------------------------------------------------
/models/detection/yolox_extension/models/yolo_pafpn.py:
--------------------------------------------------------------------------------
1 | """
2 | Original Yolox PAFPN code with slight modifications
3 | """
4 | from typing import Dict, Optional, Tuple
5 |
6 | import torch as th
7 | import torch.nn as nn
8 |
9 | try:
10 | from torch import compile as th_compile
11 | except ImportError:
12 | th_compile = None
13 |
14 | from ...yolox.models.network_blocks import BaseConv, CSPLayer, DWConv
15 | from data.utils.types import BackboneFeatures
16 |
17 |
18 | class YOLOPAFPN(nn.Module):
19 | """
20 | Removed the direct dependency on the backbone.
21 | """
22 |
23 | def __init__(
24 | self,
25 | depth: float = 1.0,
26 | in_stages: Tuple[int, ...] = (2, 3, 4),
27 | in_channels: Tuple[int, ...] = (256, 512, 1024),
28 | depthwise: bool = False,
29 | act: str = "silu",
30 | compile_cfg: Optional[Dict] = None,
31 | ):
32 | super().__init__()
33 | assert len(in_stages) == len(in_channels)
34 | assert len(in_channels) == 3, 'Current implementation only for 3 feature maps'
35 | self.in_features = in_stages
36 | self.in_channels = in_channels
37 | Conv = DWConv if depthwise else BaseConv
38 |
39 | ###### Compile if requested ######
40 | if compile_cfg is not None:
41 | compile_mdl = compile_cfg['enable']
42 | if compile_mdl and th_compile is not None:
43 | self.forward = th_compile(self.forward, **compile_cfg['args'])
44 | elif compile_mdl:
45 | print('Could not compile PAFPN because torch.compile is not available')
46 |
47 | ##################################
48 |
49 | self.upsample = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest-exact')
50 | self.lateral_conv0 = BaseConv(
51 | in_channels[2], in_channels[1], 1, 1, act=act
52 | )
53 | self.C3_p4 = CSPLayer(
54 | 2 * in_channels[1],
55 | in_channels[1],
56 | round(3 * depth),
57 | False,
58 | depthwise=depthwise,
59 | act=act,
60 | ) # cat
61 |
62 | self.reduce_conv1 = BaseConv(
63 | in_channels[1], in_channels[0], 1, 1, act=act
64 | )
65 | self.C3_p3 = CSPLayer(
66 | 2 * in_channels[0],
67 | in_channels[0],
68 | round(3 * depth),
69 | False,
70 | depthwise=depthwise,
71 | act=act,
72 | )
73 |
74 | # bottom-up conv
75 | self.bu_conv2 = Conv(
76 | in_channels[0], in_channels[0], 3, 2, act=act
77 | )
78 | self.C3_n3 = CSPLayer(
79 | 2 * in_channels[0],
80 | in_channels[1],
81 | round(3 * depth),
82 | False,
83 | depthwise=depthwise,
84 | act=act,
85 | )
86 |
87 | # bottom-up conv
88 | self.bu_conv1 = Conv(
89 | in_channels[1], in_channels[1], 3, 2, act=act
90 | )
91 | self.C3_n4 = CSPLayer(
92 | 2 * in_channels[1],
93 | in_channels[2],
94 | round(3 * depth),
95 | False,
96 | depthwise=depthwise,
97 | act=act,
98 | )
99 |
100 | ###### Compile if requested ######
101 | if compile_cfg is not None:
102 | compile_mdl = compile_cfg['enable']
103 | if compile_mdl and th_compile is not None:
104 | self.forward = th_compile(self.forward, **compile_cfg['args'])
105 | elif compile_mdl:
106 | print('Could not compile PAFPN because torch.compile is not available')
107 | ##################################
108 |
109 | def forward(self, input: BackboneFeatures):
110 | """
111 | Args:
112 | inputs: Feature maps from backbone
113 |
114 | Returns:
115 | Tuple[Tensor]: FPN feature.
116 | """
117 | features = [input[f] for f in self.in_features]
118 | x2, x1, x0 = features
119 |
120 | # channel/dowmsample_stride
121 | fpn_out0 = self.lateral_conv0(x0) # 1024->512/32
122 | f_out0 = self.upsample(fpn_out0) # 512/16
123 | f_out0 = th.cat([f_out0, x1], 1) # 512->1024/16
124 | f_out0 = self.C3_p4(f_out0) # 1024->512/16
125 |
126 | fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16
127 | f_out1 = self.upsample(fpn_out1) # 256/8
128 | f_out1 = th.cat([f_out1, x2], 1) # 256->512/8
129 | pan_out2 = self.C3_p3(f_out1) # 512->256/8
130 |
131 | p_out1 = self.bu_conv2(pan_out2) # 256->256/16
132 | p_out1 = th.cat([p_out1, fpn_out1], 1) # 256->512/16
133 | pan_out1 = self.C3_n3(p_out1) # 512->512/16
134 |
135 | p_out0 = self.bu_conv1(pan_out1) # 512->512/32
136 | p_out0 = th.cat([p_out0, fpn_out0], 1) # 512->1024/32
137 | pan_out0 = self.C3_n4(p_out0) # 1024->1024/32
138 |
139 | outputs = (pan_out2, pan_out1, pan_out0)
140 | return outputs
141 |
--------------------------------------------------------------------------------
/models/layers/maxvit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/models/layers/maxvit/__init__.py
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .activations import *
2 | from .adaptive_avgmax_pool import \
3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
4 | from .blur_pool import BlurPool2d
5 | from .classifier import ClassifierHead, create_classifier
6 | from .cond_conv2d import CondConv2d, get_condconv_initializer
7 | from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
8 | set_layer_config
9 | from .conv2d_same import Conv2dSame, conv2d_same
10 | from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
11 | from .create_act import create_act_layer, get_act_layer, get_act_fn
12 | from .create_attn import get_attn, create_attn
13 | from .create_conv2d import create_conv2d
14 | from .create_norm import get_norm_layer, create_norm_layer
15 | from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
16 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
17 | from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
18 | from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
19 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
20 | from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
21 | from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
22 | from .gather_excite import GatherExcite
23 | from .global_context import GlobalContext
24 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
25 | from .inplace_abn import InplaceAbn
26 | from .linear import Linear
27 | from .mixed_conv2d import MixedConv2d
28 | from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
29 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn
30 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
31 | from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
32 | from .padding import get_padding, get_same_padding, pad_same
33 | from .patch_embed import PatchEmbed
34 | from .pool2d_same import AvgPool2dSame, create_pool2d
35 | from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
36 | from .selective_kernel import SelectiveKernel
37 | from .separable_conv import SeparableConv2d, SeparableConvNormAct
38 | from .space_to_depth import SpaceToDepthModule
39 | from .split_attn import SplitAttn
40 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
41 | from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
42 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool
43 | from .trace_utils import _assert, _float_to_int
44 | from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
45 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/activations.py:
--------------------------------------------------------------------------------
1 | """ Activations
2 |
3 | A collection of activations fn and modules with a common interface so that they can
4 | easily be swapped. All have an `inplace` arg even if not used.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 |
9 | import torch
10 | from torch import nn as nn
11 | from torch.nn import functional as F
12 |
13 |
14 | def swish(x, inplace: bool = False):
15 | """Swish - Described in: https://arxiv.org/abs/1710.05941
16 | """
17 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
18 |
19 |
20 | class Swish(nn.Module):
21 | def __init__(self, inplace: bool = False):
22 | super(Swish, self).__init__()
23 | self.inplace = inplace
24 |
25 | def forward(self, x):
26 | return swish(x, self.inplace)
27 |
28 |
29 | def mish(x, inplace: bool = False):
30 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
31 | NOTE: I don't have a working inplace variant
32 | """
33 | return x.mul(F.softplus(x).tanh())
34 |
35 |
36 | class Mish(nn.Module):
37 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
38 | """
39 | def __init__(self, inplace: bool = False):
40 | super(Mish, self).__init__()
41 |
42 | def forward(self, x):
43 | return mish(x)
44 |
45 |
46 | def sigmoid(x, inplace: bool = False):
47 | return x.sigmoid_() if inplace else x.sigmoid()
48 |
49 |
50 | # PyTorch has this, but not with a consistent inplace argmument interface
51 | class Sigmoid(nn.Module):
52 | def __init__(self, inplace: bool = False):
53 | super(Sigmoid, self).__init__()
54 | self.inplace = inplace
55 |
56 | def forward(self, x):
57 | return x.sigmoid_() if self.inplace else x.sigmoid()
58 |
59 |
60 | def tanh(x, inplace: bool = False):
61 | return x.tanh_() if inplace else x.tanh()
62 |
63 |
64 | # PyTorch has this, but not with a consistent inplace argmument interface
65 | class Tanh(nn.Module):
66 | def __init__(self, inplace: bool = False):
67 | super(Tanh, self).__init__()
68 | self.inplace = inplace
69 |
70 | def forward(self, x):
71 | return x.tanh_() if self.inplace else x.tanh()
72 |
73 |
74 | def hard_swish(x, inplace: bool = False):
75 | inner = F.relu6(x + 3.).div_(6.)
76 | return x.mul_(inner) if inplace else x.mul(inner)
77 |
78 |
79 | class HardSwish(nn.Module):
80 | def __init__(self, inplace: bool = False):
81 | super(HardSwish, self).__init__()
82 | self.inplace = inplace
83 |
84 | def forward(self, x):
85 | return hard_swish(x, self.inplace)
86 |
87 |
88 | def hard_sigmoid(x, inplace: bool = False):
89 | if inplace:
90 | return x.add_(3.).clamp_(0., 6.).div_(6.)
91 | else:
92 | return F.relu6(x + 3.) / 6.
93 |
94 |
95 | class HardSigmoid(nn.Module):
96 | def __init__(self, inplace: bool = False):
97 | super(HardSigmoid, self).__init__()
98 | self.inplace = inplace
99 |
100 | def forward(self, x):
101 | return hard_sigmoid(x, self.inplace)
102 |
103 |
104 | def hard_mish(x, inplace: bool = False):
105 | """ Hard Mish
106 | Experimental, based on notes by Mish author Diganta Misra at
107 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
108 | """
109 | if inplace:
110 | return x.mul_(0.5 * (x + 2).clamp(min=0, max=2))
111 | else:
112 | return 0.5 * x * (x + 2).clamp(min=0, max=2)
113 |
114 |
115 | class HardMish(nn.Module):
116 | def __init__(self, inplace: bool = False):
117 | super(HardMish, self).__init__()
118 | self.inplace = inplace
119 |
120 | def forward(self, x):
121 | return hard_mish(x, self.inplace)
122 |
123 |
124 | class PReLU(nn.PReLU):
125 | """Applies PReLU (w/ dummy inplace arg)
126 | """
127 | def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None:
128 | super(PReLU, self).__init__(num_parameters=num_parameters, init=init)
129 |
130 | def forward(self, input: torch.Tensor) -> torch.Tensor:
131 | return F.prelu(input, self.weight)
132 |
133 |
134 | def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
135 | return F.gelu(x)
136 |
137 |
138 | class GELU(nn.Module):
139 | """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
140 | """
141 | def __init__(self, inplace: bool = False):
142 | super(GELU, self).__init__()
143 |
144 | def forward(self, input: torch.Tensor) -> torch.Tensor:
145 | return F.gelu(input)
146 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/activations_jit.py:
--------------------------------------------------------------------------------
1 | """ Activations
2 |
3 | A collection of jit-scripted activations fn and modules with a common interface so that they can
4 | easily be swapped. All have an `inplace` arg even if not used.
5 |
6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not
7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted
8 | versions if they contain in-place ops.
9 |
10 | Hacked together by / Copyright 2020 Ross Wightman
11 | """
12 |
13 | import torch
14 | from torch import nn as nn
15 | from torch.nn import functional as F
16 |
17 |
18 | @torch.jit.script
19 | def swish_jit(x, inplace: bool = False):
20 | """Swish - Described in: https://arxiv.org/abs/1710.05941
21 | """
22 | return x.mul(x.sigmoid())
23 |
24 |
25 | @torch.jit.script
26 | def mish_jit(x, _inplace: bool = False):
27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
28 | """
29 | return x.mul(F.softplus(x).tanh())
30 |
31 |
32 | class SwishJit(nn.Module):
33 | def __init__(self, inplace: bool = False):
34 | super(SwishJit, self).__init__()
35 |
36 | def forward(self, x):
37 | return swish_jit(x)
38 |
39 |
40 | class MishJit(nn.Module):
41 | def __init__(self, inplace: bool = False):
42 | super(MishJit, self).__init__()
43 |
44 | def forward(self, x):
45 | return mish_jit(x)
46 |
47 |
48 | @torch.jit.script
49 | def hard_sigmoid_jit(x, inplace: bool = False):
50 | # return F.relu6(x + 3.) / 6.
51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
52 |
53 |
54 | class HardSigmoidJit(nn.Module):
55 | def __init__(self, inplace: bool = False):
56 | super(HardSigmoidJit, self).__init__()
57 |
58 | def forward(self, x):
59 | return hard_sigmoid_jit(x)
60 |
61 |
62 | @torch.jit.script
63 | def hard_swish_jit(x, inplace: bool = False):
64 | # return x * (F.relu6(x + 3.) / 6)
65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster?
66 |
67 |
68 | class HardSwishJit(nn.Module):
69 | def __init__(self, inplace: bool = False):
70 | super(HardSwishJit, self).__init__()
71 |
72 | def forward(self, x):
73 | return hard_swish_jit(x)
74 |
75 |
76 | @torch.jit.script
77 | def hard_mish_jit(x, inplace: bool = False):
78 | """ Hard Mish
79 | Experimental, based on notes by Mish author Diganta Misra at
80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
81 | """
82 | return 0.5 * x * (x + 2).clamp(min=0, max=2)
83 |
84 |
85 | class HardMishJit(nn.Module):
86 | def __init__(self, inplace: bool = False):
87 | super(HardMishJit, self).__init__()
88 |
89 | def forward(self, x):
90 | return hard_mish_jit(x)
91 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/adaptive_avgmax_pool.py:
--------------------------------------------------------------------------------
1 | """ PyTorch selectable adaptive pooling
2 | Adaptive pooling with the ability to select the type of pooling from:
3 | * 'avg' - Average pooling
4 | * 'max' - Max pooling
5 | * 'avgmax' - Sum of average and max pooling re-scaled by 0.5
6 | * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim
7 |
8 | Both a functional and a nn.Module version of the pooling is provided.
9 |
10 | Hacked together by / Copyright 2020 Ross Wightman
11 | """
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 |
16 |
17 | def adaptive_pool_feat_mult(pool_type='avg'):
18 | if pool_type == 'catavgmax':
19 | return 2
20 | else:
21 | return 1
22 |
23 |
24 | def adaptive_avgmax_pool2d(x, output_size=1):
25 | x_avg = F.adaptive_avg_pool2d(x, output_size)
26 | x_max = F.adaptive_max_pool2d(x, output_size)
27 | return 0.5 * (x_avg + x_max)
28 |
29 |
30 | def adaptive_catavgmax_pool2d(x, output_size=1):
31 | x_avg = F.adaptive_avg_pool2d(x, output_size)
32 | x_max = F.adaptive_max_pool2d(x, output_size)
33 | return torch.cat((x_avg, x_max), 1)
34 |
35 |
36 | def select_adaptive_pool2d(x, pool_type='avg', output_size=1):
37 | """Selectable global pooling function with dynamic input kernel size
38 | """
39 | if pool_type == 'avg':
40 | x = F.adaptive_avg_pool2d(x, output_size)
41 | elif pool_type == 'avgmax':
42 | x = adaptive_avgmax_pool2d(x, output_size)
43 | elif pool_type == 'catavgmax':
44 | x = adaptive_catavgmax_pool2d(x, output_size)
45 | elif pool_type == 'max':
46 | x = F.adaptive_max_pool2d(x, output_size)
47 | else:
48 | assert False, 'Invalid pool type: %s' % pool_type
49 | return x
50 |
51 |
52 | class FastAdaptiveAvgPool2d(nn.Module):
53 | def __init__(self, flatten=False):
54 | super(FastAdaptiveAvgPool2d, self).__init__()
55 | self.flatten = flatten
56 |
57 | def forward(self, x):
58 | return x.mean((2, 3), keepdim=not self.flatten)
59 |
60 |
61 | class AdaptiveAvgMaxPool2d(nn.Module):
62 | def __init__(self, output_size=1):
63 | super(AdaptiveAvgMaxPool2d, self).__init__()
64 | self.output_size = output_size
65 |
66 | def forward(self, x):
67 | return adaptive_avgmax_pool2d(x, self.output_size)
68 |
69 |
70 | class AdaptiveCatAvgMaxPool2d(nn.Module):
71 | def __init__(self, output_size=1):
72 | super(AdaptiveCatAvgMaxPool2d, self).__init__()
73 | self.output_size = output_size
74 |
75 | def forward(self, x):
76 | return adaptive_catavgmax_pool2d(x, self.output_size)
77 |
78 |
79 | class SelectAdaptivePool2d(nn.Module):
80 | """Selectable global pooling layer with dynamic input kernel size
81 | """
82 | def __init__(self, output_size=1, pool_type='fast', flatten=False):
83 | super(SelectAdaptivePool2d, self).__init__()
84 | self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
85 | self.flatten = nn.Flatten(1) if flatten else nn.Identity()
86 | if pool_type == '':
87 | self.pool = nn.Identity() # pass through
88 | elif pool_type == 'fast':
89 | assert output_size == 1
90 | self.pool = FastAdaptiveAvgPool2d(flatten)
91 | self.flatten = nn.Identity()
92 | elif pool_type == 'avg':
93 | self.pool = nn.AdaptiveAvgPool2d(output_size)
94 | elif pool_type == 'avgmax':
95 | self.pool = AdaptiveAvgMaxPool2d(output_size)
96 | elif pool_type == 'catavgmax':
97 | self.pool = AdaptiveCatAvgMaxPool2d(output_size)
98 | elif pool_type == 'max':
99 | self.pool = nn.AdaptiveMaxPool2d(output_size)
100 | else:
101 | assert False, 'Invalid pool type: %s' % pool_type
102 |
103 | def is_identity(self):
104 | return not self.pool_type
105 |
106 | def forward(self, x):
107 | x = self.pool(x)
108 | x = self.flatten(x)
109 | return x
110 |
111 | def feat_mult(self):
112 | return adaptive_pool_feat_mult(self.pool_type)
113 |
114 | def __repr__(self):
115 | return self.__class__.__name__ + ' (' \
116 | + 'pool_type=' + self.pool_type \
117 | + ', flatten=' + str(self.flatten) + ')'
118 |
119 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/blur_pool.py:
--------------------------------------------------------------------------------
1 | """
2 | BlurPool layer inspired by
3 | - Kornia's Max_BlurPool2d
4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar`
5 |
6 | Hacked together by Chris Ha and Ross Wightman
7 | """
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import numpy as np
13 | from .padding import get_padding
14 |
15 |
16 | class BlurPool2d(nn.Module):
17 | r"""Creates a module that computes blurs and downsample a given feature map.
18 | See :cite:`zhang2019shiftinvar` for more details.
19 | Corresponds to the Downsample class, which does blurring and subsampling
20 |
21 | Args:
22 | channels = Number of input channels
23 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5.
24 | stride (int): downsampling filter stride
25 |
26 | Returns:
27 | torch.Tensor: the transformed tensor.
28 | """
29 | def __init__(self, channels, filt_size=3, stride=2) -> None:
30 | super(BlurPool2d, self).__init__()
31 | assert filt_size > 1
32 | self.channels = channels
33 | self.filt_size = filt_size
34 | self.stride = stride
35 | self.padding = [get_padding(filt_size, stride, dilation=1)] * 4
36 | coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32))
37 | blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1)
38 | self.register_buffer('filt', blur_filter, persistent=False)
39 |
40 | def forward(self, x: torch.Tensor) -> torch.Tensor:
41 | x = F.pad(x, self.padding, 'reflect')
42 | return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)
43 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/cbam.py:
--------------------------------------------------------------------------------
1 | """ CBAM (sort-of) Attention
2 |
3 | Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521
4 |
5 | WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on
6 | some tasks, especially fine-grained it seems. I may end up removing this impl.
7 |
8 | Hacked together by / Copyright 2020 Ross Wightman
9 | """
10 | import torch
11 | from torch import nn as nn
12 | import torch.nn.functional as F
13 |
14 | from .conv_bn_act import ConvNormAct
15 | from .create_act import create_act_layer, get_act_layer
16 | from .helpers import make_divisible
17 |
18 |
19 | class ChannelAttn(nn.Module):
20 | """ Original CBAM channel attention module, currently avg + max pool variant only.
21 | """
22 | def __init__(
23 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
24 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
25 | super(ChannelAttn, self).__init__()
26 | if not rd_channels:
27 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
28 | self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias)
29 | self.act = act_layer(inplace=True)
30 | self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias)
31 | self.gate = create_act_layer(gate_layer)
32 |
33 | def forward(self, x):
34 | x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True))))
35 | x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True))))
36 | return x * self.gate(x_avg + x_max)
37 |
38 |
39 | class LightChannelAttn(ChannelAttn):
40 | """An experimental 'lightweight' that sums avg + max pool first
41 | """
42 | def __init__(
43 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
44 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
45 | super(LightChannelAttn, self).__init__(
46 | channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias)
47 |
48 | def forward(self, x):
49 | x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True)
50 | x_attn = self.fc2(self.act(self.fc1(x_pool)))
51 | return x * F.sigmoid(x_attn)
52 |
53 |
54 | class SpatialAttn(nn.Module):
55 | """ Original CBAM spatial attention module
56 | """
57 | def __init__(self, kernel_size=7, gate_layer='sigmoid'):
58 | super(SpatialAttn, self).__init__()
59 | self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False)
60 | self.gate = create_act_layer(gate_layer)
61 |
62 | def forward(self, x):
63 | x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1)
64 | x_attn = self.conv(x_attn)
65 | return x * self.gate(x_attn)
66 |
67 |
68 | class LightSpatialAttn(nn.Module):
69 | """An experimental 'lightweight' variant that sums avg_pool and max_pool results.
70 | """
71 | def __init__(self, kernel_size=7, gate_layer='sigmoid'):
72 | super(LightSpatialAttn, self).__init__()
73 | self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False)
74 | self.gate = create_act_layer(gate_layer)
75 |
76 | def forward(self, x):
77 | x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True)
78 | x_attn = self.conv(x_attn)
79 | return x * self.gate(x_attn)
80 |
81 |
82 | class CbamModule(nn.Module):
83 | def __init__(
84 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
85 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
86 | super(CbamModule, self).__init__()
87 | self.channel = ChannelAttn(
88 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
89 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
90 | self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer)
91 |
92 | def forward(self, x):
93 | x = self.channel(x)
94 | x = self.spatial(x)
95 | return x
96 |
97 |
98 | class LightCbamModule(nn.Module):
99 | def __init__(
100 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
101 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
102 | super(LightCbamModule, self).__init__()
103 | self.channel = LightChannelAttn(
104 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
105 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
106 | self.spatial = LightSpatialAttn(spatial_kernel_size)
107 |
108 | def forward(self, x):
109 | x = self.channel(x)
110 | x = self.spatial(x)
111 | return x
112 |
113 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/classifier.py:
--------------------------------------------------------------------------------
1 | """ Classifier head and layer factory
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from torch import nn as nn
6 | from torch.nn import functional as F
7 |
8 | from .adaptive_avgmax_pool import SelectAdaptivePool2d
9 |
10 |
11 | def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
12 | flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
13 | if not pool_type:
14 | assert num_classes == 0 or use_conv,\
15 | 'Pooling can only be disabled if classifier is also removed or conv classifier is used'
16 | flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
17 | global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
18 | num_pooled_features = num_features * global_pool.feat_mult()
19 | return global_pool, num_pooled_features
20 |
21 |
22 | def _create_fc(num_features, num_classes, use_conv=False):
23 | if num_classes <= 0:
24 | fc = nn.Identity() # pass-through (no classifier)
25 | elif use_conv:
26 | fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
27 | else:
28 | fc = nn.Linear(num_features, num_classes, bias=True)
29 | return fc
30 |
31 |
32 | def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
33 | global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
34 | fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
35 | return global_pool, fc
36 |
37 |
38 | class ClassifierHead(nn.Module):
39 | """Classifier head w/ configurable global pooling and dropout."""
40 |
41 | def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
42 | super(ClassifierHead, self).__init__()
43 | self.drop_rate = drop_rate
44 | self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
45 | self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
46 | self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
47 |
48 | def forward(self, x, pre_logits: bool = False):
49 | x = self.global_pool(x)
50 | if self.drop_rate:
51 | x = F.dropout(x, p=float(self.drop_rate), training=self.training)
52 | if pre_logits:
53 | return x.flatten(1)
54 | else:
55 | x = self.fc(x)
56 | return self.flatten(x)
57 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/config.py:
--------------------------------------------------------------------------------
1 | """ Model / Layer Config singleton state
2 | """
3 | from typing import Any, Optional
4 |
5 | __all__ = [
6 | 'is_exportable', 'is_scriptable', 'is_no_jit',
7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config'
8 | ]
9 |
10 | # Set to True if prefer to have layers with no jit optimization (includes activations)
11 | _NO_JIT = False
12 |
13 | # Set to True if prefer to have activation layers with no jit optimization
14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying
15 | # the jit flags so far are activations. This will change as more layers are updated and/or added.
16 | _NO_ACTIVATION_JIT = False
17 |
18 | # Set to True if exporting a model with Same padding via ONNX
19 | _EXPORTABLE = False
20 |
21 | # Set to True if wanting to use torch.jit.script on a model
22 | _SCRIPTABLE = False
23 |
24 |
25 | def is_no_jit():
26 | return _NO_JIT
27 |
28 |
29 | class set_no_jit:
30 | def __init__(self, mode: bool) -> None:
31 | global _NO_JIT
32 | self.prev = _NO_JIT
33 | _NO_JIT = mode
34 |
35 | def __enter__(self) -> None:
36 | pass
37 |
38 | def __exit__(self, *args: Any) -> bool:
39 | global _NO_JIT
40 | _NO_JIT = self.prev
41 | return False
42 |
43 |
44 | def is_exportable():
45 | return _EXPORTABLE
46 |
47 |
48 | class set_exportable:
49 | def __init__(self, mode: bool) -> None:
50 | global _EXPORTABLE
51 | self.prev = _EXPORTABLE
52 | _EXPORTABLE = mode
53 |
54 | def __enter__(self) -> None:
55 | pass
56 |
57 | def __exit__(self, *args: Any) -> bool:
58 | global _EXPORTABLE
59 | _EXPORTABLE = self.prev
60 | return False
61 |
62 |
63 | def is_scriptable():
64 | return _SCRIPTABLE
65 |
66 |
67 | class set_scriptable:
68 | def __init__(self, mode: bool) -> None:
69 | global _SCRIPTABLE
70 | self.prev = _SCRIPTABLE
71 | _SCRIPTABLE = mode
72 |
73 | def __enter__(self) -> None:
74 | pass
75 |
76 | def __exit__(self, *args: Any) -> bool:
77 | global _SCRIPTABLE
78 | _SCRIPTABLE = self.prev
79 | return False
80 |
81 |
82 | class set_layer_config:
83 | """ Layer config context manager that allows setting all layer config flags at once.
84 | If a flag arg is None, it will not change the current value.
85 | """
86 | def __init__(
87 | self,
88 | scriptable: Optional[bool] = None,
89 | exportable: Optional[bool] = None,
90 | no_jit: Optional[bool] = None,
91 | no_activation_jit: Optional[bool] = None):
92 | global _SCRIPTABLE
93 | global _EXPORTABLE
94 | global _NO_JIT
95 | global _NO_ACTIVATION_JIT
96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT
97 | if scriptable is not None:
98 | _SCRIPTABLE = scriptable
99 | if exportable is not None:
100 | _EXPORTABLE = exportable
101 | if no_jit is not None:
102 | _NO_JIT = no_jit
103 | if no_activation_jit is not None:
104 | _NO_ACTIVATION_JIT = no_activation_jit
105 |
106 | def __enter__(self) -> None:
107 | pass
108 |
109 | def __exit__(self, *args: Any) -> bool:
110 | global _SCRIPTABLE
111 | global _EXPORTABLE
112 | global _NO_JIT
113 | global _NO_ACTIVATION_JIT
114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev
115 | return False
116 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/conv2d_same.py:
--------------------------------------------------------------------------------
1 | """ Conv2d w/ Same Padding
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from typing import Tuple, Optional
9 |
10 | from .padding import pad_same, get_padding_value
11 |
12 |
13 | def conv2d_same(
14 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1),
15 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1):
16 | x = pad_same(x, weight.shape[-2:], stride, dilation)
17 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups)
18 |
19 |
20 | class Conv2dSame(nn.Conv2d):
21 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions
22 | """
23 |
24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
25 | padding=0, dilation=1, groups=1, bias=True):
26 | super(Conv2dSame, self).__init__(
27 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
28 |
29 | def forward(self, x):
30 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
31 |
32 |
33 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs):
34 | padding = kwargs.pop('padding', '')
35 | kwargs.setdefault('bias', False)
36 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs)
37 | if is_dynamic:
38 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs)
39 | else:
40 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs)
41 |
42 |
43 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/conv_bn_act.py:
--------------------------------------------------------------------------------
1 | """ Conv2d + BN + Act
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import functools
6 | from torch import nn as nn
7 |
8 | from .create_conv2d import create_conv2d
9 | from .create_norm_act import get_norm_act_layer
10 |
11 |
12 | class ConvNormAct(nn.Module):
13 | def __init__(
14 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
15 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None):
16 | super(ConvNormAct, self).__init__()
17 | self.conv = create_conv2d(
18 | in_channels, out_channels, kernel_size, stride=stride,
19 | padding=padding, dilation=dilation, groups=groups, bias=bias)
20 |
21 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions
22 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
23 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
24 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
25 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
26 |
27 | @property
28 | def in_channels(self):
29 | return self.conv.in_channels
30 |
31 | @property
32 | def out_channels(self):
33 | return self.conv.out_channels
34 |
35 | def forward(self, x):
36 | x = self.conv(x)
37 | x = self.bn(x)
38 | return x
39 |
40 |
41 | ConvBnAct = ConvNormAct
42 |
43 |
44 | def create_aa(aa_layer, channels, stride=2, enable=True):
45 | if not aa_layer or not enable:
46 | return nn.Identity()
47 | if isinstance(aa_layer, functools.partial):
48 | if issubclass(aa_layer.func, nn.AvgPool2d):
49 | return aa_layer()
50 | else:
51 | return aa_layer(channels)
52 | elif issubclass(aa_layer, nn.AvgPool2d):
53 | return aa_layer(stride)
54 | else:
55 | return aa_layer(channels=channels, stride=stride)
56 |
57 |
58 | class ConvNormActAa(nn.Module):
59 | def __init__(
60 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
61 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None):
62 | super(ConvNormActAa, self).__init__()
63 | use_aa = aa_layer is not None and stride == 2
64 |
65 | self.conv = create_conv2d(
66 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride,
67 | padding=padding, dilation=dilation, groups=groups, bias=bias)
68 |
69 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions
70 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
71 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
72 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
73 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
74 | self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
75 |
76 | @property
77 | def in_channels(self):
78 | return self.conv.in_channels
79 |
80 | @property
81 | def out_channels(self):
82 | return self.conv.out_channels
83 |
84 | def forward(self, x):
85 | x = self.conv(x)
86 | x = self.bn(x)
87 | x = self.aa(x)
88 | return x
89 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/create_attn.py:
--------------------------------------------------------------------------------
1 | """ Attention Factory
2 |
3 | Hacked together by / Copyright 2021 Ross Wightman
4 | """
5 | import torch
6 | from functools import partial
7 |
8 | from .bottleneck_attn import BottleneckAttn
9 | from .cbam import CbamModule, LightCbamModule
10 | from .eca import EcaModule, CecaModule
11 | from .gather_excite import GatherExcite
12 | from .global_context import GlobalContext
13 | from .halo_attn import HaloAttn
14 | from .lambda_layer import LambdaLayer
15 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn
16 | from .selective_kernel import SelectiveKernel
17 | from .split_attn import SplitAttn
18 | from .squeeze_excite import SEModule, EffectiveSEModule
19 |
20 |
21 | def get_attn(attn_type):
22 | if isinstance(attn_type, torch.nn.Module):
23 | return attn_type
24 | module_cls = None
25 | if attn_type:
26 | if isinstance(attn_type, str):
27 | attn_type = attn_type.lower()
28 | # Lightweight attention modules (channel and/or coarse spatial).
29 | # Typically added to existing network architecture blocks in addition to existing convolutions.
30 | if attn_type == 'se':
31 | module_cls = SEModule
32 | elif attn_type == 'ese':
33 | module_cls = EffectiveSEModule
34 | elif attn_type == 'eca':
35 | module_cls = EcaModule
36 | elif attn_type == 'ecam':
37 | module_cls = partial(EcaModule, use_mlp=True)
38 | elif attn_type == 'ceca':
39 | module_cls = CecaModule
40 | elif attn_type == 'ge':
41 | module_cls = GatherExcite
42 | elif attn_type == 'gc':
43 | module_cls = GlobalContext
44 | elif attn_type == 'gca':
45 | module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False)
46 | elif attn_type == 'cbam':
47 | module_cls = CbamModule
48 | elif attn_type == 'lcbam':
49 | module_cls = LightCbamModule
50 |
51 | # Attention / attention-like modules w/ significant params
52 | # Typically replace some of the existing workhorse convs in a network architecture.
53 | # All of these accept a stride argument and can spatially downsample the input.
54 | elif attn_type == 'sk':
55 | module_cls = SelectiveKernel
56 | elif attn_type == 'splat':
57 | module_cls = SplitAttn
58 |
59 | # Self-attention / attention-like modules w/ significant compute and/or params
60 | # Typically replace some of the existing workhorse convs in a network architecture.
61 | # All of these accept a stride argument and can spatially downsample the input.
62 | elif attn_type == 'lambda':
63 | return LambdaLayer
64 | elif attn_type == 'bottleneck':
65 | return BottleneckAttn
66 | elif attn_type == 'halo':
67 | return HaloAttn
68 | elif attn_type == 'nl':
69 | module_cls = NonLocalAttn
70 | elif attn_type == 'bat':
71 | module_cls = BatNonLocalAttn
72 |
73 | # Woops!
74 | else:
75 | assert False, "Invalid attn module (%s)" % attn_type
76 | elif isinstance(attn_type, bool):
77 | if attn_type:
78 | module_cls = SEModule
79 | else:
80 | module_cls = attn_type
81 | return module_cls
82 |
83 |
84 | def create_attn(attn_type, channels, **kwargs):
85 | module_cls = get_attn(attn_type)
86 | if module_cls is not None:
87 | # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
88 | return module_cls(channels, **kwargs)
89 | return None
90 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/create_conv2d.py:
--------------------------------------------------------------------------------
1 | """ Create Conv2d Factory Method
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 |
6 | from .mixed_conv2d import MixedConv2d
7 | from .cond_conv2d import CondConv2d
8 | from .conv2d_same import create_conv2d_pad
9 |
10 |
11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
12 | """ Select a 2d convolution implementation based on arguments
13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d.
14 |
15 | Used extensively by EfficientNet, MobileNetv3 and related networks.
16 | """
17 | if isinstance(kernel_size, list):
18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
19 | if 'groups' in kwargs:
20 | groups = kwargs.pop('groups')
21 | if groups == in_channels:
22 | kwargs['depthwise'] = True
23 | else:
24 | assert groups == 1
25 | # We're going to use only lists for defining the MixedConv2d kernel groups,
26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)
28 | else:
29 | depthwise = kwargs.pop('depthwise', False)
30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0
31 | groups = in_channels if depthwise else kwargs.pop('groups', 1)
32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0:
33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
34 | else:
35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs)
36 | return m
37 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/create_norm.py:
--------------------------------------------------------------------------------
1 | """ Norm Layer Factory
2 |
3 | Create norm modules by string (to mirror create_act and creat_norm-act fns)
4 |
5 | Copyright 2022 Ross Wightman
6 | """
7 | import types
8 | import functools
9 |
10 | import torch.nn as nn
11 |
12 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
13 |
14 | _NORM_MAP = dict(
15 | batchnorm=nn.BatchNorm2d,
16 | batchnorm2d=nn.BatchNorm2d,
17 | batchnorm1d=nn.BatchNorm1d,
18 | groupnorm=GroupNorm,
19 | groupnorm1=GroupNorm1,
20 | layernorm=LayerNorm,
21 | layernorm2d=LayerNorm2d,
22 | )
23 | _NORM_TYPES = {m for n, m in _NORM_MAP.items()}
24 |
25 |
26 | def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs):
27 | layer = get_norm_layer(layer_name, act_layer=act_layer)
28 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
29 | return layer_instance
30 |
31 |
32 | def get_norm_layer(norm_layer):
33 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
34 | norm_kwargs = {}
35 |
36 | # unbind partial fn, so args can be rebound later
37 | if isinstance(norm_layer, functools.partial):
38 | norm_kwargs.update(norm_layer.keywords)
39 | norm_layer = norm_layer.func
40 |
41 | if isinstance(norm_layer, str):
42 | layer_name = norm_layer.replace('_', '')
43 | norm_layer = _NORM_MAP.get(layer_name, None)
44 | elif norm_layer in _NORM_TYPES:
45 | norm_layer = norm_layer
46 | elif isinstance(norm_layer, types.FunctionType):
47 | # if function type, assume it is a lambda/fn that creates a norm layer
48 | norm_layer = norm_layer
49 | else:
50 | type_name = norm_layer.__name__.lower().replace('_', '')
51 | norm_layer = _NORM_MAP.get(type_name, None)
52 | assert norm_layer is not None, f"No equivalent norm layer for {type_name}"
53 |
54 | if norm_kwargs:
55 | norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args
56 | return norm_layer
57 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/create_norm_act.py:
--------------------------------------------------------------------------------
1 | """ NormAct (Normalizaiton + Activation Layer) Factory
2 |
3 | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act
4 | isntances in models. Where these are used it will be possible to swap separate BN + act layers with
5 | combined modules like IABN or EvoNorms.
6 |
7 | Hacked together by / Copyright 2020 Ross Wightman
8 | """
9 | import types
10 | import functools
11 |
12 | from .evo_norm import *
13 | from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
14 | from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d
15 | from .inplace_abn import InplaceAbn
16 |
17 | _NORM_ACT_MAP = dict(
18 | batchnorm=BatchNormAct2d,
19 | batchnorm2d=BatchNormAct2d,
20 | groupnorm=GroupNormAct,
21 | groupnorm1=functools.partial(GroupNormAct, num_groups=1),
22 | layernorm=LayerNormAct,
23 | layernorm2d=LayerNormAct2d,
24 | evonormb0=EvoNorm2dB0,
25 | evonormb1=EvoNorm2dB1,
26 | evonormb2=EvoNorm2dB2,
27 | evonorms0=EvoNorm2dS0,
28 | evonorms0a=EvoNorm2dS0a,
29 | evonorms1=EvoNorm2dS1,
30 | evonorms1a=EvoNorm2dS1a,
31 | evonorms2=EvoNorm2dS2,
32 | evonorms2a=EvoNorm2dS2a,
33 | frn=FilterResponseNormAct2d,
34 | frntlu=FilterResponseNormTlu2d,
35 | inplaceabn=InplaceAbn,
36 | iabn=InplaceAbn,
37 | )
38 | _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
39 | # has act_layer arg to define act type
40 | _NORM_ACT_REQUIRES_ARG = {
41 | BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn}
42 |
43 |
44 | def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs):
45 | layer = get_norm_act_layer(layer_name, act_layer=act_layer)
46 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
47 | if jit:
48 | layer_instance = torch.jit.script(layer_instance)
49 | return layer_instance
50 |
51 |
52 | def get_norm_act_layer(norm_layer, act_layer=None):
53 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
54 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
55 | norm_act_kwargs = {}
56 |
57 | # unbind partial fn, so args can be rebound later
58 | if isinstance(norm_layer, functools.partial):
59 | norm_act_kwargs.update(norm_layer.keywords)
60 | norm_layer = norm_layer.func
61 |
62 | if isinstance(norm_layer, str):
63 | layer_name = norm_layer.replace('_', '').lower().split('-')[0]
64 | norm_act_layer = _NORM_ACT_MAP.get(layer_name, None)
65 | elif norm_layer in _NORM_ACT_TYPES:
66 | norm_act_layer = norm_layer
67 | elif isinstance(norm_layer, types.FunctionType):
68 | # if function type, must be a lambda/fn that creates a norm_act layer
69 | norm_act_layer = norm_layer
70 | else:
71 | type_name = norm_layer.__name__.lower()
72 | if type_name.startswith('batchnorm'):
73 | norm_act_layer = BatchNormAct2d
74 | elif type_name.startswith('groupnorm'):
75 | norm_act_layer = GroupNormAct
76 | elif type_name.startswith('groupnorm1'):
77 | norm_act_layer = functools.partial(GroupNormAct, num_groups=1)
78 | elif type_name.startswith('layernorm2d'):
79 | norm_act_layer = LayerNormAct2d
80 | elif type_name.startswith('layernorm'):
81 | norm_act_layer = LayerNormAct
82 | else:
83 | assert False, f"No equivalent norm_act layer for {type_name}"
84 |
85 | if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
86 | # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
87 | # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
88 | norm_act_kwargs.setdefault('act_layer', act_layer)
89 | if norm_act_kwargs:
90 | norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args
91 | return norm_act_layer
92 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/fast_norm.py:
--------------------------------------------------------------------------------
1 | """ 'Fast' Normalization Functions
2 |
3 | For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32.
4 |
5 | Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast)
6 |
7 | Hacked together by / Copyright 2022 Ross Wightman
8 | """
9 | from typing import List, Optional
10 |
11 | import torch
12 | from torch.nn import functional as F
13 |
14 | try:
15 | from apex.normalization.fused_layer_norm import fused_layer_norm_affine
16 | has_apex = True
17 | except ImportError:
18 | has_apex = False
19 |
20 |
21 | # fast (ie lower precision LN) can be disabled with this flag if issues crop up
22 | _USE_FAST_NORM = False # defaulting to False for now
23 |
24 |
25 | def is_fast_norm():
26 | return _USE_FAST_NORM
27 |
28 |
29 | def set_fast_norm(enable=True):
30 | global _USE_FAST_NORM
31 | _USE_FAST_NORM = enable
32 |
33 |
34 | def fast_group_norm(
35 | x: torch.Tensor,
36 | num_groups: int,
37 | weight: Optional[torch.Tensor] = None,
38 | bias: Optional[torch.Tensor] = None,
39 | eps: float = 1e-5
40 | ) -> torch.Tensor:
41 | if torch.jit.is_scripting():
42 | # currently cannot use is_autocast_enabled within torchscript
43 | return F.group_norm(x, num_groups, weight, bias, eps)
44 |
45 | if torch.is_autocast_enabled():
46 | # normally native AMP casts GN inputs to float32
47 | # here we use the low precision autocast dtype
48 | # FIXME what to do re CPU autocast?
49 | dt = torch.get_autocast_gpu_dtype()
50 | x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
51 |
52 | with torch.cuda.amp.autocast(enabled=False):
53 | return F.group_norm(x, num_groups, weight, bias, eps)
54 |
55 |
56 | def fast_layer_norm(
57 | x: torch.Tensor,
58 | normalized_shape: List[int],
59 | weight: Optional[torch.Tensor] = None,
60 | bias: Optional[torch.Tensor] = None,
61 | eps: float = 1e-5
62 | ) -> torch.Tensor:
63 | if torch.jit.is_scripting():
64 | # currently cannot use is_autocast_enabled within torchscript
65 | return F.layer_norm(x, normalized_shape, weight, bias, eps)
66 |
67 | if has_apex:
68 | return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
69 |
70 | if torch.is_autocast_enabled():
71 | # normally native AMP casts LN inputs to float32
72 | # apex LN does not, this is behaving like Apex
73 | dt = torch.get_autocast_gpu_dtype()
74 | # FIXME what to do re CPU autocast?
75 | x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt)
76 |
77 | with torch.cuda.amp.autocast(enabled=False):
78 | return F.layer_norm(x, normalized_shape, weight, bias, eps)
79 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/filter_response_norm.py:
--------------------------------------------------------------------------------
1 | """ Filter Response Norm in PyTorch
2 |
3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737
4 |
5 | Hacked together by / Copyright 2021 Ross Wightman
6 | """
7 | import torch
8 | import torch.nn as nn
9 |
10 | from .create_act import create_act_layer
11 | from .trace_utils import _assert
12 |
13 |
14 | def inv_instance_rms(x, eps: float = 1e-5):
15 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype)
16 | return rms.expand(x.shape)
17 |
18 |
19 | class FilterResponseNormTlu2d(nn.Module):
20 | def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_):
21 | super(FilterResponseNormTlu2d, self).__init__()
22 | self.apply_act = apply_act # apply activation (non-linearity)
23 | self.rms = rms
24 | self.eps = eps
25 | self.weight = nn.Parameter(torch.ones(num_features))
26 | self.bias = nn.Parameter(torch.zeros(num_features))
27 | self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None
28 | self.reset_parameters()
29 |
30 | def reset_parameters(self):
31 | nn.init.ones_(self.weight)
32 | nn.init.zeros_(self.bias)
33 | if self.tau is not None:
34 | nn.init.zeros_(self.tau)
35 |
36 | def forward(self, x):
37 | _assert(x.dim() == 4, 'expected 4D input')
38 | x_dtype = x.dtype
39 | v_shape = (1, -1, 1, 1)
40 | x = x * inv_instance_rms(x, self.eps)
41 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
42 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x
43 |
44 |
45 | class FilterResponseNormAct2d(nn.Module):
46 | def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_):
47 | super(FilterResponseNormAct2d, self).__init__()
48 | if act_layer is not None and apply_act:
49 | self.act = create_act_layer(act_layer, inplace=inplace)
50 | else:
51 | self.act = nn.Identity()
52 | self.rms = rms
53 | self.eps = eps
54 | self.weight = nn.Parameter(torch.ones(num_features))
55 | self.bias = nn.Parameter(torch.zeros(num_features))
56 | self.reset_parameters()
57 |
58 | def reset_parameters(self):
59 | nn.init.ones_(self.weight)
60 | nn.init.zeros_(self.bias)
61 |
62 | def forward(self, x):
63 | _assert(x.dim() == 4, 'expected 4D input')
64 | x_dtype = x.dtype
65 | v_shape = (1, -1, 1, 1)
66 | x = x * inv_instance_rms(x, self.eps)
67 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
68 | return self.act(x)
69 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/gather_excite.py:
--------------------------------------------------------------------------------
1 | """ Gather-Excite Attention Block
2 |
3 | Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348
4 |
5 | Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet
6 |
7 | I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another
8 | impl that covers all of the cases.
9 |
10 | NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation
11 |
12 | Hacked together by / Copyright 2021 Ross Wightman
13 | """
14 | import math
15 |
16 | from torch import nn as nn
17 | import torch.nn.functional as F
18 |
19 | from .create_act import create_act_layer, get_act_layer
20 | from .create_conv2d import create_conv2d
21 | from .helpers import make_divisible
22 | from .mlp import ConvMlp
23 |
24 |
25 | class GatherExcite(nn.Module):
26 | """ Gather-Excite Attention Module
27 | """
28 | def __init__(
29 | self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,
30 | rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False,
31 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):
32 | super(GatherExcite, self).__init__()
33 | self.add_maxpool = add_maxpool
34 | act_layer = get_act_layer(act_layer)
35 | self.extent = extent
36 | if extra_params:
37 | self.gather = nn.Sequential()
38 | if extent == 0:
39 | assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'
40 | self.gather.add_module(
41 | 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True))
42 | if norm_layer:
43 | self.gather.add_module(f'norm1', nn.BatchNorm2d(channels))
44 | else:
45 | assert extent % 2 == 0
46 | num_conv = int(math.log2(extent))
47 | for i in range(num_conv):
48 | self.gather.add_module(
49 | f'conv{i + 1}',
50 | create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True))
51 | if norm_layer:
52 | self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels))
53 | if i != num_conv - 1:
54 | self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))
55 | else:
56 | self.gather = None
57 | if self.extent == 0:
58 | self.gk = 0
59 | self.gs = 0
60 | else:
61 | assert extent % 2 == 0
62 | self.gk = self.extent * 2 - 1
63 | self.gs = self.extent
64 |
65 | if not rd_channels:
66 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
67 | self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity()
68 | self.gate = create_act_layer(gate_layer)
69 |
70 | def forward(self, x):
71 | size = x.shape[-2:]
72 | if self.gather is not None:
73 | x_ge = self.gather(x)
74 | else:
75 | if self.extent == 0:
76 | # global extent
77 | x_ge = x.mean(dim=(2, 3), keepdims=True)
78 | if self.add_maxpool:
79 | # experimental codepath, may remove or change
80 | x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)
81 | else:
82 | x_ge = F.avg_pool2d(
83 | x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)
84 | if self.add_maxpool:
85 | # experimental codepath, may remove or change
86 | x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)
87 | x_ge = self.mlp(x_ge)
88 | if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:
89 | x_ge = F.interpolate(x_ge, size=size)
90 | return x * self.gate(x_ge)
91 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/global_context.py:
--------------------------------------------------------------------------------
1 | """ Global Context Attention Block
2 |
3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond`
4 | - https://arxiv.org/abs/1904.11492
5 |
6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet
7 |
8 | Hacked together by / Copyright 2021 Ross Wightman
9 | """
10 | from torch import nn as nn
11 | import torch.nn.functional as F
12 |
13 | from .create_act import create_act_layer, get_act_layer
14 | from .helpers import make_divisible
15 | from .mlp import ConvMlp
16 | from .norm import LayerNorm2d
17 |
18 |
19 | class GlobalContext(nn.Module):
20 |
21 | def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False,
22 | rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
23 | super(GlobalContext, self).__init__()
24 | act_layer = get_act_layer(act_layer)
25 |
26 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None
27 |
28 | if rd_channels is None:
29 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
30 | if fuse_add:
31 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
32 | else:
33 | self.mlp_add = None
34 | if fuse_scale:
35 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
36 | else:
37 | self.mlp_scale = None
38 |
39 | self.gate = create_act_layer(gate_layer)
40 | self.init_last_zero = init_last_zero
41 | self.reset_parameters()
42 |
43 | def reset_parameters(self):
44 | if self.conv_attn is not None:
45 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu')
46 | if self.mlp_add is not None:
47 | nn.init.zeros_(self.mlp_add.fc2.weight)
48 |
49 | def forward(self, x):
50 | B, C, H, W = x.shape
51 |
52 | if self.conv_attn is not None:
53 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
54 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
55 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
56 | context = context.view(B, C, 1, 1)
57 | else:
58 | context = x.mean(dim=(2, 3), keepdim=True)
59 |
60 | if self.mlp_scale is not None:
61 | mlp_x = self.mlp_scale(context)
62 | x = x * self.gate(mlp_x)
63 | if self.mlp_add is not None:
64 | mlp_x = self.mlp_add(context)
65 | x = x + mlp_x
66 |
67 | return x
68 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/helpers.py:
--------------------------------------------------------------------------------
1 | """ Layer/Module Helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from itertools import repeat
6 | import collections.abc
7 |
8 |
9 | # From PyTorch internals
10 | def _ntuple(n):
11 | def parse(x):
12 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
13 | return x
14 | return tuple(repeat(x, n))
15 | return parse
16 |
17 |
18 | to_1tuple = _ntuple(1)
19 | to_2tuple = _ntuple(2)
20 | to_3tuple = _ntuple(3)
21 | to_4tuple = _ntuple(4)
22 | to_ntuple = _ntuple
23 |
24 |
25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
26 | min_value = min_value or divisor
27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
28 | # Make sure that round down does not go down by more than 10%.
29 | if new_v < round_limit * v:
30 | new_v += divisor
31 | return new_v
32 |
33 |
34 | def extend_tuple(x, n):
35 | # pdas a tuple to specified n by padding with last value
36 | if not isinstance(x, (tuple, list)):
37 | x = (x,)
38 | else:
39 | x = tuple(x)
40 | pad_n = n - len(x)
41 | if pad_n <= 0:
42 | return x[:n]
43 | return x + (x[-1],) * pad_n
44 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/inplace_abn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 |
4 | try:
5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync
6 | has_iabn = True
7 | except ImportError:
8 | has_iabn = False
9 |
10 | def inplace_abn(x, weight, bias, running_mean, running_var,
11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01):
12 | raise ImportError(
13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'")
14 |
15 | def inplace_abn_sync(**kwargs):
16 | inplace_abn(**kwargs)
17 |
18 |
19 | class InplaceAbn(nn.Module):
20 | """Activated Batch Normalization
21 |
22 | This gathers a BatchNorm and an activation function in a single module
23 |
24 | Parameters
25 | ----------
26 | num_features : int
27 | Number of feature channels in the input and output.
28 | eps : float
29 | Small constant to prevent numerical issues.
30 | momentum : float
31 | Momentum factor applied to compute running statistics.
32 | affine : bool
33 | If `True` apply learned scale and shift transformation after normalization.
34 | act_layer : str or nn.Module type
35 | Name or type of the activation functions, one of: `leaky_relu`, `elu`
36 | act_param : float
37 | Negative slope for the `leaky_relu` activation.
38 | """
39 |
40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,
41 | act_layer="leaky_relu", act_param=0.01, drop_layer=None):
42 | super(InplaceAbn, self).__init__()
43 | self.num_features = num_features
44 | self.affine = affine
45 | self.eps = eps
46 | self.momentum = momentum
47 | if apply_act:
48 | if isinstance(act_layer, str):
49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '')
50 | self.act_name = act_layer if act_layer else 'identity'
51 | else:
52 | # convert act layer passed as type to string
53 | if act_layer == nn.ELU:
54 | self.act_name = 'elu'
55 | elif act_layer == nn.LeakyReLU:
56 | self.act_name = 'leaky_relu'
57 | elif act_layer is None or act_layer == nn.Identity:
58 | self.act_name = 'identity'
59 | else:
60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN'
61 | else:
62 | self.act_name = 'identity'
63 | self.act_param = act_param
64 | if self.affine:
65 | self.weight = nn.Parameter(torch.ones(num_features))
66 | self.bias = nn.Parameter(torch.zeros(num_features))
67 | else:
68 | self.register_parameter('weight', None)
69 | self.register_parameter('bias', None)
70 | self.register_buffer('running_mean', torch.zeros(num_features))
71 | self.register_buffer('running_var', torch.ones(num_features))
72 | self.reset_parameters()
73 |
74 | def reset_parameters(self):
75 | nn.init.constant_(self.running_mean, 0)
76 | nn.init.constant_(self.running_var, 1)
77 | if self.affine:
78 | nn.init.constant_(self.weight, 1)
79 | nn.init.constant_(self.bias, 0)
80 |
81 | def forward(self, x):
82 | output = inplace_abn(
83 | x, self.weight, self.bias, self.running_mean, self.running_var,
84 | self.training, self.momentum, self.eps, self.act_name, self.act_param)
85 | if isinstance(output, tuple):
86 | output = output[0]
87 | return output
88 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/linear.py:
--------------------------------------------------------------------------------
1 | """ Linear layer (alternate definition)
2 | """
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn as nn
6 |
7 |
8 | class Linear(nn.Linear):
9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
10 |
11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting
12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case.
13 | """
14 | def forward(self, input: torch.Tensor) -> torch.Tensor:
15 | if torch.jit.is_scripting():
16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None
17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias)
18 | else:
19 | return F.linear(input, self.weight, self.bias)
20 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/median_pool.py:
--------------------------------------------------------------------------------
1 | """ Median Pool
2 | Hacked together by / Copyright 2020 Ross Wightman
3 | """
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from .helpers import to_2tuple, to_4tuple
7 |
8 |
9 | class MedianPool2d(nn.Module):
10 | """ Median pool (usable as median filter when stride=1) module.
11 |
12 | Args:
13 | kernel_size: size of pooling kernel, int or 2-tuple
14 | stride: pool stride, int or 2-tuple
15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad
16 | same: override padding and enforce same padding, boolean
17 | """
18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False):
19 | super(MedianPool2d, self).__init__()
20 | self.k = to_2tuple(kernel_size)
21 | self.stride = to_2tuple(stride)
22 | self.padding = to_4tuple(padding) # convert to l, r, t, b
23 | self.same = same
24 |
25 | def _padding(self, x):
26 | if self.same:
27 | ih, iw = x.size()[2:]
28 | if ih % self.stride[0] == 0:
29 | ph = max(self.k[0] - self.stride[0], 0)
30 | else:
31 | ph = max(self.k[0] - (ih % self.stride[0]), 0)
32 | if iw % self.stride[1] == 0:
33 | pw = max(self.k[1] - self.stride[1], 0)
34 | else:
35 | pw = max(self.k[1] - (iw % self.stride[1]), 0)
36 | pl = pw // 2
37 | pr = pw - pl
38 | pt = ph // 2
39 | pb = ph - pt
40 | padding = (pl, pr, pt, pb)
41 | else:
42 | padding = self.padding
43 | return padding
44 |
45 | def forward(self, x):
46 | x = F.pad(x, self._padding(x), mode='reflect')
47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
49 | return x
50 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/mixed_conv2d.py:
--------------------------------------------------------------------------------
1 | """ PyTorch Mixed Convolution
2 |
3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595)
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 |
8 | import torch
9 | from torch import nn as nn
10 |
11 | from .conv2d_same import create_conv2d_pad
12 |
13 |
14 | def _split_channels(num_chan, num_groups):
15 | split = [num_chan // num_groups for _ in range(num_groups)]
16 | split[0] += num_chan - sum(split)
17 | return split
18 |
19 |
20 | class MixedConv2d(nn.ModuleDict):
21 | """ Mixed Grouped Convolution
22 |
23 | Based on MDConv and GroupedConv in MixNet impl:
24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py
25 | """
26 | def __init__(self, in_channels, out_channels, kernel_size=3,
27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs):
28 | super(MixedConv2d, self).__init__()
29 |
30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size]
31 | num_groups = len(kernel_size)
32 | in_splits = _split_channels(in_channels, num_groups)
33 | out_splits = _split_channels(out_channels, num_groups)
34 | self.in_channels = sum(in_splits)
35 | self.out_channels = sum(out_splits)
36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)):
37 | conv_groups = in_ch if depthwise else 1
38 | # use add_module to keep key space clean
39 | self.add_module(
40 | str(idx),
41 | create_conv2d_pad(
42 | in_ch, out_ch, k, stride=stride,
43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs)
44 | )
45 | self.splits = in_splits
46 |
47 | def forward(self, x):
48 | x_split = torch.split(x, self.splits, 1)
49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())]
50 | x = torch.cat(x_out, 1)
51 | return x
52 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/mlp.py:
--------------------------------------------------------------------------------
1 | """ MLP module w/ dropout and configurable activation layer
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | from torch import nn as nn
6 |
7 | from .helpers import to_2tuple
8 |
9 |
10 | class Mlp(nn.Module):
11 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks
12 | """
13 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
14 | super().__init__()
15 | out_features = out_features or in_features
16 | hidden_features = hidden_features or in_features
17 | bias = to_2tuple(bias)
18 | drop_probs = to_2tuple(drop)
19 |
20 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
21 | self.act = act_layer()
22 | self.drop1 = nn.Dropout(drop_probs[0])
23 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
24 | self.drop2 = nn.Dropout(drop_probs[1])
25 |
26 | def forward(self, x):
27 | x = self.fc1(x)
28 | x = self.act(x)
29 | x = self.drop1(x)
30 | x = self.fc2(x)
31 | x = self.drop2(x)
32 | return x
33 |
34 |
35 | class GluMlp(nn.Module):
36 | """ MLP w/ GLU style gating
37 | See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
38 | """
39 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.):
40 | super().__init__()
41 | out_features = out_features or in_features
42 | hidden_features = hidden_features or in_features
43 | assert hidden_features % 2 == 0
44 | bias = to_2tuple(bias)
45 | drop_probs = to_2tuple(drop)
46 |
47 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
48 | self.act = act_layer()
49 | self.drop1 = nn.Dropout(drop_probs[0])
50 | self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1])
51 | self.drop2 = nn.Dropout(drop_probs[1])
52 |
53 | def init_weights(self):
54 | # override init of fc1 w/ gate portion set to weight near zero, bias=1
55 | fc1_mid = self.fc1.bias.shape[0] // 2
56 | nn.init.ones_(self.fc1.bias[fc1_mid:])
57 | nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
58 |
59 | def forward(self, x):
60 | x = self.fc1(x)
61 | x, gates = x.chunk(2, dim=-1)
62 | x = x * self.act(gates)
63 | x = self.drop1(x)
64 | x = self.fc2(x)
65 | x = self.drop2(x)
66 | return x
67 |
68 |
69 | class GatedMlp(nn.Module):
70 | """ MLP as used in gMLP
71 | """
72 | def __init__(
73 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
74 | gate_layer=None, bias=True, drop=0.):
75 | super().__init__()
76 | out_features = out_features or in_features
77 | hidden_features = hidden_features or in_features
78 | bias = to_2tuple(bias)
79 | drop_probs = to_2tuple(drop)
80 |
81 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
82 | self.act = act_layer()
83 | self.drop1 = nn.Dropout(drop_probs[0])
84 | if gate_layer is not None:
85 | assert hidden_features % 2 == 0
86 | self.gate = gate_layer(hidden_features)
87 | hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
88 | else:
89 | self.gate = nn.Identity()
90 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
91 | self.drop2 = nn.Dropout(drop_probs[1])
92 |
93 | def forward(self, x):
94 | x = self.fc1(x)
95 | x = self.act(x)
96 | x = self.drop1(x)
97 | x = self.gate(x)
98 | x = self.fc2(x)
99 | x = self.drop2(x)
100 | return x
101 |
102 |
103 | class ConvMlp(nn.Module):
104 | """ MLP using 1x1 convs that keeps spatial dims
105 | """
106 | def __init__(
107 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU,
108 | norm_layer=None, bias=True, drop=0.):
109 | super().__init__()
110 | out_features = out_features or in_features
111 | hidden_features = hidden_features or in_features
112 | bias = to_2tuple(bias)
113 |
114 | self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0])
115 | self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
116 | self.act = act_layer()
117 | self.drop = nn.Dropout(drop)
118 | self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1])
119 |
120 | def forward(self, x):
121 | x = self.fc1(x)
122 | x = self.norm(x)
123 | x = self.act(x)
124 | x = self.drop(x)
125 | x = self.fc2(x)
126 | return x
127 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/norm.py:
--------------------------------------------------------------------------------
1 | """ Normalization layers and wrappers
2 |
3 | Norm layer definitions that support fast norm and consistent channel arg order (always first arg).
4 |
5 | Hacked together by / Copyright 2022 Ross Wightman
6 | """
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
13 |
14 |
15 | class GroupNorm(nn.GroupNorm):
16 | def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True):
17 | # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN
18 | super().__init__(num_groups, num_channels, eps=eps, affine=affine)
19 | self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
20 |
21 | def forward(self, x):
22 | if self.fast_norm:
23 | return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
24 | else:
25 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
26 |
27 |
28 | class GroupNorm1(nn.GroupNorm):
29 | """ Group Normalization with 1 group.
30 | Input: tensor in shape [B, C, *]
31 | """
32 |
33 | def __init__(self, num_channels, **kwargs):
34 | super().__init__(1, num_channels, **kwargs)
35 | self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
36 |
37 | def forward(self, x: torch.Tensor) -> torch.Tensor:
38 | if self.fast_norm:
39 | return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
40 | else:
41 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
42 |
43 |
44 | class LayerNorm(nn.LayerNorm):
45 | """ LayerNorm w/ fast norm option
46 | """
47 | def __init__(self, num_channels, eps=1e-6, affine=True):
48 | super().__init__(num_channels, eps=eps, elementwise_affine=affine)
49 | self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
50 |
51 | def forward(self, x: torch.Tensor) -> torch.Tensor:
52 | if self._fast_norm:
53 | x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
54 | else:
55 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
56 | return x
57 |
58 |
59 | class LayerNorm2d(nn.LayerNorm):
60 | """ LayerNorm for channels of '2D' spatial NCHW tensors """
61 | def __init__(self, num_channels, eps=1e-6, affine=True):
62 | super().__init__(num_channels, eps=eps, elementwise_affine=affine)
63 | self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals)
64 |
65 | def forward(self, x: torch.Tensor) -> torch.Tensor:
66 | x = x.permute(0, 2, 3, 1)
67 | if self._fast_norm:
68 | x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
69 | else:
70 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
71 | x = x.permute(0, 3, 1, 2)
72 | return x
73 |
74 |
75 | def _is_contiguous(tensor: torch.Tensor) -> bool:
76 | # jit is oh so lovely :/
77 | if torch.jit.is_scripting():
78 | return tensor.is_contiguous()
79 | else:
80 | return tensor.is_contiguous(memory_format=torch.contiguous_format)
81 |
82 |
83 | @torch.jit.script
84 | def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
85 | s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
86 | x = (x - u) * torch.rsqrt(s + eps)
87 | x = x * weight[:, None, None] + bias[:, None, None]
88 | return x
89 |
90 |
91 | def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float):
92 | u = x.mean(dim=1, keepdim=True)
93 | s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0)
94 | x = (x - u) * torch.rsqrt(s + eps)
95 | x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1)
96 | return x
97 |
98 |
99 | class LayerNormExp2d(nn.LayerNorm):
100 | """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
101 |
102 | Experimental implementation w/ manual norm for tensors non-contiguous tensors.
103 |
104 | This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last
105 | layout. However, benefits are not always clear and can perform worse on other GPUs.
106 | """
107 |
108 | def __init__(self, num_channels, eps=1e-6):
109 | super().__init__(num_channels, eps=eps)
110 |
111 | def forward(self, x) -> torch.Tensor:
112 | if _is_contiguous(x):
113 | x = F.layer_norm(
114 | x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
115 | else:
116 | x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
117 | return x
118 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/padding.py:
--------------------------------------------------------------------------------
1 | """ Padding Helpers
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import math
6 | from typing import List, Tuple
7 |
8 | import torch.nn.functional as F
9 |
10 |
11 | # Calculate symmetric padding for a convolution
12 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
13 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
14 | return padding
15 |
16 |
17 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
18 | def get_same_padding(x: int, k: int, s: int, d: int):
19 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
20 |
21 |
22 | # Can SAME padding for given args be done statically?
23 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
24 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0
25 |
26 |
27 | # Dynamically pad input x with 'SAME' padding for conv with specified args
28 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
29 | ih, iw = x.size()[-2:]
30 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
31 | if pad_h > 0 or pad_w > 0:
32 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
33 | return x
34 |
35 |
36 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
37 | dynamic = False
38 | if isinstance(padding, str):
39 | # for any string padding, the padding will be calculated for you, one of three ways
40 | padding = padding.lower()
41 | if padding == 'same':
42 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
43 | if is_static_pad(kernel_size, **kwargs):
44 | # static case, no extra overhead
45 | padding = get_padding(kernel_size, **kwargs)
46 | else:
47 | # dynamic 'SAME' padding, has runtime/GPU memory overhead
48 | padding = 0
49 | dynamic = True
50 | elif padding == 'valid':
51 | # 'VALID' padding, same as padding=0
52 | padding = 0
53 | else:
54 | # Default to PyTorch style 'same'-ish symmetric padding
55 | padding = get_padding(kernel_size, **kwargs)
56 | return padding, dynamic
57 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/patch_embed.py:
--------------------------------------------------------------------------------
1 | """ Image to Patch Embedding using Conv2d
2 |
3 | A convolution based approach to patchifying a 2D image w/ embedding projection.
4 |
5 | Based on the impl in https://github.com/google-research/vision_transformer
6 |
7 | Hacked together by / Copyright 2020 Ross Wightman
8 | """
9 | from torch import nn as nn
10 |
11 | from .helpers import to_2tuple
12 | from .trace_utils import _assert
13 |
14 |
15 | class PatchEmbed(nn.Module):
16 | """ 2D Image to Patch Embedding
17 | """
18 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
19 | super().__init__()
20 | img_size = to_2tuple(img_size)
21 | patch_size = to_2tuple(patch_size)
22 | self.img_size = img_size
23 | self.patch_size = patch_size
24 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
25 | self.num_patches = self.grid_size[0] * self.grid_size[1]
26 | self.flatten = flatten
27 |
28 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
29 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
30 |
31 | def forward(self, x):
32 | B, C, H, W = x.shape
33 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
34 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
35 | x = self.proj(x)
36 | if self.flatten:
37 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
38 | x = self.norm(x)
39 | return x
40 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/pool2d_same.py:
--------------------------------------------------------------------------------
1 | """ AvgPool2d w/ Same Padding
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from typing import List, Tuple, Optional
9 |
10 | from .helpers import to_2tuple
11 | from .padding import pad_same, get_padding_value
12 |
13 |
14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
15 | ceil_mode: bool = False, count_include_pad: bool = True):
16 | # FIXME how to deal with count_include_pad vs not for external padding?
17 | x = pad_same(x, kernel_size, stride)
18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
19 |
20 |
21 | class AvgPool2dSame(nn.AvgPool2d):
22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling
23 | """
24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
25 | kernel_size = to_2tuple(kernel_size)
26 | stride = to_2tuple(stride)
27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
28 |
29 | def forward(self, x):
30 | x = pad_same(x, self.kernel_size, self.stride)
31 | return F.avg_pool2d(
32 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
33 |
34 |
35 | def max_pool2d_same(
36 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0),
37 | dilation: List[int] = (1, 1), ceil_mode: bool = False):
38 | x = pad_same(x, kernel_size, stride, value=-float('inf'))
39 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode)
40 |
41 |
42 | class MaxPool2dSame(nn.MaxPool2d):
43 | """ Tensorflow like 'SAME' wrapper for 2D max pooling
44 | """
45 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
46 | kernel_size = to_2tuple(kernel_size)
47 | stride = to_2tuple(stride)
48 | dilation = to_2tuple(dilation)
49 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode)
50 |
51 | def forward(self, x):
52 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf'))
53 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode)
54 |
55 |
56 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):
57 | stride = stride or kernel_size
58 | padding = kwargs.pop('padding', '')
59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs)
60 | if is_dynamic:
61 | if pool_type == 'avg':
62 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs)
63 | elif pool_type == 'max':
64 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs)
65 | else:
66 | assert False, f'Unsupported pool type {pool_type}'
67 | else:
68 | if pool_type == 'avg':
69 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
70 | elif pool_type == 'max':
71 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs)
72 | else:
73 | assert False, f'Unsupported pool type {pool_type}'
74 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/separable_conv.py:
--------------------------------------------------------------------------------
1 | """ Depthwise Separable Conv Modules
2 |
3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the
4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception.
5 |
6 | Hacked together by / Copyright 2020 Ross Wightman
7 | """
8 | from torch import nn as nn
9 |
10 | from .create_conv2d import create_conv2d
11 | from .create_norm_act import get_norm_act_layer
12 |
13 |
14 | class SeparableConvNormAct(nn.Module):
15 | """ Separable Conv w/ trailing Norm and Activation
16 | """
17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU,
19 | apply_act=True, drop_layer=None):
20 | super(SeparableConvNormAct, self).__init__()
21 |
22 | self.conv_dw = create_conv2d(
23 | in_channels, int(in_channels * channel_multiplier), kernel_size,
24 | stride=stride, dilation=dilation, padding=padding, depthwise=True)
25 |
26 | self.conv_pw = create_conv2d(
27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
28 |
29 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
30 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
31 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
32 |
33 | @property
34 | def in_channels(self):
35 | return self.conv_dw.in_channels
36 |
37 | @property
38 | def out_channels(self):
39 | return self.conv_pw.out_channels
40 |
41 | def forward(self, x):
42 | x = self.conv_dw(x)
43 | x = self.conv_pw(x)
44 | x = self.bn(x)
45 | return x
46 |
47 |
48 | SeparableConvBnAct = SeparableConvNormAct
49 |
50 |
51 | class SeparableConv2d(nn.Module):
52 | """ Separable Conv
53 | """
54 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
55 | channel_multiplier=1.0, pw_kernel_size=1):
56 | super(SeparableConv2d, self).__init__()
57 |
58 | self.conv_dw = create_conv2d(
59 | in_channels, int(in_channels * channel_multiplier), kernel_size,
60 | stride=stride, dilation=dilation, padding=padding, depthwise=True)
61 |
62 | self.conv_pw = create_conv2d(
63 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
64 |
65 | @property
66 | def in_channels(self):
67 | return self.conv_dw.in_channels
68 |
69 | @property
70 | def out_channels(self):
71 | return self.conv_pw.out_channels
72 |
73 | def forward(self, x):
74 | x = self.conv_dw(x)
75 | x = self.conv_pw(x)
76 | return x
77 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/space_to_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class SpaceToDepth(nn.Module):
6 | def __init__(self, block_size=4):
7 | super().__init__()
8 | assert block_size == 4
9 | self.bs = block_size
10 |
11 | def forward(self, x):
12 | N, C, H, W = x.size()
13 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs)
14 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
15 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs)
16 | return x
17 |
18 |
19 | @torch.jit.script
20 | class SpaceToDepthJit(object):
21 | def __call__(self, x: torch.Tensor):
22 | # assuming hard-coded that block_size==4 for acceleration
23 | N, C, H, W = x.size()
24 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs)
25 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
26 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs)
27 | return x
28 |
29 |
30 | class SpaceToDepthModule(nn.Module):
31 | def __init__(self, no_jit=False):
32 | super().__init__()
33 | if not no_jit:
34 | self.op = SpaceToDepthJit()
35 | else:
36 | self.op = SpaceToDepth()
37 |
38 | def forward(self, x):
39 | return self.op(x)
40 |
41 |
42 | class DepthToSpace(nn.Module):
43 |
44 | def __init__(self, block_size):
45 | super().__init__()
46 | self.bs = block_size
47 |
48 | def forward(self, x):
49 | N, C, H, W = x.size()
50 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
51 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
52 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs)
53 | return x
54 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/split_attn.py:
--------------------------------------------------------------------------------
1 | """ Split Attention Conv2d (for ResNeSt Models)
2 |
3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
4 |
5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
6 |
7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
8 | """
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import nn
12 |
13 | from .helpers import make_divisible
14 |
15 |
16 | class RadixSoftmax(nn.Module):
17 | def __init__(self, radix, cardinality):
18 | super(RadixSoftmax, self).__init__()
19 | self.radix = radix
20 | self.cardinality = cardinality
21 |
22 | def forward(self, x):
23 | batch = x.size(0)
24 | if self.radix > 1:
25 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
26 | x = F.softmax(x, dim=1)
27 | x = x.reshape(batch, -1)
28 | else:
29 | x = torch.sigmoid(x)
30 | return x
31 |
32 |
33 | class SplitAttn(nn.Module):
34 | """Split-Attention (aka Splat)
35 | """
36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
38 | act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs):
39 | super(SplitAttn, self).__init__()
40 | out_channels = out_channels or in_channels
41 | self.radix = radix
42 | mid_chs = out_channels * radix
43 | if rd_channels is None:
44 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
45 | else:
46 | attn_chs = rd_channels * radix
47 |
48 | padding = kernel_size // 2 if padding is None else padding
49 | self.conv = nn.Conv2d(
50 | in_channels, mid_chs, kernel_size, stride, padding, dilation,
51 | groups=groups * radix, bias=bias, **kwargs)
52 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
53 | self.drop = drop_layer() if drop_layer is not None else nn.Identity()
54 | self.act0 = act_layer(inplace=True)
55 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
56 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
57 | self.act1 = act_layer(inplace=True)
58 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
59 | self.rsoftmax = RadixSoftmax(radix, groups)
60 |
61 | def forward(self, x):
62 | x = self.conv(x)
63 | x = self.bn0(x)
64 | x = self.drop(x)
65 | x = self.act0(x)
66 |
67 | B, RC, H, W = x.shape
68 | if self.radix > 1:
69 | x = x.reshape((B, self.radix, RC // self.radix, H, W))
70 | x_gap = x.sum(dim=1)
71 | else:
72 | x_gap = x
73 | x_gap = x_gap.mean((2, 3), keepdim=True)
74 | x_gap = self.fc1(x_gap)
75 | x_gap = self.bn1(x_gap)
76 | x_gap = self.act1(x_gap)
77 | x_attn = self.fc2(x_gap)
78 |
79 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
80 | if self.radix > 1:
81 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
82 | else:
83 | out = x * x_attn
84 | return out.contiguous()
85 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/split_batchnorm.py:
--------------------------------------------------------------------------------
1 | """ Split BatchNorm
2 |
3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias
5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
6 | namespace.
7 |
8 | This allows easily removing the auxiliary BN layers after training to efficiently
9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
10 | 'Disentangled Learning via An Auxiliary BN'
11 |
12 | Hacked together by / Copyright 2020 Ross Wightman
13 | """
14 | import torch
15 | import torch.nn as nn
16 |
17 |
18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d):
19 |
20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
21 | track_running_stats=True, num_splits=2):
22 | super().__init__(num_features, eps, momentum, affine, track_running_stats)
23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
24 | self.num_splits = num_splits
25 | self.aux_bn = nn.ModuleList([
26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
27 |
28 | def forward(self, input: torch.Tensor):
29 | if self.training: # aux BN only relevant while training
30 | split_size = input.shape[0] // self.num_splits
31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
32 | split_input = input.split(split_size)
33 | x = [super().forward(split_input[0])]
34 | for i, a in enumerate(self.aux_bn):
35 | x.append(a(split_input[i + 1]))
36 | return torch.cat(x, dim=0)
37 | else:
38 | return super().forward(input)
39 |
40 |
41 | def convert_splitbn_model(module, num_splits=2):
42 | """
43 | Recursively traverse module and its children to replace all instances of
44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
45 | Args:
46 | module (torch.nn.Module): input module
47 | num_splits: number of separate batchnorm layers to split input across
48 | Example::
49 | >>> # model is an instance of torch.nn.Module
50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2)
51 | """
52 | mod = module
53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
54 | return module
55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
56 | mod = SplitBatchNorm2d(
57 | module.num_features, module.eps, module.momentum, module.affine,
58 | module.track_running_stats, num_splits=num_splits)
59 | mod.running_mean = module.running_mean
60 | mod.running_var = module.running_var
61 | mod.num_batches_tracked = module.num_batches_tracked
62 | if module.affine:
63 | mod.weight.data = module.weight.data.clone().detach()
64 | mod.bias.data = module.bias.data.clone().detach()
65 | for aux in mod.aux_bn:
66 | aux.running_mean = module.running_mean.clone()
67 | aux.running_var = module.running_var.clone()
68 | aux.num_batches_tracked = module.num_batches_tracked.clone()
69 | if module.affine:
70 | aux.weight.data = module.weight.data.clone().detach()
71 | aux.bias.data = module.bias.data.clone().detach()
72 | for name, child in module.named_children():
73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
74 | del module
75 | return mod
76 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/squeeze_excite.py:
--------------------------------------------------------------------------------
1 | """ Squeeze-and-Excitation Channel Attention
2 |
3 | An SE implementation originally based on PyTorch SE-Net impl.
4 | Has since evolved with additional functionality / configuration.
5 |
6 | Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507
7 |
8 | Also included is Effective Squeeze-Excitation (ESE).
9 | Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
10 |
11 | Hacked together by / Copyright 2021 Ross Wightman
12 | """
13 | from torch import nn as nn
14 |
15 | from .create_act import create_act_layer
16 | from .helpers import make_divisible
17 |
18 |
19 | class SEModule(nn.Module):
20 | """ SE Module as defined in original SE-Nets with a few additions
21 | Additions include:
22 | * divisor can be specified to keep channels % div == 0 (default: 8)
23 | * reduction channels can be specified directly by arg (if rd_channels is set)
24 | * reduction channels can be specified by float rd_ratio (default: 1/16)
25 | * global max pooling can be added to the squeeze aggregation
26 | * customizable activation, normalization, and gate layer
27 | """
28 | def __init__(
29 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
30 | bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
31 | super(SEModule, self).__init__()
32 | self.add_maxpool = add_maxpool
33 | if not rd_channels:
34 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
35 | self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias)
36 | self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
37 | self.act = create_act_layer(act_layer, inplace=True)
38 | self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias)
39 | self.gate = create_act_layer(gate_layer)
40 |
41 | def forward(self, x):
42 | x_se = x.mean((2, 3), keepdim=True)
43 | if self.add_maxpool:
44 | # experimental codepath, may remove or change
45 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
46 | x_se = self.fc1(x_se)
47 | x_se = self.act(self.bn(x_se))
48 | x_se = self.fc2(x_se)
49 | return x * self.gate(x_se)
50 |
51 |
52 | SqueezeExcite = SEModule # alias
53 |
54 |
55 | class EffectiveSEModule(nn.Module):
56 | """ 'Effective Squeeze-Excitation
57 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
58 | """
59 | def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_):
60 | super(EffectiveSEModule, self).__init__()
61 | self.add_maxpool = add_maxpool
62 | self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
63 | self.gate = create_act_layer(gate_layer)
64 |
65 | def forward(self, x):
66 | x_se = x.mean((2, 3), keepdim=True)
67 | if self.add_maxpool:
68 | # experimental codepath, may remove or change
69 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
70 | x_se = self.fc(x_se)
71 | return x * self.gate(x_se)
72 |
73 |
74 | EffectiveSqueezeExcite = EffectiveSEModule # alias
75 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/test_time_pool.py:
--------------------------------------------------------------------------------
1 | """ Test Time Pooling (Average-Max Pool)
2 |
3 | Hacked together by / Copyright 2020 Ross Wightman
4 | """
5 |
6 | import logging
7 | from torch import nn
8 | import torch.nn.functional as F
9 |
10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
11 |
12 |
13 | _logger = logging.getLogger(__name__)
14 |
15 |
16 | class TestTimePoolHead(nn.Module):
17 | def __init__(self, base, original_pool=7):
18 | super(TestTimePoolHead, self).__init__()
19 | self.base = base
20 | self.original_pool = original_pool
21 | base_fc = self.base.get_classifier()
22 | if isinstance(base_fc, nn.Conv2d):
23 | self.fc = base_fc
24 | else:
25 | self.fc = nn.Conv2d(
26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True)
27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size()))
28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size()))
29 | self.base.reset_classifier(0) # delete original fc layer
30 |
31 | def forward(self, x):
32 | x = self.base.forward_features(x)
33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1)
34 | x = self.fc(x)
35 | x = adaptive_avgmax_pool2d(x, 1)
36 | return x.view(x.size(0), -1)
37 |
38 |
39 | def apply_test_time_pool(model, config, use_test_size=False):
40 | test_time_pool = False
41 | if not hasattr(model, 'default_cfg') or not model.default_cfg:
42 | return model, False
43 | if use_test_size and 'test_input_size' in model.default_cfg:
44 | df_input_size = model.default_cfg['test_input_size']
45 | else:
46 | df_input_size = model.default_cfg['input_size']
47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]:
48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' %
49 | (str(config['input_size'][-2:]), str(df_input_size[-2:])))
50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
51 | test_time_pool = True
52 | return model, test_time_pool
53 |
--------------------------------------------------------------------------------
/models/layers/maxvit/layers/trace_utils.py:
--------------------------------------------------------------------------------
1 | try:
2 | from torch import _assert
3 | except ImportError:
4 | def _assert(condition: bool, message: str):
5 | assert condition, message
6 |
7 |
8 | def _float_to_int(x: float) -> int:
9 | """
10 | Symbolic tracing helper to substitute for inbuilt `int`.
11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy`
12 | """
13 | return int(x)
14 |
--------------------------------------------------------------------------------
/models/layers/rnn.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch as th
4 | import torch.nn as nn
5 |
6 |
7 | class DWSConvLSTM2d(nn.Module):
8 | """LSTM with (depthwise-separable) Conv option in NCHW [channel-first] format.
9 | """
10 |
11 | def __init__(self,
12 | dim: int,
13 | dws_conv: bool = True, # RVT uses False
14 | dws_conv_only_hidden: bool = True,
15 | dws_conv_kernel_size: int = 3,
16 | cell_update_dropout: float = 0.):
17 | super().__init__()
18 | assert isinstance(dws_conv, bool)
19 | assert isinstance(dws_conv_only_hidden, bool)
20 | self.dim = dim
21 |
22 | xh_dim = dim * 2
23 | gates_dim = dim * 4
24 | conv3x3_dws_dim = dim if dws_conv_only_hidden else xh_dim
25 | # `self.conv3x3_dws` is actually just Identity mapping in RVT
26 | self.conv3x3_dws = nn.Conv2d(in_channels=conv3x3_dws_dim,
27 | out_channels=conv3x3_dws_dim,
28 | kernel_size=dws_conv_kernel_size,
29 | padding=dws_conv_kernel_size // 2,
30 | groups=conv3x3_dws_dim) if dws_conv else nn.Identity()
31 | self.conv1x1 = nn.Conv2d(in_channels=xh_dim,
32 | out_channels=gates_dim,
33 | kernel_size=1)
34 | self.conv_only_hidden = dws_conv_only_hidden
35 | self.cell_update_dropout = nn.Dropout(p=cell_update_dropout)
36 |
37 | def forward(self, x: th.Tensor, h_and_c_previous: Optional[Tuple[th.Tensor, th.Tensor]] = None) \
38 | -> Tuple[th.Tensor, th.Tensor]:
39 | """
40 | :param x: (N C H W), new features extracted at current time step
41 | :param h_and_c_previous: ((N C H W), (N C H W)), prev RNN states
42 | :return: ((N C H W), (N C H W))
43 | """
44 | if h_and_c_previous is None:
45 | # generate zero states
46 | hidden = th.zeros_like(x)
47 | cell = th.zeros_like(x)
48 | h_and_c_previous = (hidden, cell)
49 | h_tm1, c_tm1 = h_and_c_previous
50 |
51 | if self.conv_only_hidden:
52 | h_tm1 = self.conv3x3_dws(h_tm1)
53 | xh = th.cat((x, h_tm1), dim=1) # (N 2C H W)
54 | if not self.conv_only_hidden:
55 | xh = self.conv3x3_dws(xh)
56 | mix = self.conv1x1(xh)
57 |
58 | gates, cell_input = th.tensor_split(mix, [self.dim * 3], dim=1)
59 | assert gates.shape[1] == cell_input.shape[1] * 3
60 |
61 | gates = th.sigmoid(gates)
62 | forget_gate, input_gate, output_gate = th.tensor_split(gates, 3, dim=1)
63 | assert forget_gate.shape == input_gate.shape == output_gate.shape
64 |
65 | cell_input = self.cell_update_dropout(th.tanh(cell_input))
66 |
67 | c_t = forget_gate * c_tm1 + input_gate * cell_input
68 | h_t = output_gate * th.tanh(c_t)
69 |
70 | return h_t, c_t
71 |
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/modules/__init__.py
--------------------------------------------------------------------------------
/modules/tracking/__init__.py:
--------------------------------------------------------------------------------
1 | from .linear import LinearTracker
2 |
--------------------------------------------------------------------------------
/modules/tracking/tracker.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import numpy as np
4 |
5 |
6 | class Tracker(object):
7 | """Base class for multi-object video tracker."""
8 |
9 | def __init__(self, img_hw: Tuple[int, int], iou_threshold: float = 0.45):
10 | """
11 | Sets key parameters for the tracker.
12 | """
13 | self.img_hw = img_hw
14 | self.iou_threshold = iou_threshold
15 |
16 | self.trackers, self.prev_trackers, self.bbox_idx2tracker = [], [], {}
17 | self.track_count, self.bbox_count = 0, 0
18 | self.done = False
19 |
20 | def update(self,
21 | dets: np.ndarray = np.empty((0, 4)),
22 | is_gt: np.ndarray = np.empty((0, ))):
23 | raise NotImplementedError
24 |
25 | def _del_tracker(self, idx: int, done: bool = True):
26 | """Delete self.trackers[idx], move it to self.del_trackers."""
27 | tracker = self.trackers.pop(idx)
28 | tracker.finish(done=done)
29 | self.prev_trackers.append(tracker)
30 | for idx in tracker.bbox_idx:
31 | self.bbox_idx2tracker[idx] = tracker
32 |
33 | def finish(self):
34 | """Delete all remaining trackers."""
35 | for idx in reversed(range(len(self.trackers))):
36 | # don't filter out unfinished tracklets!
37 | self._del_tracker(idx, done=False)
38 | self.done = True
39 |
40 | def new(self):
41 | """Create a new tracker."""
42 | raise NotImplementedError
43 |
44 | def get_bbox_tracker(self, bbox_idx: int):
45 | """Get the bbox_tracker for the given bbox_idx."""
46 | assert self.done, 'Please call Tracker.finish() first.'
47 | return self.bbox_idx2tracker[bbox_idx]
48 |
--------------------------------------------------------------------------------
/modules/tracking/utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Tuple
3 |
4 | import numpy as np
5 |
6 |
7 | def greedy_matching(cost_matrix: np.ndarray, idx_lst: np.ndarray,
8 | thresh: float = 0.0) -> np.ndarray:
9 | cost_matrix = copy.deepcopy(cost_matrix)
10 | matched_indices = []
11 | assert len(idx_lst) == cost_matrix.shape[0]
12 | for i in idx_lst:
13 | if cost_matrix[i].max() < thresh:
14 | continue
15 | j = np.argmax(cost_matrix[i])
16 | cost_matrix[:, j] = -np.inf
17 | matched_indices.append([i, j])
18 | return np.array(matched_indices) # (N, 2)
19 |
20 |
21 | def iou_batch_xywh(bb_test: np.ndarray, bb_gt: np.ndarray) -> np.ndarray:
22 | """
23 | Computes IOU between two bboxes in the form [x,y,w,h,(cls_id)]
24 | both bbox are in shape (N, 4/5) where N is the number of bboxes
25 | If class_id is provided, take it into account by ignoring the IOU between
26 | bboxes of different classes.
27 | """
28 | bb_gt = np.expand_dims(bb_gt, 0)
29 | bb_test = np.expand_dims(bb_test, 1)
30 |
31 | xx1 = np.maximum(bb_test[..., 0] - bb_test[..., 2] / 2.,
32 | bb_gt[..., 0] - bb_gt[..., 2] / 2.)
33 | yy1 = np.maximum(bb_test[..., 1] - bb_test[..., 3] / 2.,
34 | bb_gt[..., 1] - bb_gt[..., 3] / 2.)
35 | xx2 = np.minimum(bb_test[..., 0] + bb_test[..., 2] / 2.,
36 | bb_gt[..., 0] + bb_gt[..., 2] / 2.)
37 | yy2 = np.minimum(bb_test[..., 1] + bb_test[..., 3] / 2.,
38 | bb_gt[..., 1] + bb_gt[..., 3] / 2.)
39 | w = np.maximum(0., xx2 - xx1)
40 | h = np.maximum(0., yy2 - yy1)
41 | wh = w * h
42 | o = wh / (
43 | bb_test[..., 2] * bb_test[..., 3] + bb_gt[..., 2] * bb_gt[..., 3] - wh)
44 |
45 | # set IoU of between different class objects to 0
46 | if bb_test.shape[-1] == 5 and bb_gt.shape[-1] == 5:
47 | o[bb_gt[..., 4] != bb_test[..., 4]] = 0.
48 |
49 | return o
50 |
51 |
52 | def xyxy2xywh(bbox: np.ndarray) -> np.ndarray:
53 | """
54 | Takes a bbox in the form [x1,y1,x2,y2] and returns a new bbox in the form
55 | [x,y,w,h] where x,y is the center and w,h are the width and height
56 | """
57 | x1, y1, x2, y2 = bbox
58 | bbox = [(x1 + x2) / 2., (y1 + y2) / 2., x2 - x1, y2 - y1]
59 | return np.array(bbox)
60 |
61 |
62 | def xywh2xyxy(bbox: np.ndarray) -> np.ndarray:
63 | """
64 | Takes a bounding box in the form [x,y,w,h] and returns it in the form
65 | [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
66 | """
67 | x, y, w, h = bbox
68 | bbox = [x - w / 2., y - h / 2., x + w / 2., y + h / 2.]
69 | return np.array(bbox)
70 |
71 |
72 | def clamp_bbox(bbox: np.ndarray,
73 | img_hw: Tuple[int, int],
74 | format_: str = 'xyxy') -> np.ndarray:
75 | """
76 | Clamp bbox to image boundaries.
77 | """
78 | # bbox: (4,) or (1, 4) or (4, 1)
79 | bbox_shape = bbox.shape
80 | bbox = bbox.squeeze() # to (4,)
81 | H, W = img_hw
82 | assert format_ in ['xyxy', 'xywh']
83 | if format_ == 'xywh':
84 | bbox = xywh2xyxy(bbox)
85 | x1_, y1_, x2_, y2_ = bbox
86 | x1 = np.clip(x1_, 0., W - 1.)
87 | x2 = np.clip(x2_, 0., W - 1.)
88 | y1 = np.clip(y1_, 0., H - 1.)
89 | y2 = np.clip(y2_, 0., H - 1.)
90 | bbox = np.array([x1, y1, x2, y2])
91 | clamp_top, clamp_down = (y1 != y1_), (y2 != y2_)
92 | clamp_left, clamp_right = (x1 != x1_), (x2 != x2_)
93 | if format_ == 'xywh':
94 | bbox = xyxy2xywh(bbox)
95 | bbox = bbox.reshape(bbox_shape)
96 | return bbox, clamp_top, clamp_down, clamp_left, clamp_right
97 |
--------------------------------------------------------------------------------
/modules/utils/fetch.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | from omegaconf import DictConfig
3 |
4 | from modules.data.genx import DataModule as genx_data_module
5 | from modules.detection import Module as rnn_det_module
6 | from modules.utils.tta import TTAModule
7 | from modules.pseudo_labeler import PseudoLabeler
8 |
9 |
10 | def fetch_model_module(config: DictConfig) -> pl.LightningModule:
11 | """Build model."""
12 | model_str = config.model.name
13 | if model_str == 'rnndet':
14 | if config.get('tta', {}).get('enable', False):
15 | return TTAModule(config)
16 | return rnn_det_module(config)
17 | elif model_str == 'pseudo_labeler':
18 | return PseudoLabeler(config)
19 | raise NotImplementedError
20 |
21 |
22 | def fetch_data_module(config: DictConfig) -> pl.LightningDataModule:
23 | """Build dataloaders."""
24 | batch_size_train = config.batch_size.train
25 | batch_size_eval = config.batch_size.eval
26 | num_workers_generic = config.hardware.get('num_workers', None)
27 | num_workers_train = config.hardware.num_workers.get('train', num_workers_generic)
28 | num_workers_eval = config.hardware.num_workers.get('eval', num_workers_generic)
29 | dataset_str = config.dataset.name
30 | if dataset_str in {'gen1', 'gen4'}:
31 | return genx_data_module(config.dataset,
32 | num_workers_train=num_workers_train,
33 | num_workers_eval=num_workers_eval,
34 | batch_size_train=batch_size_train,
35 | batch_size_eval=batch_size_eval)
36 | raise NotImplementedError
37 |
--------------------------------------------------------------------------------
/pretrained/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/pretrained/.gitkeep
--------------------------------------------------------------------------------
/utils/bbox.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List, Tuple
2 |
3 | import numpy as np
4 |
5 | import torch
6 | import torch as th
7 |
8 | from data.genx_utils.labels import ObjectLabels
9 |
10 |
11 | def np_th_stack(values: List[Union[np.ndarray, th.Tensor]], axis: int = 0):
12 | """Stack a list of numpy arrays or tensors."""
13 | if isinstance(values[0], np.ndarray):
14 | return np.stack(values, axis=axis)
15 | elif isinstance(values[0], th.Tensor):
16 | return torch.stack(values, dim=axis)
17 | else:
18 | raise ValueError(f'Unknown type {type(values[0])}')
19 |
20 |
21 | def np_th_concat(values: List[Union[np.ndarray, th.Tensor]], axis: int = 0):
22 | """Concat a list of numpy arrays or tensors."""
23 | if isinstance(values[0], np.ndarray):
24 | return np.concatenate(values, axis=axis)
25 | elif isinstance(values[0], th.Tensor):
26 | return torch.cat(values, dim=axis)
27 | else:
28 | raise ValueError(f'Unknown type {type(values[0])}')
29 |
30 |
31 | def get_bbox_coords(bbox: Union[np.ndarray, th.Tensor], last4: bool = None):
32 | """Get the 4 coords (xyxy/xywh) from a bbox array or tensor."""
33 | if isinstance(bbox, list):
34 | bbox = np_th_stack(bbox, axis=0)
35 | if last4 is None: # infer from shape, buggy when bbox.shape == (4, 4)
36 | if bbox.shape[0] == 4:
37 | last4 = False
38 | elif bbox.shape[-1] == 4:
39 | last4 = True
40 | else:
41 | raise ValueError(f'Unknown shape {bbox.shape}')
42 | if last4:
43 | a, b, c, d = bbox[..., 0], bbox[..., 1], bbox[..., 2], bbox[..., 3]
44 | else:
45 | a, b, c, d = bbox
46 | return (a, b, c, d), last4
47 |
48 |
49 | def construct_bbox(abcd: Tuple[Union[np.ndarray, th.Tensor]], last4: bool):
50 | """Construct a bbox from 4 coords (xyxy/xywh)."""
51 | if last4:
52 | return np_th_stack(abcd, axis=-1)
53 | return np_th_stack(abcd, axis=0)
54 |
55 |
56 | def xywh2xyxy(xywh: Union[np.ndarray, th.Tensor], format_: str = 'center', last4: bool = None):
57 | """Convert bounding box from xywh to xyxy format."""
58 | if isinstance(xywh, ObjectLabels):
59 | return xywh.get_xyxy()
60 |
61 | (x, y, w, h), last4 = get_bbox_coords(xywh, last4=last4)
62 |
63 | if format_ == 'center':
64 | x1, x2 = x - w / 2., x + w / 2.
65 | y1, y2 = y - h / 2., y + h / 2.
66 | elif format_ == 'corner':
67 | x1, x2 = x, x + w
68 | y1, y2 = y, y + h
69 | else:
70 | raise NotImplementedError(f'Unknown format {format_}')
71 |
72 | return construct_bbox((x1, y1, x2, y2), last4=last4)
73 |
74 |
75 | def xyxy2xywh(xyxy: Union[np.ndarray, th.Tensor], format_: str = 'center', last4: bool = None):
76 | """Convert bounding box from xyxy to xywh format."""
77 | if isinstance(xyxy, ObjectLabels):
78 | return xyxy.get_xywh(format_=format_)
79 |
80 | (x1, y1, x2, y2), last4 = get_bbox_coords(xyxy, last4=last4)
81 |
82 | w, h = x2 - x1, y2 - y1
83 | if format_ == 'center':
84 | x, y = (x1 + x2) / 2., (y1 + y2) / 2.
85 | elif format_ == 'corner':
86 | x, y = x1, y1
87 | else:
88 | raise NotImplementedError(f'Unknown format {format_}')
89 |
90 | return construct_bbox((x, y, w, h), last4=last4)
91 |
--------------------------------------------------------------------------------
/utils/evaluation/prophesee/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/utils/evaluation/prophesee/__init__.py
--------------------------------------------------------------------------------
/utils/evaluation/prophesee/evaluation.py:
--------------------------------------------------------------------------------
1 | from .io.box_filtering import filter_boxes
2 | from .metrics.coco_eval import evaluate_detection
3 |
4 |
5 | def evaluate_list(result_boxes_list,
6 | gt_boxes_list,
7 | height: int,
8 | width: int,
9 | camera: str = 'gen1',
10 | apply_bbox_filters: bool = True,
11 | downsampled_by_2: bool = False,
12 | return_aps: bool = True):
13 | assert camera in {'gen1', 'gen4'}
14 |
15 | if camera == 'gen1':
16 | classes = ("car", "pedestrian")
17 | elif camera == 'gen4':
18 | classes = ("pedestrian", "two-wheeler", "car")
19 | else:
20 | raise NotImplementedError
21 |
22 | if apply_bbox_filters:
23 | # Default values taken from: https://github.com/prophesee-ai/prophesee-automotive-dataset-toolbox/blob/0393adea2bf22d833893c8cb1d986fcbe4e6f82d/src/psee_evaluator.py#L23-L24
24 | min_box_diag = 60 if camera == 'gen4' else 30
25 | # In the supplementary mat, they say that min_box_side is 20 for gen4.
26 | min_box_side = 20 if camera == 'gen4' else 10
27 | if downsampled_by_2:
28 | assert min_box_diag % 2 == 0
29 | min_box_diag //= 2
30 | assert min_box_side % 2 == 0
31 | min_box_side //= 2
32 |
33 | half_sec_us = int(5e5)
34 | filter_boxes_fn = lambda x: filter_boxes(x, half_sec_us, min_box_diag, min_box_side)
35 |
36 | gt_boxes_list = map(filter_boxes_fn, gt_boxes_list)
37 | # NOTE: We also filter the prediction to follow the prophesee protocol of evaluation.
38 | result_boxes_list = map(filter_boxes_fn, result_boxes_list)
39 |
40 | return evaluate_detection(gt_boxes_list, result_boxes_list,
41 | height=height, width=width,
42 | classes=classes, return_aps=return_aps)
43 |
--------------------------------------------------------------------------------
/utils/evaluation/prophesee/evaluator.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Tuple, List, Optional, Dict
2 | from warnings import warn
3 |
4 | import numpy as np
5 |
6 | from utils.evaluation.prophesee.evaluation import evaluate_list
7 |
8 | LABELMAP = {
9 | 'gen1': ('car', 'ped'),
10 | 'gen4': ('ped', 'cyc', 'car'),
11 | }
12 |
13 |
14 | def get_labelmap(dst_name: str = None, num_cls: int = None) -> Tuple[str]:
15 | assert dst_name is None or num_cls is None
16 | if dst_name is not None:
17 | return LABELMAP[dst_name.lower()]
18 | elif num_cls is not None:
19 | assert num_cls in (2, 3), f'Invalid number of classes: {num_cls}'
20 | return LABELMAP['gen1'] if num_cls == 2 else LABELMAP['gen4']
21 | else:
22 | raise NotImplementedError('Either dst_name or num_cls must be input')
23 |
24 |
25 | class PropheseeEvaluator:
26 | LABELS = 'lables'
27 | PREDICTIONS = 'predictions'
28 |
29 | def __init__(self, dataset: str, downsample_by_2: bool):
30 | super().__init__()
31 | assert dataset in {'gen1', 'gen4'}
32 | self.dataset = dataset
33 | self.label_map = get_labelmap(dataset)
34 | self.downsample_by_2 = downsample_by_2
35 |
36 | self._buffer = None
37 | self._buffer_empty = True
38 | self._reset_buffer()
39 |
40 | def _reset_buffer(self):
41 | self._buffer_empty = True
42 | self._buffer = {
43 | self.LABELS: list(),
44 | self.PREDICTIONS: list(),
45 | }
46 |
47 | def _add_to_buffer(self, key: str, value: List[np.ndarray]):
48 | assert isinstance(value, list)
49 | for entry in value:
50 | assert isinstance(entry, np.ndarray)
51 | self._buffer_empty = False
52 | assert self._buffer is not None
53 | self._buffer[key].extend(value)
54 |
55 | def _get_from_buffer(self, key: str) -> List[np.ndarray]:
56 | assert not self._buffer_empty
57 | assert self._buffer is not None
58 | return self._buffer[key]
59 |
60 | def add_predictions(self, predictions: List[np.ndarray]):
61 | self._add_to_buffer(self.PREDICTIONS, predictions)
62 |
63 | def add_labels(self, labels: List[np.ndarray]):
64 | self._add_to_buffer(self.LABELS, labels)
65 |
66 | def reset_buffer(self) -> None:
67 | # E.g. call in on_validation_epoch_start
68 | self._reset_buffer()
69 |
70 | def has_data(self):
71 | return not self._buffer_empty
72 |
73 | def evaluate_buffer(self, img_height: int, img_width: int, ret_pr_curve: bool = False) -> Optional[Dict[str, Any]]:
74 | # e.g call in on_validation_epoch_end
75 | if self._buffer_empty:
76 | warn("Attempt to use prophesee evaluation buffer, but it is empty", UserWarning, stacklevel=2)
77 | return
78 |
79 | labels = self._get_from_buffer(self.LABELS)
80 | predictions = self._get_from_buffer(self.PREDICTIONS)
81 | assert len(labels) == len(predictions)
82 |
83 | # we perform both per-category and overall evaluation
84 | # overall
85 | metrics = evaluate_list(result_boxes_list=predictions,
86 | gt_boxes_list=labels,
87 | height=img_height,
88 | width=img_width,
89 | apply_bbox_filters=True,
90 | downsampled_by_2=self.downsample_by_2,
91 | camera=self.dataset)
92 | # per-category
93 | for cls_id, cls_name in enumerate(self.label_map):
94 | lbls = [lbl[lbl['class_id'] == cls_id] for lbl in labels]
95 | preds = [pred[pred['class_id'] == cls_id] for pred in predictions]
96 | cls_metric = evaluate_list(result_boxes_list=preds,
97 | gt_boxes_list=lbls,
98 | height=img_height,
99 | width=img_width,
100 | apply_bbox_filters=True,
101 | downsampled_by_2=self.downsample_by_2,
102 | camera=self.dataset)
103 | cls_metric = {f'{k}_{cls_name}': v for k, v in cls_metric.items()}
104 | metrics.update(cls_metric)
105 |
106 | if not ret_pr_curve:
107 | del_keys = [k for k in metrics.keys() if 'PR' in k]
108 | for k in del_keys:
109 | del metrics[k]
110 | return metrics
111 |
--------------------------------------------------------------------------------
/utils/evaluation/prophesee/io/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/utils/evaluation/prophesee/io/__init__.py
--------------------------------------------------------------------------------
/utils/evaluation/prophesee/io/box_filtering.py:
--------------------------------------------------------------------------------
1 | """
2 | Define same filtering that we apply in:
3 | "Learning to detect objects on a 1 Megapixel Event Camera" by Etienne Perot et al.
4 |
5 | Namely we apply 2 different filters:
6 | 1. skip all boxes before 0.5s (before we assume it is unlikely you have sufficient historic)
7 | 2. filter all boxes whose diagonal <= min_box_diag**2 and whose side <= min_box_side
8 |
9 |
10 |
11 | Copyright: (c) 2019-2020 Prophesee
12 | """
13 | from __future__ import print_function
14 |
15 | import numpy as np
16 |
17 |
18 | def filter_boxes(boxes, skip_ts=int(5e5), min_box_diag=60, min_box_side=20):
19 | """Filters boxes according to the paper rule.
20 |
21 | To note: the default represents our threshold when evaluating GEN4 resolution (1280x720)
22 | To note: we assume the initial time of the video is always 0
23 |
24 | Args:
25 | boxes (np.ndarray): structured box array with fields ['t','x','y','w','h','class_id','track_id','class_confidence']
26 | (example BBOX_DTYPE is provided in src/box_loading.py)
27 |
28 | Returns:
29 | boxes: filtered boxes
30 | """
31 | ts = boxes['t']
32 | width = boxes['w']
33 | height = boxes['h']
34 | diag_square = width ** 2 + height ** 2
35 | mask = (ts > skip_ts) * (diag_square >= min_box_diag ** 2) * (width >= min_box_side) * (height >= min_box_side)
36 | return boxes[mask]
37 |
--------------------------------------------------------------------------------
/utils/evaluation/prophesee/io/npy_events_tools.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | """
4 | Defines some tools to handle events, mimicking dat_events_tools.py.
5 | In particular :
6 | -> defines functions to read events from binary .npy files using numpy
7 | -> defines functions to write events to binary .dat files using numpy (TODO later)
8 |
9 | Copyright: (c) 2015-2019 Prophesee
10 | """
11 | from __future__ import print_function
12 |
13 | import numpy as np
14 |
15 |
16 | def stream_td_data(file_handle, buffer, dtype, ev_count=-1):
17 | """
18 | Streams data from opened file_handle
19 | args :
20 | - file_handle: file object
21 | - buffer: pre-allocated buffer to fill with events
22 | - dtype: expected fields
23 | - ev_count: number of events
24 | """
25 | dat = np.fromfile(file_handle, dtype=dtype, count=ev_count)
26 | count = len(dat['t'])
27 | for name, _ in dtype:
28 | buffer[name][:count] = dat[name]
29 |
30 |
31 | def parse_header(fhandle):
32 | """
33 | Parses the header of a .npy file
34 | Args:
35 | - f file handle to a .npy file
36 | return :
37 | - int position of the file cursor after the header
38 | - int type of event
39 | - int size of event in bytes
40 | - size (height, width) tuple of int or (None, None)
41 | """
42 | version = np.lib.format.read_magic(fhandle)
43 | shape, fortran, dtype = np.lib.format._read_array_header(fhandle, version)
44 | assert not fortran, "Fortran order arrays not supported"
45 | # Get the number of elements in one 'row' by taking
46 | # a product over all other dimensions.
47 | if len(shape) == 0:
48 | count = 1
49 | else:
50 | count = np.multiply.reduce(shape, dtype=np.int64)
51 | ev_size = dtype.itemsize
52 | assert ev_size != 0
53 | start = fhandle.tell()
54 | # turn numpy.dtype into an iterable list
55 | ev_type = [(x, str(dtype.fields[x][0])) for x in dtype.names]
56 | # filter name to have only t and not ts
57 | ev_type = [(name if name != "ts" else "t", desc) for name, desc in ev_type]
58 | ev_type = [(name if name != "confidence" else "class_confidence", desc) for name, desc in ev_type]
59 | size = (None, None)
60 | size = (None, None)
61 |
62 | return start, ev_type, ev_size, size
63 |
--------------------------------------------------------------------------------
/utils/evaluation/prophesee/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/utils/evaluation/prophesee/metrics/__init__.py
--------------------------------------------------------------------------------
/utils/evaluation/prophesee/visualize/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/utils/evaluation/prophesee/visualize/__init__.py
--------------------------------------------------------------------------------
/utils/padding.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Tuple
2 |
3 | import torch as th
4 | import torch.nn.functional as F
5 |
6 |
7 | class InputPadderFromShape:
8 | """Pad input to desired height and width."""
9 |
10 | def __init__(self, desired_hw: Tuple[int, int], mode: str = 'constant',
11 | value: int = 0, type: str = 'corner'):
12 | """
13 | :param desired_hw: Desired height and width
14 | :param mode: See torch.nn.functional.pad
15 | :param value: See torch.nn.functional.pad
16 | :param type: "corner": add zero to bottom and right
17 | """
18 | assert isinstance(desired_hw, tuple)
19 | assert len(desired_hw) == 2
20 | assert desired_hw[0] % 4 == 0, 'Required for token mask padding'
21 | assert desired_hw[1] % 4 == 0, 'Required for token mask padding'
22 | assert type in {'corner'}
23 |
24 | self.desired_hw = desired_hw
25 | self.mode = mode
26 | self.value = value
27 | self.type = type
28 | self._pad_ev_repr = None
29 | self._pad_token_mask = None
30 |
31 | @staticmethod
32 | def _pad_tensor_impl(input_tensor: th.Tensor, desired_hw: Tuple[int, int],
33 | mode: str, value: Any) -> Tuple[th.Tensor, List[int]]:
34 | assert isinstance(input_tensor, th.Tensor)
35 |
36 | ht, wd = input_tensor.shape[-2:]
37 | ht_des, wd_des = desired_hw
38 | assert ht <= ht_des
39 | assert wd <= wd_des
40 |
41 | pad_left = 0
42 | pad_right = wd_des - wd
43 | pad_top = 0
44 | pad_bottom = ht_des - ht
45 |
46 | pad = [pad_left, pad_right, pad_top, pad_bottom]
47 | return F.pad(input_tensor, pad=pad, mode=mode, value=value if
48 | mode == 'constant' else None), pad
49 |
50 | def pad_tensor_ev_repr(self, ev_repr: th.Tensor) -> th.Tensor:
51 | padded_ev_repr, pad = self._pad_tensor_impl(
52 | input_tensor=ev_repr, desired_hw=self.desired_hw,
53 | mode=self.mode, value=self.value)
54 | if self._pad_ev_repr is None:
55 | self._pad_ev_repr = pad
56 | else:
57 | assert self._pad_ev_repr == pad
58 | return padded_ev_repr
59 |
60 | def pad_token_mask(self, token_mask: th.Tensor):
61 | assert isinstance(token_mask, th.Tensor)
62 |
63 | desired_hw = tuple(x // 4 for x in self.desired_hw)
64 | padded_token_mask, pad = self._pad_tensor_impl(
65 | input_tensor=token_mask, desired_hw=desired_hw,
66 | mode='constant', value=0)
67 | if self._pad_token_mask is None:
68 | self._pad_token_mask = pad
69 | else:
70 | assert self._pad_token_mask == pad
71 | return padded_token_mask
72 |
--------------------------------------------------------------------------------
/utils/preprocessing.py:
--------------------------------------------------------------------------------
1 | from .helpers import subsample_list
2 |
3 |
4 | def _blosc_opts(complevel=1, complib='blosc:zstd', shuffle='byte'):
5 | shuffle = 2 if shuffle == 'bit' else 1 if shuffle == 'byte' else 0
6 | compressors = ['blosclz', 'lz4', 'lz4hc', 'snappy', 'zlib', 'zstd']
7 | complib = ['blosc:' + c for c in compressors].index(complib)
8 | args = {
9 | 'compression': 32001,
10 | 'compression_opts': (0, 0, 0, 0, complevel, shuffle, complib),
11 | }
12 | if shuffle > 0:
13 | # Do not use h5py shuffle if blosc shuffle is enabled.
14 | args['shuffle'] = False
15 | return args
16 |
17 |
18 | def subsample_sequence(split_path, ratio):
19 | """Subsample the sequence under a folder by a given ratio."""
20 | seq_dirs = sorted([p for p in split_path.iterdir()])
21 | print(f'Found {len(seq_dirs)} sequences in {str(split_path)}')
22 | # may need to sub-sample training seqs
23 | if 0. < ratio < 1.:
24 | num = round(len(seq_dirs) * ratio)
25 | seq_dirs = subsample_list(seq_dirs, num)
26 | assert len(seq_dirs) == num
27 | print(f'Using {ratio*100}% of data --> {len(seq_dirs)} sequences')
28 | return seq_dirs
29 |
--------------------------------------------------------------------------------
/utils/timers.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import time
3 | from functools import wraps
4 |
5 | import numpy as np
6 | import torch
7 |
8 | cuda_timers = {}
9 | timers = {}
10 |
11 |
12 | class CudaTimer:
13 | def __init__(self, device: torch.device, timer_name: str):
14 | assert isinstance(device, torch.device)
15 | assert isinstance(timer_name, str)
16 | self.timer_name = timer_name
17 | if self.timer_name not in cuda_timers:
18 | cuda_timers[self.timer_name] = []
19 |
20 | self.device = device
21 | self.start = None
22 | self.end = None
23 |
24 | def __enter__(self):
25 | torch.cuda.synchronize(device=self.device)
26 | self.start = time.time()
27 | return self
28 |
29 | def __exit__(self, *args):
30 | assert self.start is not None
31 | torch.cuda.synchronize(device=self.device)
32 | end = time.time()
33 | cuda_timers[self.timer_name].append(end - self.start)
34 |
35 |
36 | def cuda_timer_decorator(device: torch.device, timer_name: str):
37 | def decorator(func):
38 | @wraps(func)
39 | def wrapper(*args, **kwargs):
40 | with CudaTimer(device=device, timer_name=timer_name):
41 | out = func(*args, **kwargs)
42 | return out
43 |
44 | return wrapper
45 |
46 | return decorator
47 |
48 |
49 | class TimerDummy:
50 | def __init__(self, *args, **kwargs):
51 | pass
52 |
53 | def __enter__(self):
54 | pass
55 |
56 | def __exit__(self, *args):
57 | pass
58 |
59 |
60 | class Timer:
61 | def __init__(self, timer_name=''):
62 | self.timer_name = timer_name
63 | if self.timer_name not in timers:
64 | timers[self.timer_name] = []
65 |
66 | def __enter__(self):
67 | self.start = time.time()
68 | return self
69 |
70 | def __exit__(self, *args):
71 | end = time.time()
72 | time_diff_s = end - self.start # measured in seconds
73 | timers[self.timer_name].append(time_diff_s)
74 |
75 |
76 | def print_timing_info():
77 | print('== Timing statistics ==')
78 | skip_warmup = 10
79 | for timer_name, timing_values in [*cuda_timers.items(), *timers.items()]:
80 | if len(timing_values) <= skip_warmup:
81 | continue
82 | values = timing_values[skip_warmup:]
83 | timing_value_s_mean = np.mean(np.array(values))
84 | timing_value_s_median = np.median(np.array(values))
85 | timing_value_ms_mean = timing_value_s_mean * 1000
86 | timing_value_ms_median = timing_value_s_median * 1000
87 | if timing_value_ms_mean > 1000:
88 | print('{}: mean={:.2f} s, median={:.2f} s'.format(timer_name, timing_value_s_mean, timing_value_s_median))
89 | else:
90 | print(
91 | '{}: mean={:.2f} ms, median={:.2f} ms'.format(timer_name, timing_value_ms_mean, timing_value_ms_median))
92 |
93 |
94 | # this will print all the timer values upon termination of any program that imported this file
95 | atexit.register(print_timing_info)
96 |
--------------------------------------------------------------------------------
/val.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
4 | os.environ["OMP_NUM_THREADS"] = "1"
5 | os.environ["OPENBLAS_NUM_THREADS"] = "1"
6 | os.environ["MKL_NUM_THREADS"] = "1"
7 | os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
8 | os.environ["NUMEXPR_NUM_THREADS"] = "1"
9 | import hdf5plugin # resolve a weird h5py error
10 | import numpy as np
11 | import matplotlib.pyplot as plt
12 |
13 | import torch
14 | from torch.backends import cuda, cudnn
15 |
16 | cuda.matmul.allow_tf32 = True
17 | cudnn.allow_tf32 = True
18 | torch.multiprocessing.set_sharing_strategy('file_system')
19 |
20 | import hydra
21 | from omegaconf import DictConfig, OmegaConf
22 | import pytorch_lightning as pl
23 | from pytorch_lightning.loggers import CSVLogger
24 | from pytorch_lightning.callbacks import ModelSummary
25 |
26 | from config.modifier import dynamically_modify_train_config
27 | from modules.utils.fetch import fetch_data_module, fetch_model_module
28 |
29 |
30 | @hydra.main(config_path='config', config_name='val', version_base='1.2')
31 | def main(config: DictConfig):
32 | dynamically_modify_train_config(config)
33 | # Just to check whether config can be resolved
34 | OmegaConf.to_container(config, resolve=True, throw_on_missing=True)
35 |
36 | print('------ Configuration ------')
37 | # print(OmegaConf.to_yaml(config))
38 | _ = OmegaConf.to_yaml(config)
39 | print('---------------------------')
40 |
41 | # ---------------------
42 | # GPU options
43 | # ---------------------
44 | gpus = config.hardware.gpus
45 | assert isinstance(gpus, int), 'no more than 1 GPU supported'
46 | gpus = [gpus]
47 |
48 | # ---------------------
49 | # Data
50 | # ---------------------
51 | if 'T4' in torch.cuda.get_device_name() and \
52 | config.tta.enable and config.tta.hflip:
53 | if config.dataset.name == 'gen1':
54 | config.batch_size.eval = 12 # to avoid OOM on T4 GPU
55 | else:
56 | config.batch_size.eval = 6
57 | if config.reverse:
58 | config.dataset.reverse_event_order = True
59 | print('Testing on event sequences with reversed temporal order.')
60 | data_module = fetch_data_module(config=config)
61 |
62 | # ---------------------
63 | # Logging
64 | # ---------------------
65 | logger = CSVLogger(save_dir='./validation_logs')
66 |
67 | # ---------------------
68 | # Model
69 | # ---------------------
70 | module = fetch_model_module(config=config).eval()
71 | module.load_weight(config.checkpoint)
72 |
73 | # ---------------------
74 | # Callbacks and Misc
75 | # ---------------------
76 | callbacks = [ModelSummary(max_depth=2)]
77 |
78 | # ---------------------
79 | # Validation
80 | # ---------------------
81 |
82 | trainer = pl.Trainer(
83 | accelerator='gpu',
84 | callbacks=callbacks,
85 | default_root_dir=None,
86 | devices=gpus,
87 | logger=logger,
88 | log_every_n_steps=100,
89 | precision=config.training.precision,
90 | move_metrics_to_cpu=False,
91 | )
92 | with torch.inference_mode():
93 | trainer.test(model=module, datamodule=data_module)
94 | print(f'Evaluating {config.checkpoint=} finished.')
95 | print(f'Conf_thresh: {config.model.postprocess.confidence_threshold}')
96 |
97 |
98 | if __name__ == '__main__':
99 | main()
100 |
--------------------------------------------------------------------------------