├── .gitignore ├── LICENSE ├── README.md ├── assets ├── LEOD.png ├── det-results.png └── gen1-video.mp4 ├── callbacks ├── custom.py ├── detection.py ├── gradflow.py ├── utils │ └── visualization.py └── viz_base.py ├── config ├── dataset │ ├── base.yaml │ ├── gen1-tflip.yaml │ ├── gen1.yaml │ ├── gen1x0.01_seq.yaml │ ├── gen1x0.01_ss-1round.yaml │ ├── gen1x0.01_ss.yaml │ ├── gen1x0.02_seq.yaml │ ├── gen1x0.02_ss.yaml │ ├── gen1x0.05_seq.yaml │ ├── gen1x0.05_ss.yaml │ ├── gen1x0.1_seq.yaml │ ├── gen1x0.1_ss.yaml │ ├── gen4-tflip.yaml │ ├── gen4.yaml │ ├── gen4x0.01_seq.yaml │ ├── gen4x0.01_ss.yaml │ ├── gen4x0.02_seq.yaml │ ├── gen4x0.02_ss.yaml │ ├── gen4x0.05_seq.yaml │ ├── gen4x0.05_ss.yaml │ ├── gen4x0.1_seq.yaml │ └── gen4x0.1_ss.yaml ├── experiment │ ├── gen1 │ │ ├── base.yaml │ │ ├── default.yaml │ │ ├── small.yaml │ │ └── tiny.yaml │ └── gen4 │ │ ├── base.yaml │ │ ├── default.yaml │ │ ├── small.yaml │ │ └── tiny.yaml ├── general.yaml ├── model │ ├── base.yaml │ ├── maxvit_yolox │ │ └── default.yaml │ ├── pseudo_labeler-gen4-wsod.yaml │ ├── pseudo_labeler.yaml │ ├── rnndet-soft-gen4-wsod.yaml │ ├── rnndet-soft.yaml │ └── rnndet.yaml ├── modifier.py ├── predict.yaml ├── train.yaml ├── val.yaml └── vis.yaml ├── data ├── genx_utils │ ├── collate.py │ ├── collate_from_pytorch.py │ ├── dataset_rnd.py │ ├── dataset_streaming.py │ ├── labels.py │ ├── sequence_base.py │ ├── sequence_rnd.py │ ├── sequence_streaming.py │ └── splits │ │ ├── gen1 │ │ ├── ssod_0.010-off0.pkl │ │ ├── ssod_0.020-off0.pkl │ │ ├── ssod_0.050-off0.pkl │ │ └── ssod_0.100-off0.pkl │ │ └── gen4 │ │ ├── ssod_0.010-off0.pkl │ │ ├── ssod_0.020-off0.pkl │ │ ├── ssod_0.050-off0.pkl │ │ └── ssod_0.100-off0.pkl └── utils │ ├── augmentor.py │ ├── misc.py │ ├── representations.py │ ├── spatial.py │ ├── ssod_augmentor.py │ ├── stream_concat_datapipe.py │ ├── stream_sharded_datapipe.py │ └── types.py ├── datasets └── .gitkeep ├── docs ├── benchmark.md └── install.md ├── environment.yml ├── loggers └── utils.py ├── models ├── detection │ ├── __init__.py │ ├── recurrent_backbone │ │ ├── __init__.py │ │ ├── base.py │ │ └── maxvit_rnn.py │ ├── yolox │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── losses.py │ │ │ ├── network_blocks.py │ │ │ └── yolo_head.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── boxes.py │ │ │ └── compat.py │ └── yolox_extension │ │ └── models │ │ ├── __init__.py │ │ ├── build.py │ │ ├── detector.py │ │ └── yolo_pafpn.py └── layers │ ├── maxvit │ ├── __init__.py │ ├── layers │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── activations_jit.py │ │ ├── activations_me.py │ │ ├── adaptive_avgmax_pool.py │ │ ├── attention_pool2d.py │ │ ├── blur_pool.py │ │ ├── bottleneck_attn.py │ │ ├── cbam.py │ │ ├── classifier.py │ │ ├── cond_conv2d.py │ │ ├── config.py │ │ ├── conv2d_same.py │ │ ├── conv_bn_act.py │ │ ├── create_act.py │ │ ├── create_attn.py │ │ ├── create_conv2d.py │ │ ├── create_norm.py │ │ ├── create_norm_act.py │ │ ├── drop.py │ │ ├── eca.py │ │ ├── evo_norm.py │ │ ├── fast_norm.py │ │ ├── filter_response_norm.py │ │ ├── gather_excite.py │ │ ├── global_context.py │ │ ├── halo_attn.py │ │ ├── helpers.py │ │ ├── inplace_abn.py │ │ ├── lambda_layer.py │ │ ├── linear.py │ │ ├── median_pool.py │ │ ├── mixed_conv2d.py │ │ ├── ml_decoder.py │ │ ├── mlp.py │ │ ├── non_local_attn.py │ │ ├── norm.py │ │ ├── norm_act.py │ │ ├── padding.py │ │ ├── patch_embed.py │ │ ├── pool2d_same.py │ │ ├── pos_embed.py │ │ ├── selective_kernel.py │ │ ├── separable_conv.py │ │ ├── space_to_depth.py │ │ ├── split_attn.py │ │ ├── split_batchnorm.py │ │ ├── squeeze_excite.py │ │ ├── std_conv.py │ │ ├── test_time_pool.py │ │ ├── trace_utils.py │ │ └── weight_init.py │ └── maxvit.py │ └── rnn.py ├── modules ├── __init__.py ├── data │ └── genx.py ├── detection.py ├── pseudo_labeler.py ├── tracking │ ├── __init__.py │ ├── linear.py │ ├── tracker.py │ └── utils.py └── utils │ ├── detection.py │ ├── fetch.py │ ├── ssod.py │ └── tta.py ├── predict.py ├── pretrained └── .gitkeep ├── train.py ├── utils ├── bbox.py ├── evaluation │ └── prophesee │ │ ├── __init__.py │ │ ├── evaluation.py │ │ ├── evaluator.py │ │ ├── io │ │ ├── __init__.py │ │ ├── box_filtering.py │ │ ├── box_loading.py │ │ ├── dat_events_tools.py │ │ ├── npy_events_tools.py │ │ └── psee_loader.py │ │ ├── metrics │ │ ├── __init__.py │ │ └── coco_eval.py │ │ └── visualize │ │ ├── __init__.py │ │ └── vis_utils.py ├── helpers.py ├── padding.py ├── preprocessing.py └── timers.py ├── val.py ├── val_dst.py └── vis_pred.py /.gitignore: -------------------------------------------------------------------------------- 1 | pretrained 2 | pretrained/ 3 | checkpoint/ 4 | checkpoints/ 5 | datasets/ 6 | vis/ 7 | old_vis/ 8 | outputs/ 9 | validation_logs/ 10 | sbatch/ 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Mathias Gehrig & Ziyi Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LEOD 2 | 3 | This is the official Pytorch implementation for our CVPR 2024 paper: 4 | 5 | [**LEOD: Label-Efficient Object Detection for Event Cameras**](https://arxiv.org/abs/2311.17286)
6 | [Ziyi Wu](https://wuziyi616.github.io/), 7 | [Mathias Gehrig](https://magehrig.github.io/), 8 | Qing Lyu, 9 | Xudong Liu, 10 | [Igor Gilitschenski](https://tisl.cs.utoronto.ca/author/igor-gilitschenski/)
11 | _CVPR'24 | 12 | [GitHub](https://github.com/Wuziyi616/LEOD?tab=readme-ov-file#leod) | 13 | [arXiv](https://arxiv.org/abs/2311.17286)_ 14 | 15 |

16 | 17 |

18 | 19 | ## TL;DR 20 | 21 | [Event cameras](https://tub-rip.github.io/eventvision2023/#null) are bio-inspired low-latency sensors, which hold great potentials for safety-critical applications such as object detection in self-driving. 22 | Due to the high temporal resolution (>1000 FPS) of event data, existing datasets are annotated at a low frame rate (e.g., 4 FPS). 23 | As a result, models are only trained on these annotated frames, leading to sub-optimal performance and slow convergence speed. 24 | In this paper, we tackle this problem from the perspective of weakly-/semi-supervised learning. 25 | We design a novel self-training framework that pseudo-labels unannotated events with reliable model predictions, which achieves SOTA performance on two largest detection benchmarks. 26 | 27 |

28 | 29 |

30 | 31 | ## Install 32 | 33 | This codebase builds upon [RVT](https://github.com/uzh-rpg/RVT). 34 | Please refer to [install.md](./docs/install.md) for detailed instructions. 35 | 36 | ## Experiments 37 | 38 | **This codebase is tailored to [Slurm](https://slurm.schedmd.com/documentation.html) GPU clusters with preemption mechanism.** 39 | There are some functions in the code (e.g. auto-detect and load previous checkpoints) which you might not need. 40 | Please go through all fields marked with `TODO` in `train.py` in case there is any conflict with your environment. 41 | To reproduce the results in the paper, please refer to [benchmark.md](docs/benchmark.md). 42 | 43 | ## Citation 44 | 45 | Please cite our paper if you find it useful in your research: 46 | ```bibtex 47 | @inproceedings{wu2024leod, 48 | title={LEOD: Label-Efficient Object Detection for Event Cameras}, 49 | author={Wu, Ziyi and Gehrig, Mathias and Lyu, Qing and Liu, Xudong and Gilitschenski, Igor}, 50 | booktitle={CVPR}, 51 | year={2024} 52 | } 53 | ``` 54 | 55 | ## Acknowledgement 56 | 57 | We thank the authors of [RVT](https://github.com/uzh-rpg/RVT), [SORT](https://github.com/abewley/sort), [Soft Teacher](https://github.com/microsoft/SoftTeacher), [Unbiased Teacher](https://github.com/facebookresearch/unbiased-teacher) and all the packages we use in this repo for opening source their wonderful works. 58 | 59 | ## License 60 | 61 | LEOD is released under the MIT License. See the LICENSE file for more details. 62 | 63 | ## Contact 64 | 65 | If you have any questions about the code, please contact Ziyi Wu dazitu616@gmail.com 66 | -------------------------------------------------------------------------------- /assets/LEOD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/assets/LEOD.png -------------------------------------------------------------------------------- /assets/det-results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/assets/det-results.png -------------------------------------------------------------------------------- /assets/gen1-video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/assets/gen1-video.mp4 -------------------------------------------------------------------------------- /callbacks/custom.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from omegaconf import DictConfig 3 | from pytorch_lightning.callbacks import Callback 4 | from pytorch_lightning.callbacks import ModelCheckpoint 5 | 6 | from callbacks.detection import DetectionVizCallback 7 | 8 | 9 | def get_ckpt_callback(config: DictConfig, ckpt_dir: str = None) -> ModelCheckpoint: 10 | prefix = 'val' 11 | metric = 'AP' 12 | mode = 'max' 13 | ckpt_callback_monitor = prefix + '/' + metric 14 | filename_monitor_str = prefix + '_' + metric 15 | 16 | ckpt_filename = 'epoch_{epoch:03d}-step_{step}-' + filename_monitor_str + '_{' + ckpt_callback_monitor + ':.4f}' 17 | every_n_min = config.logging.ckpt_every_min 18 | cktp_callback = ModelCheckpoint( 19 | dirpath=ckpt_dir, 20 | monitor=ckpt_callback_monitor, 21 | filename=ckpt_filename, 22 | auto_insert_metric_name=False, # because backslash would create a directory 23 | save_top_k=2, # in case the best one is broken 24 | mode=mode, 25 | train_time_interval=timedelta(minutes=every_n_min), 26 | save_last=True, 27 | verbose=True) 28 | cktp_callback.CHECKPOINT_NAME_LAST = 'last_epoch_{epoch:03d}-step_{step}' 29 | return cktp_callback 30 | 31 | 32 | def get_viz_callback(config: DictConfig) -> Callback: 33 | if hasattr(config.model, 'pseudo_label'): 34 | prefixs = ['', 'pseudo_'] 35 | else: 36 | prefixs = [''] 37 | return DetectionVizCallback(config=config, prefixs=prefixs) 38 | -------------------------------------------------------------------------------- /callbacks/detection.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | from typing import Any, List 3 | 4 | import torch 5 | from einops import rearrange 6 | from omegaconf import DictConfig 7 | from pytorch_lightning.loggers.wandb import WandbLogger 8 | 9 | from data.utils.types import ObjDetOutput 10 | from utils.evaluation.prophesee.visualize.vis_utils import get_labelmap, draw_bboxes 11 | from .viz_base import VizCallbackBase 12 | 13 | 14 | class DetectionVizEnum(Enum): 15 | EV_IMG = auto() 16 | LABEL_IMG_PROPH = auto() 17 | PRED_IMG_PROPH = auto() 18 | 19 | 20 | class DetectionVizCallback(VizCallbackBase): 21 | """Visualize predicted and GT bbox on event-converted RGB frames.""" 22 | 23 | def __init__(self, config: DictConfig, prefixs: List[str] = ['']): 24 | super().__init__(config=config, buffer_entries=DetectionVizEnum) 25 | 26 | self.label_map = get_labelmap(dst_name=config.dataset.name) 27 | self.prefixs = prefixs 28 | 29 | def on_train_batch_end_custom(self, *args, **kwargs) -> None: 30 | for prefix in self.prefixs: 31 | self._on_train_batch_end_custom(*args, **kwargs, prefix=prefix) 32 | 33 | def _on_train_batch_end_custom(self, 34 | logger: WandbLogger, 35 | outputs: Any, 36 | batch: Any, 37 | log_n_samples: int, 38 | global_step: int, 39 | prefix: str) -> None: 40 | """May need to load images from different labeled data.""" 41 | if outputs is None: 42 | # If we tried to skip the training step (not supported in DDP in PL, atm) 43 | return 44 | if f'{prefix}{ObjDetOutput.EV_REPR}' not in outputs: 45 | return 46 | ev_tensors = outputs[f'{prefix}{ObjDetOutput.EV_REPR}'] 47 | num_samples = len(ev_tensors) 48 | assert num_samples > 0 49 | log_n_samples = min(num_samples, log_n_samples) 50 | 51 | merged_img = [] 52 | captions = [] 53 | start_idx = num_samples - 1 54 | end_idx = start_idx - log_n_samples 55 | # for sample_idx in range(log_n_samples): 56 | for sample_idx in range(start_idx, end_idx, -1): 57 | ev_img = self.ev_repr_to_img(ev_tensors[sample_idx].cpu().numpy()) 58 | 59 | predictions_proph = outputs[f'{prefix}{ObjDetOutput.PRED_PROPH}'][sample_idx] 60 | prediction_img = ev_img.copy() 61 | draw_bboxes(prediction_img, predictions_proph, labelmap=self.label_map) 62 | 63 | labels_proph = outputs[f'{prefix}{ObjDetOutput.LABELS_PROPH}'][sample_idx] 64 | label_img = ev_img.copy() 65 | draw_bboxes(label_img, labels_proph, labelmap=self.label_map) 66 | 67 | merged_img.append(rearrange([prediction_img, label_img], 'pl H W C -> (pl H) W C', pl=2, C=3)) 68 | captions.append(f'sample_{sample_idx}') 69 | 70 | logger.log_image(key=f'train/{prefix}predictions', # PL's native wandb 71 | images=merged_img, 72 | caption=captions, 73 | step=global_step) 74 | 75 | def on_validation_batch_end_custom(self, batch: Any, outputs: Any) -> None: 76 | """Val is not affected by pseudo-label training.""" 77 | if outputs[ObjDetOutput.SKIP_VIZ]: 78 | return 79 | ev_tensor = outputs[ObjDetOutput.EV_REPR] 80 | assert isinstance(ev_tensor, torch.Tensor) 81 | 82 | ev_img = self.ev_repr_to_img(ev_tensor.cpu().numpy()) 83 | 84 | predictions_proph = outputs[ObjDetOutput.PRED_PROPH] 85 | prediction_img = ev_img.copy() 86 | draw_bboxes(prediction_img, predictions_proph, labelmap=self.label_map) 87 | self.add_to_buffer(DetectionVizEnum.PRED_IMG_PROPH, prediction_img) 88 | 89 | labels_proph = outputs[ObjDetOutput.LABELS_PROPH] 90 | label_img = ev_img.copy() 91 | draw_bboxes(label_img, labels_proph, labelmap=self.label_map) 92 | self.add_to_buffer(DetectionVizEnum.LABEL_IMG_PROPH, label_img) 93 | 94 | def on_validation_epoch_end_custom(self, logger: WandbLogger): 95 | pred_imgs = self.get_from_buffer(DetectionVizEnum.PRED_IMG_PROPH) 96 | label_imgs = self.get_from_buffer(DetectionVizEnum.LABEL_IMG_PROPH) 97 | assert len(pred_imgs) == len(label_imgs) 98 | merged_img = [] 99 | captions = [] 100 | for idx, (pred_img, label_img) in enumerate(zip(pred_imgs, label_imgs)): 101 | merged_img.append(rearrange([pred_img, label_img], 'pl H W C -> (pl H) W C', pl=2, C=3)) 102 | captions.append(f'sample_{idx}') 103 | 104 | logger.log_image(key='val/predictions', 105 | images=merged_img, 106 | caption=captions) 107 | -------------------------------------------------------------------------------- /callbacks/gradflow.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.callbacks import Callback 5 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 6 | 7 | from callbacks.utils.visualization import get_grad_flow_figure 8 | 9 | 10 | class GradFlowLogCallback(Callback): 11 | def __init__(self, log_every_n_train_steps: int): 12 | super().__init__() 13 | assert log_every_n_train_steps > 0 14 | self.log_every_n_train_steps = log_every_n_train_steps 15 | 16 | @rank_zero_only 17 | def on_before_zero_grad(self, trainer: pl.Trainer, pl_module: pl.LightningModule, optimizer: Any) -> None: 18 | # NOTE: before we had this in the on_after_backward callback. 19 | # This was fine for fp32 but showed unscaled gradients for fp16. 20 | # That is why we move it to on_before_zero_grad where gradients are scaled. 21 | global_step = trainer.global_step 22 | if global_step % self.log_every_n_train_steps != 0: 23 | return 24 | named_parameters = pl_module.named_parameters() 25 | figure = get_grad_flow_figure(named_parameters) 26 | trainer.logger.log_metrics({'train/gradients': figure}, step=global_step) 27 | -------------------------------------------------------------------------------- /callbacks/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import plotly.express as px 3 | 4 | 5 | def get_grad_flow_figure(named_params): 6 | """Creates figure to visualize gradients flowing through different layers in the net during training. 7 | Can be used for checking for possible gradient vanishing / exploding problems. 8 | Usage: Use this function after loss.backwards() 9 | """ 10 | data_dict = { 11 | 'name': list(), 12 | 'grad_abs': list(), 13 | } 14 | for name, param in named_params: 15 | if param.requires_grad and param.grad is not None: 16 | grad_abs = param.grad.abs() 17 | data_dict['name'].append(name) 18 | data_dict['grad_abs'].append(grad_abs.mean().cpu().item()) 19 | 20 | data_frame = pd.DataFrame.from_dict(data_dict) 21 | 22 | fig = px.bar(data_frame, x='name', y='grad_abs') 23 | return fig 24 | -------------------------------------------------------------------------------- /config/dataset/base.yaml: -------------------------------------------------------------------------------- 1 | name: ??? 2 | path: ??? 3 | ssod: False # for semi-supervised object detection, see `modifier.py` 4 | ratio: -1 # sub-sample the labeling frequency of each event sequence 5 | train_ratio: -1 # sub-sample the training set 6 | val_ratio: -1 # to accelerate the val process, currently takes ~20min 7 | test_ratio: -1 # to accelerate the test process, currently takes ~20min 8 | only_load_labels: False # only load the label, not the events 9 | reverse_event_order: False # reverse the temporal order of events 10 | train: 11 | sampling: 'mixed' # ('random', 'stream', 'mixed') 12 | random: 13 | weighted_sampling: False 14 | mixed: 15 | w_stream: 1 16 | w_random: 1 17 | eval: 18 | sampling: 'stream' 19 | data_augmentation: 20 | tflip_offset: -1 21 | random: 22 | prob_hflip: 0.5 23 | prob_tflip: 0 24 | rotate: 25 | prob: 0 26 | min_angle_deg: 2 27 | max_angle_deg: 6 28 | zoom: 29 | prob: 0.8 30 | zoom_in: 31 | weight: 8 32 | factor: 33 | min: 1 34 | max: 1.5 35 | zoom_out: 36 | weight: 2 37 | factor: 38 | min: 1 39 | max: 1.2 40 | stream: 41 | start_from_zero: False 42 | prob_hflip: 0.5 43 | prob_tflip: 0 44 | rotate: 45 | prob: 0 46 | min_angle_deg: 2 47 | max_angle_deg: 6 48 | zoom: 49 | prob: 0.5 50 | zoom_out: 51 | factor: 52 | min: 1 53 | max: 1.2 54 | -------------------------------------------------------------------------------- /config/dataset/gen1-tflip.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen1 3 | 4 | # time-flip data augmentation 5 | data_augmentation: 6 | random: 7 | prob_tflip: 0.5 8 | stream: 9 | prob_tflip: 0.5 10 | -------------------------------------------------------------------------------- /config/dataset/gen1.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | name: gen1 5 | path: ./datasets/gen1/ 6 | ev_repr_name: 'stacked_histogram_dt=50_nbins=10' 7 | sequence_length: 21 8 | resolution_hw: [240, 304] 9 | downsample_by_factor_2: False 10 | only_load_end_labels: False 11 | -------------------------------------------------------------------------------- /config/dataset/gen1x0.01_seq.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen1 3 | 4 | train_ratio: 0.01 5 | val_ratio: 0.5 6 | 7 | # time-flip data augmentation 8 | data_augmentation: 9 | random: 10 | prob_tflip: 0.5 11 | stream: 12 | prob_tflip: 0.5 13 | -------------------------------------------------------------------------------- /config/dataset/gen1x0.01_ss-1round.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen1 3 | 4 | ratio: -1 # we now use all the pesudo labels 5 | val_ratio: 0.5 6 | 7 | path: ./datasets/pseudo_gen1/gen1x0.01_ss-1round 8 | 9 | # time-flip data augmentation 10 | data_augmentation: 11 | random: 12 | prob_tflip: 0.5 13 | stream: 14 | prob_tflip: 0.5 15 | -------------------------------------------------------------------------------- /config/dataset/gen1x0.01_ss.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen1 3 | 4 | ratio: 0.01 5 | val_ratio: 0.5 6 | 7 | # time-flip data augmentation 8 | data_augmentation: 9 | random: 10 | prob_tflip: 0.5 11 | stream: 12 | prob_tflip: 0.5 13 | -------------------------------------------------------------------------------- /config/dataset/gen1x0.02_seq.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen1 3 | 4 | train_ratio: 0.02 5 | val_ratio: 0.5 6 | 7 | # time-flip data augmentation 8 | data_augmentation: 9 | random: 10 | prob_tflip: 0.5 11 | stream: 12 | prob_tflip: 0.5 13 | -------------------------------------------------------------------------------- /config/dataset/gen1x0.02_ss.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen1 3 | 4 | ratio: 0.02 5 | val_ratio: 0.5 6 | 7 | # time-flip data augmentation 8 | data_augmentation: 9 | random: 10 | prob_tflip: 0.5 11 | stream: 12 | prob_tflip: 0.5 13 | -------------------------------------------------------------------------------- /config/dataset/gen1x0.05_seq.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen1 3 | 4 | train_ratio: 0.05 5 | val_ratio: 0.5 6 | 7 | # time-flip data augmentation 8 | data_augmentation: 9 | random: 10 | prob_tflip: 0.5 11 | stream: 12 | prob_tflip: 0.5 13 | -------------------------------------------------------------------------------- /config/dataset/gen1x0.05_ss.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen1 3 | 4 | ratio: 0.05 5 | val_ratio: 0.5 6 | 7 | # time-flip data augmentation 8 | data_augmentation: 9 | random: 10 | prob_tflip: 0.5 11 | stream: 12 | prob_tflip: 0.5 13 | -------------------------------------------------------------------------------- /config/dataset/gen1x0.1_seq.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen1 3 | 4 | train_ratio: 0.1 5 | val_ratio: 0.5 6 | 7 | # time-flip data augmentation 8 | data_augmentation: 9 | random: 10 | prob_tflip: 0.5 11 | stream: 12 | prob_tflip: 0.5 13 | -------------------------------------------------------------------------------- /config/dataset/gen1x0.1_ss.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen1 3 | 4 | ratio: 0.1 5 | val_ratio: 0.5 6 | 7 | # time-flip data augmentation 8 | data_augmentation: 9 | random: 10 | prob_tflip: 0.5 11 | stream: 12 | prob_tflip: 0.5 13 | -------------------------------------------------------------------------------- /config/dataset/gen4-tflip.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen4 3 | 4 | sequence_length: 5 5 | 6 | # time-flip data augmentation 7 | data_augmentation: 8 | random: 9 | prob_tflip: 0.5 10 | stream: 11 | prob_tflip: 0.5 12 | -------------------------------------------------------------------------------- /config/dataset/gen4.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | name: gen4 5 | path: ./datasets/gen4/ 6 | ev_repr_name: 'stacked_histogram_dt=50_nbins=10' 7 | sequence_length: 5 8 | resolution_hw: [720, 1280] 9 | downsample_by_factor_2: True 10 | only_load_end_labels: False 11 | 12 | data_augmentation: 13 | tflip_offset: -2 # the GT labels are not well-aligned with the event frames 14 | -------------------------------------------------------------------------------- /config/dataset/gen4x0.01_seq.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen4 3 | 4 | train_ratio: 0.01 5 | val_ratio: 0.5 6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage 7 | 8 | # time-flip data augmentation 9 | data_augmentation: 10 | random: 11 | prob_tflip: 0.5 12 | stream: 13 | prob_tflip: 0.5 14 | -------------------------------------------------------------------------------- /config/dataset/gen4x0.01_ss.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen4 3 | 4 | ratio: 0.01 5 | val_ratio: 0.5 6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage 7 | 8 | # time-flip data augmentation 9 | data_augmentation: 10 | random: 11 | prob_tflip: 0.5 12 | stream: 13 | prob_tflip: 0.5 14 | -------------------------------------------------------------------------------- /config/dataset/gen4x0.02_seq.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen4 3 | 4 | train_ratio: 0.02 5 | val_ratio: 0.5 6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage 7 | 8 | # time-flip data augmentation 9 | data_augmentation: 10 | random: 11 | prob_tflip: 0.5 12 | stream: 13 | prob_tflip: 0.5 14 | -------------------------------------------------------------------------------- /config/dataset/gen4x0.02_ss.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen4 3 | 4 | ratio: 0.02 5 | val_ratio: 0.5 6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage 7 | 8 | # time-flip data augmentation 9 | data_augmentation: 10 | random: 11 | prob_tflip: 0.5 12 | stream: 13 | prob_tflip: 0.5 14 | -------------------------------------------------------------------------------- /config/dataset/gen4x0.05_seq.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen4 3 | 4 | train_ratio: 0.05 5 | val_ratio: 0.5 6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage 7 | 8 | # time-flip data augmentation 9 | data_augmentation: 10 | random: 11 | prob_tflip: 0.5 12 | stream: 13 | prob_tflip: 0.5 14 | -------------------------------------------------------------------------------- /config/dataset/gen4x0.05_ss.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen4 3 | 4 | ratio: 0.05 5 | val_ratio: 0.5 6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage 7 | 8 | # time-flip data augmentation 9 | data_augmentation: 10 | random: 11 | prob_tflip: 0.5 12 | stream: 13 | prob_tflip: 0.5 14 | -------------------------------------------------------------------------------- /config/dataset/gen4x0.1_seq.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen4 3 | 4 | train_ratio: 0.1 5 | val_ratio: 0.5 6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage 7 | 8 | # time-flip data augmentation 9 | data_augmentation: 10 | random: 11 | prob_tflip: 0.5 12 | stream: 13 | prob_tflip: 0.5 14 | -------------------------------------------------------------------------------- /config/dataset/gen4x0.1_ss.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - gen4 3 | 4 | ratio: 0.1 5 | val_ratio: 0.5 6 | sequence_length: 10 # longer is better for the pre-training (burn-in) stage 7 | 8 | # time-flip data augmentation 9 | data_augmentation: 10 | random: 11 | prob_tflip: 0.5 12 | stream: 13 | prob_tflip: 0.5 14 | -------------------------------------------------------------------------------- /config/experiment/gen1/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - default 4 | 5 | model: 6 | backbone: 7 | embed_dim: 64 8 | fpn: 9 | depth: 0.67 10 | -------------------------------------------------------------------------------- /config/experiment/gen1/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /model/maxvit_yolox: default 4 | 5 | training: 6 | precision: 16 7 | max_epochs: 10000 8 | max_steps: 400000 9 | learning_rate: 0.0002 10 | lr_scheduler: 11 | use: True 12 | total_steps: ${..max_steps} 13 | pct_start: 0.005 14 | div_factor: 20 15 | final_div_factor: 10000 16 | batch_size: 17 | train: 8 18 | eval: 8 19 | hardware: 20 | num_workers: 21 | train: 8 22 | eval: 8 23 | dataset: 24 | train: 25 | sampling: 'mixed' 26 | random: 27 | weighted_sampling: False 28 | mixed: 29 | w_stream: 1 30 | w_random: 1 31 | eval: 32 | sampling: 'stream' 33 | model: 34 | backbone: 35 | partition_split_32: 1 36 | -------------------------------------------------------------------------------- /config/experiment/gen1/small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - default 4 | 5 | model: 6 | backbone: 7 | embed_dim: 48 8 | stage: 9 | attention: 10 | dim_head: 24 11 | fpn: 12 | depth: 0.33 13 | -------------------------------------------------------------------------------- /config/experiment/gen1/tiny.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - default 4 | 5 | model: 6 | backbone: 7 | embed_dim: 32 8 | fpn: 9 | depth: 0.33 10 | -------------------------------------------------------------------------------- /config/experiment/gen4/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - default 4 | 5 | model: 6 | backbone: 7 | embed_dim: 64 8 | fpn: 9 | depth: 0.67 10 | -------------------------------------------------------------------------------- /config/experiment/gen4/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /model/maxvit_yolox: default 4 | 5 | training: 6 | precision: 16 7 | max_epochs: 10000 8 | max_steps: 400000 9 | learning_rate: 0.000346 10 | lr_scheduler: 11 | use: True 12 | total_steps: ${..max_steps} 13 | pct_start: 0.005 14 | div_factor: 20 15 | final_div_factor: 10000 16 | batch_size: 17 | train: 12 18 | eval: 12 19 | hardware: 20 | num_workers: 21 | train: 8 22 | eval: 4 23 | dataset: 24 | train: 25 | sampling: 'mixed' 26 | random: 27 | weighted_sampling: False 28 | mixed: 29 | w_stream: 1 30 | w_random: 1 31 | eval: 32 | sampling: 'stream' 33 | -------------------------------------------------------------------------------- /config/experiment/gen4/small.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - default 4 | 5 | model: 6 | backbone: 7 | embed_dim: 48 8 | stage: 9 | attention: 10 | dim_head: 24 11 | fpn: 12 | depth: 0.33 13 | -------------------------------------------------------------------------------- /config/experiment/gen4/tiny.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - default 4 | 5 | model: 6 | backbone: 7 | embed_dim: 32 8 | fpn: 9 | depth: 0.33 10 | -------------------------------------------------------------------------------- /config/general.yaml: -------------------------------------------------------------------------------- 1 | reproduce: 2 | seed_everything: null # Union[int, null] 3 | deterministic_flag: False # Must be true for fully deterministic behaviour (slows down training) 4 | benchmark: True # Should be set to false for fully deterministic behaviour. Could potentially speed up training. 5 | training: 6 | precision: 16 7 | max_epochs: 10000 8 | max_steps: 400000 9 | learning_rate: 0.0002 10 | weight_decay: 0 11 | gradient_clip_val: 1.0 12 | limit_train_batches: 1.0 13 | lr_scheduler: 14 | use: True 15 | total_steps: ${..max_steps} 16 | pct_start: 0.005 17 | div_factor: 25 # init_lr = max_lr / div_factor 18 | final_div_factor: 10000 # final_lr = max_lr / final_div_factor (this is different from Pytorch' OneCycleLR param) 19 | validation: 20 | limit_val_batches: 1.0 21 | val_check_interval: 20000 # Optional[int] 22 | check_val_every_n_epoch: null # Optional[int] 23 | batch_size: 24 | train: 8 25 | eval: 8 26 | hardware: 27 | num_workers: 28 | train: 8 29 | eval: 8 30 | gpus: 0 # Either a single integer (e.g. 3) or a list of integers (e.g. [3,5,6]) 31 | dist_backend: "nccl" 32 | logging: 33 | ckpt_every_min: 18 # checkpoint every x minutes 34 | train: 35 | metrics: 36 | compute: false 37 | detection_metrics_every_n_steps: null # Optional[int] -> null: every train epoch, int: every N steps 38 | log_model_every_n_steps: 5000 39 | log_every_n_steps: 100 40 | high_dim: 41 | enable: True 42 | every_n_steps: 5000 43 | n_samples: 4 44 | validation: 45 | high_dim: 46 | enable: True 47 | every_n_epochs: 1 48 | n_samples: 8 49 | wandb: 50 | wandb_name: null # name of the run 51 | wandb_id: null # name of the run 52 | wandb_runpath: null # WandB run path. E.g. USERNAME/PROJECTNAME/1grv5kg6 53 | group_name: null # Specify group name of the run 54 | project_name: RVT 55 | suffix: "" # full dup run name suffix 56 | weight: "" # only resume weight 57 | checkpoint: "" # resume weight + training state 58 | pretrain_teacher_checkpoint: "" # pre-trained weight of the teacher model 59 | pretrain_student_checkpoint: "" # pre-trained weight of the student model 60 | -------------------------------------------------------------------------------- /config/model/base.yaml: -------------------------------------------------------------------------------- 1 | name: ??? 2 | -------------------------------------------------------------------------------- /config/model/maxvit_yolox/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - override /model: rnndet 4 | 5 | model: 6 | backbone: 7 | name: MaxViTRNN 8 | compile: 9 | enable: False 10 | args: 11 | mode: reduce-overhead 12 | input_channels: 20 13 | enable_masking: False 14 | partition_split_32: 2 15 | embed_dim: 64 16 | dim_multiplier: [1, 2, 4, 8] 17 | num_blocks: [1, 1, 1, 1] 18 | T_max_chrono_init: [4, 8, 16, 32] 19 | stem: 20 | patch_size: 4 21 | stage: 22 | downsample: 23 | type: patch 24 | overlap: True 25 | norm_affine: True 26 | attention: 27 | use_torch_mha: False 28 | partition_size: ??? 29 | dim_head: 32 30 | attention_bias: True 31 | mlp_activation: gelu 32 | mlp_gated: False 33 | mlp_bias: True 34 | mlp_ratio: 4 35 | drop_mlp: 0 36 | drop_path: 0 37 | ls_init_value: 1e-5 38 | lstm: 39 | dws_conv: False 40 | dws_conv_only_hidden: True 41 | dws_conv_kernel_size: 3 42 | drop_cell_update: 0 43 | fpn: 44 | name: PAFPN 45 | compile: 46 | enable: False 47 | args: 48 | mode: reduce-overhead 49 | depth: 0.67 # round(depth * 3) == num bottleneck blocks 50 | # stage 1 is the first and len(num_layers) is the last 51 | in_stages: [2, 3, 4] 52 | depthwise: False 53 | act: "silu" 54 | head: 55 | name: YoloX 56 | compile: 57 | enable: False 58 | args: 59 | mode: reduce-overhead 60 | depthwise: False 61 | act: "silu" 62 | postprocess: 63 | confidence_threshold: 0.1 64 | nms_threshold: 0.45 65 | -------------------------------------------------------------------------------- /config/model/pseudo_labeler-gen4-wsod.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | name: pseudo_labeler 5 | backbone: 6 | name: ??? 7 | fpn: 8 | name: ??? 9 | head: 10 | name: ??? 11 | postprocess: 12 | confidence_threshold: 0.1 13 | nms_threshold: 0.45 14 | # SSOD-related configs 15 | pseudo_label: 16 | # if a sub-seq `is_first_sample`, we skip predictions at the first `skip_first_t` timesteps 17 | # because they don't have enough history information --> not accurate 18 | skip_first_t: 0 19 | # thresholds for filtering pseudo labels 20 | obj_thresh: [0.6, 0.5] # thresholds for each category 21 | cls_thresh: [0.6, 0.5] # gen1: ('car', 'ped'); gen4: ('ped', 'cyc', 'car') 22 | # by default we will use the same thresholds for ped and cyc 23 | # post-process using offline tracker, ignore bbox with short track length 24 | min_track_len: 6 25 | track_method: 'forward or backward' 26 | # 'forward or backward' (short in both directions --> ignore) 27 | inpaint: True # hallucinate bbox at missing frames 28 | ignore_label: 1024 29 | -------------------------------------------------------------------------------- /config/model/pseudo_labeler.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - base 3 | 4 | name: pseudo_labeler 5 | backbone: 6 | name: ??? 7 | fpn: 8 | name: ??? 9 | head: 10 | name: ??? 11 | postprocess: 12 | confidence_threshold: 0.1 13 | nms_threshold: 0.45 14 | # SSOD-related configs 15 | pseudo_label: 16 | # if a sub-seq `is_first_sample`, we skip predictions at the first `skip_first_t` timesteps 17 | # because they don't have enough history information --> not accurate 18 | skip_first_t: 0 19 | # thresholds for filtering pseudo labels 20 | obj_thresh: [0.6, 0.3] # thresholds for each category 21 | cls_thresh: [0.6, 0.3] # gen1: ('car', 'ped'); gen4: ('ped', 'cyc', 'car') 22 | # by default we will use the same thresholds for ped and cyc 23 | # post-process using offline tracker, ignore bbox with short track length 24 | min_track_len: 6 25 | track_method: 'forward or backward' 26 | # 'forward or backward' (short in both directions --> ignore) 27 | inpaint: True # hallucinate bbox at missing frames 28 | ignore_label: 1024 29 | -------------------------------------------------------------------------------- /config/model/rnndet-soft-gen4-wsod.yaml: -------------------------------------------------------------------------------- 1 | # This config is for self-training RVT in 1Mpx WSOD 2 | 3 | defaults: 4 | - base 5 | 6 | name: rnndet 7 | backbone: 8 | name: ??? 9 | fpn: 10 | name: ??? 11 | head: 12 | name: ??? 13 | obj_focal_loss: False # FocalLoss or BCE on objectness score 14 | bbox_loss_weighting: '' # 'obj' or 'cls' or 'objxcls' or '' 15 | # support further transformations, e.g. 'cls-w**2' is using cls**2 as weights 16 | ignore_bbox_thresh: [0.7, 0.55] # ignore bbox with obj/cls lower than this 17 | ignore_label: 1024 # ignore during training 18 | # don't suppress BG pixels with the highest k% confidence scores 19 | ignore_bg_k: 0 20 | postprocess: 21 | confidence_threshold: 0.1 22 | nms_threshold: 0.45 23 | use_label_every: 1 24 | ignore_image: False # ignore images where all bbox are with `ignore_label` 25 | -------------------------------------------------------------------------------- /config/model/rnndet-soft.yaml: -------------------------------------------------------------------------------- 1 | # This config is for self-training RVT in all settings except 1Mpx WSOD 2 | 3 | defaults: 4 | - base 5 | 6 | name: rnndet 7 | backbone: 8 | name: ??? 9 | fpn: 10 | name: ??? 11 | head: 12 | name: ??? 13 | obj_focal_loss: False # FocalLoss or BCE on objectness score 14 | bbox_loss_weighting: '' # 'obj' or 'cls' or 'objxcls' or '' 15 | # support further transformations, e.g. 'cls-w**2' is using cls**2 as weights 16 | ignore_bbox_thresh: [0.7, 0.35] # ignore bbox with obj/cls lower than this 17 | ignore_label: 1024 # ignore during training 18 | # don't suppress BG pixels with the highest k% confidence scores 19 | ignore_bg_k: 0 20 | postprocess: 21 | confidence_threshold: 0.1 22 | nms_threshold: 0.45 23 | use_label_every: 1 24 | ignore_image: False # ignore images where all bbox are with `ignore_label` 25 | -------------------------------------------------------------------------------- /config/model/rnndet.yaml: -------------------------------------------------------------------------------- 1 | # This config is for pre-training RVT on limited annotated data 2 | 3 | defaults: 4 | - base 5 | 6 | name: rnndet 7 | backbone: 8 | name: ??? 9 | fpn: 10 | name: ??? 11 | head: 12 | name: ??? 13 | obj_focal_loss: False # FocalLoss or BCE on objectness score 14 | bbox_loss_weighting: '' # 'obj' or 'cls' or 'objxcls' or '' 15 | # support further transformations, e.g. 'cls-w**2' is using cls**2 as weights 16 | ignore_bbox_thresh: null # ignore bbox with obj/cls lower than this 17 | ignore_label: 1024 # ignore during training 18 | # don't suppress BG pixels with the highest k% confidence scores 19 | ignore_bg_k: 0 20 | postprocess: 21 | confidence_threshold: 0.1 22 | nms_threshold: 0.45 23 | use_label_every: 1 24 | ignore_image: False # ignore images where all bbox are with `ignore_label` 25 | -------------------------------------------------------------------------------- /config/predict.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: ??? 3 | - model: pseudo_labeler 4 | - _self_ 5 | 6 | is_train: False 7 | checkpoint: ??? 8 | save_dir: "" # dir to save the generated dataset 9 | hardware: 10 | num_workers: 11 | eval: 8 12 | gpus: 0 # GPU idx (multi-gpu not supported for validation) 13 | batch_size: 14 | eval: 16 15 | training: 16 | precision: 16 17 | tta: 18 | enable: False 19 | hflip: True 20 | tflip: True 21 | use_gt: True # take GT labels on labeled frames, or still use pseudo labels 22 | -------------------------------------------------------------------------------- /config/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - general 3 | - dataset: ??? 4 | - model: rnndet 5 | - optional model/dataset: ${model}_${dataset} 6 | 7 | is_train: True 8 | -------------------------------------------------------------------------------- /config/val.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: ??? 3 | - model: rnndet 4 | - _self_ 5 | 6 | is_train: False 7 | checkpoint: ??? 8 | use_test_set: False 9 | reverse: False 10 | hardware: 11 | num_workers: 12 | eval: 8 13 | gpus: 0 # GPU idx (multi-gpu not supported for validation) 14 | batch_size: 15 | eval: 16 16 | training: 17 | precision: 16 18 | tta: 19 | enable: False 20 | hflip: True 21 | tflip: True 22 | -------------------------------------------------------------------------------- /config/vis.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: ??? 3 | - model: rnndet 4 | - _self_ 5 | 6 | is_train: False 7 | checkpoint: ??? 8 | num_video: 5 9 | reverse: False 10 | hardware: 11 | num_workers: 12 | eval: 0 13 | gpus: 0 # GPU idx (multi-gpu not supported for validation) 14 | batch_size: 15 | eval: 1 16 | training: 17 | precision: 16 18 | tta: 19 | enable: False 20 | hflip: False 21 | tflip: False 22 | -------------------------------------------------------------------------------- /data/genx_utils/collate.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Any, Callable, Dict, Optional, Type, Tuple, Union 3 | 4 | import torch 5 | 6 | from data.genx_utils.collate_from_pytorch import collate, default_collate_fn_map 7 | from data.genx_utils.labels import ObjectLabels, SparselyBatchedObjectLabels 8 | from data.utils.augmentor import AugmentationState 9 | from modules.utils.detection import WORKER_ID_KEY, DATA_KEY 10 | 11 | 12 | def collate_object_labels(batch, *, collate_fn_map: Optional[ 13 | Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): 14 | return batch 15 | 16 | 17 | def collate_sparsely_batched_object_labels(batch, *, collate_fn_map: Optional[ 18 | Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): 19 | return SparselyBatchedObjectLabels.transpose_list(batch) 20 | 21 | 22 | def collate_augm_state(batch, *, collate_fn_map: Optional[ 23 | Dict[Union[Type, Tuple[Type, ...]], Callable]] = None): 24 | """Collates a batch of AugmentationState objects into a dictionary where 25 | each field of AugmentationState maps to a list of values. 26 | 27 | Returns: { 28 | 'h_flip': {'active': `B`-len list of bool}, 29 | 'zoom_out': { 30 | 'active': [`B`-len list of bool], 31 | 'x0': [`B`-len list of int], 32 | 'y0': [`B`-len list of int], 33 | 'factor': [`B`-len list of float], 34 | }, 35 | 'zoom_in': { 36 | 'active': [`B`-len list of bool], 37 | 'x0': [`B`-len list of int], 38 | 'y0': [`B`-len list of int], 39 | 'factor': [`B`-len list of float], 40 | }, 41 | 'rotation': { 42 | 'active': [`B`-len list of bool], 43 | 'angle_deg': [`B`-len list of float], 44 | }, 45 | } 46 | """ 47 | return AugmentationState.collate_augm_state(batch) 48 | 49 | 50 | custom_collate_fn_map = copy.deepcopy(default_collate_fn_map) 51 | custom_collate_fn_map[ObjectLabels] = collate_object_labels 52 | custom_collate_fn_map[SparselyBatchedObjectLabels] = collate_sparsely_batched_object_labels 53 | custom_collate_fn_map[AugmentationState] = collate_augm_state 54 | 55 | 56 | def custom_collate(batch: Any): 57 | return collate(batch, collate_fn_map=custom_collate_fn_map) 58 | 59 | 60 | def custom_collate_rnd(batch: Any): 61 | samples = batch 62 | # NOTE: We do not really need the worker id for map style datasets (rnd) but we still provide the id for consistency 63 | worker_info = torch.utils.data.get_worker_info() 64 | local_worker_id = 0 if worker_info is None else worker_info.id 65 | return { 66 | DATA_KEY: custom_collate(samples), 67 | WORKER_ID_KEY: local_worker_id, 68 | } 69 | 70 | 71 | def custom_collate_streaming(batch: Any): 72 | """We assume that we receive a batch collected by a worker of our streaming datapipe 73 | """ 74 | samples = batch[0] 75 | worker_id = batch[1] 76 | assert isinstance(worker_id, int) 77 | return { 78 | DATA_KEY: custom_collate(samples), 79 | WORKER_ID_KEY: worker_id, 80 | } 81 | -------------------------------------------------------------------------------- /data/genx_utils/splits/gen1/ssod_0.010-off0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen1/ssod_0.010-off0.pkl -------------------------------------------------------------------------------- /data/genx_utils/splits/gen1/ssod_0.020-off0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen1/ssod_0.020-off0.pkl -------------------------------------------------------------------------------- /data/genx_utils/splits/gen1/ssod_0.050-off0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen1/ssod_0.050-off0.pkl -------------------------------------------------------------------------------- /data/genx_utils/splits/gen1/ssod_0.100-off0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen1/ssod_0.100-off0.pkl -------------------------------------------------------------------------------- /data/genx_utils/splits/gen4/ssod_0.010-off0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen4/ssod_0.010-off0.pkl -------------------------------------------------------------------------------- /data/genx_utils/splits/gen4/ssod_0.020-off0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen4/ssod_0.020-off0.pkl -------------------------------------------------------------------------------- /data/genx_utils/splits/gen4/ssod_0.050-off0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen4/ssod_0.050-off0.pkl -------------------------------------------------------------------------------- /data/genx_utils/splits/gen4/ssod_0.100-off0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/data/genx_utils/splits/gen4/ssod_0.100-off0.pkl -------------------------------------------------------------------------------- /data/utils/misc.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | 3 | import os 4 | 5 | import h5py 6 | import numpy as np 7 | 8 | from data.genx_utils.labels import ObjectLabelFactory, ObjectLabels 9 | 10 | 11 | def get_labels_npz_fn(seq_dir: str) -> str: 12 | """Get labels npz file name.""" 13 | # seq_dir: path/to/dataset/train/18-03-29_13-15-02_5_605 14 | # labels_npz_fn: .../18-03-29_13-15-02_5_605/labels_v2/labels.npz 15 | labels_npz_fn = os.path.join(seq_dir, 'labels_v2/labels.npz') 16 | return labels_npz_fn 17 | 18 | 19 | def read_npz_labels(label_fn: str) -> Tuple[np.ndarray, np.ndarray]: 20 | """Read labels from npz file.""" 21 | # if it's just .../18-03-29_13-15-02_5_605, we get the true npz first 22 | if 'labels_v2' not in label_fn: 23 | label_fn = get_labels_npz_fn(label_fn) 24 | labels = np.load(label_fn) 25 | return labels['labels'], labels['objframe_idx_2_label_idx'] 26 | 27 | 28 | def read_labels_as_list(seq_dir: str, dst_cfg, 29 | L: int, start_idx: int = 0) -> List[ObjectLabels]: 30 | """Read the original full-frequency GT labels.""" 31 | # seq_dir: .../18-03-29_13-15-02_5_605 32 | labels, objframe_idx_2_label_idx = read_npz_labels(seq_dir) 33 | hw = tuple(dst_cfg.ev_repr_hw) 34 | if dst_cfg.downsample_by_factor_2: 35 | hw = tuple(s * 2 for s in hw) 36 | label_factory = ObjectLabelFactory.from_structured_array( 37 | labels, objframe_idx_2_label_idx, hw, 38 | 2 if dst_cfg.downsample_by_factor_2 else None) 39 | ev_dir = get_ev_dir(seq_dir) 40 | objframe_idx_2_repr_idx = np.load( 41 | os.path.join(ev_dir, 'objframe_idx_2_repr_idx.npy')) 42 | obj_labels = [None] * L 43 | for objframe_idx, repr_idx in enumerate(objframe_idx_2_repr_idx): 44 | if start_idx <= repr_idx < start_idx + L: 45 | obj_labels[repr_idx - start_idx] = label_factory[objframe_idx] 46 | return obj_labels 47 | 48 | 49 | def get_ev_dir(seq_dir: str) -> str: 50 | """Get event representation directory.""" 51 | # seq_dir: path/to/dataset/train/18-03-29_13-15-02_5_605 52 | # ev_dir: .../18-03-29_13-15-02_5_605/event_representations_v2/stacked_histogram_dt=50_nbins=10/ 53 | ev_dir = os.path.join(seq_dir, 'event_representations_v2', 54 | 'stacked_histogram_dt=50_nbins=10') 55 | return ev_dir 56 | 57 | 58 | def get_objframe_idx_2_repr_idx_fn(ev_dir: str) -> str: 59 | """Get xxx/objframe_idx_2_repr_idx.npy file name.""" 60 | if 'event_representations_v2' not in ev_dir: 61 | ev_dir = get_ev_dir(ev_dir) 62 | # ev_dir: .../18-03-29_13-15-02_5_605/event_representations_v2/stacked_histogram_dt=50_nbins=10/ 63 | # fn: .../18-03-29_13-15-02_5_605/event_representations_v2/stacked_histogram_dt=50_nbins=10/objframe_idx_2_repr_idx.npy 64 | fn = os.path.join(ev_dir, 'objframe_idx_2_repr_idx.npy') 65 | return fn 66 | 67 | 68 | def get_ev_h5_fn(ev_dir: str, dst_name: str = None) -> str: 69 | """Get event representation h5 file name.""" 70 | # if it's just .../18-03-29_13-15-02_5_605, we get the true ev_dir first 71 | if 'event_representations_v2' not in ev_dir: 72 | ev_dir = get_ev_dir(ev_dir) 73 | # ev_dir: .../18-03-29_13-15-02_5_605/event_representations_v2/stacked_histogram_dt=50_nbins=10/ 74 | # ev_h5_fn: .../18-03-29_13-15-02_5_605/event_representations_v2/stacked_histogram_dt=50_nbins=10/event_representations.h5 75 | if dst_name is None: 76 | dst_name = 'gen1' if 'gen1' in ev_dir else 'gen4' 77 | ev_name = 'event_representations.h5' if dst_name == 'gen1' else \ 78 | 'event_representations_ds2_nearest.h5' 79 | ev_h5_fn = os.path.join(ev_dir, ev_name) 80 | return ev_h5_fn 81 | 82 | 83 | def read_ev_repr(h5f: str) -> np.ndarray: 84 | if 'event_representations_v2' not in h5f: 85 | h5f = get_ev_h5_fn(h5f) 86 | with h5py.File(str(h5f), 'r') as h5f: 87 | ev_repr = h5f['data'][:] 88 | return ev_repr 89 | 90 | 91 | def read_objframe_idx_2_repr_idx(npy_fn: str) -> np.ndarray: 92 | if 'event_representations_v2' not in npy_fn: 93 | npy_fn = get_objframe_idx_2_repr_idx_fn(npy_fn) 94 | return np.load(npy_fn) 95 | -------------------------------------------------------------------------------- /data/utils/spatial.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig 2 | 3 | from data.utils.types import DatasetType 4 | 5 | _type_2_hw = { 6 | DatasetType.GEN1: (240, 304), 7 | DatasetType.GEN4: (720, 1280), 8 | } 9 | 10 | _str_2_type = { 11 | 'gen1': DatasetType.GEN1, 12 | 'gen4': DatasetType.GEN4, 13 | } 14 | 15 | 16 | def get_original_hw(dataset_type: DatasetType): 17 | return _type_2_hw[dataset_type] 18 | 19 | 20 | def get_dataloading_hw(dataset_config: DictConfig): 21 | dataset_name = dataset_config.name 22 | hw = get_original_hw(dataset_type=_str_2_type[dataset_name]) 23 | downsample_by_factor_2 = dataset_config.downsample_by_factor_2 24 | if downsample_by_factor_2: 25 | hw = tuple(x // 2 for x in hw) 26 | return hw 27 | -------------------------------------------------------------------------------- /data/utils/stream_concat_datapipe.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterator, List, Optional, Type 2 | 3 | import torch as th 4 | import torch.distributed as dist 5 | from torch.utils.data import DataLoader 6 | from torchdata.datapipes.iter import ( 7 | Concater, 8 | IterableWrapper, 9 | IterDataPipe, 10 | Zipper, 11 | ) 12 | from torchdata.datapipes.map import MapDataPipe 13 | 14 | 15 | class DummyIterDataPipe(IterDataPipe): 16 | def __init__(self, source_dp: IterDataPipe): 17 | super().__init__() 18 | assert isinstance(source_dp, IterDataPipe) 19 | self.source_dp = source_dp 20 | 21 | def __iter__(self): 22 | yield from self.source_dp 23 | 24 | 25 | class ConcatStreamingDataPipe(IterDataPipe): 26 | """This Dataset avoids the sharding problem by instantiating randomized stream concatenation at the batch and 27 | worker level. 28 | Pros: 29 | - Every single batch has valid samples. Consequently, the batch size is always constant. 30 | Cons: 31 | - There might be repeated samples in a batch. Although they should be different because of data augmentation. 32 | - Cannot be used for validation or testing because we repeat the dataset multiple times in an epoch. 33 | 34 | TLDR: preferred approach for training but not useful for validation or testing. 35 | """ 36 | 37 | def __init__(self, 38 | datapipe_list: List[MapDataPipe], 39 | batch_size: int, 40 | num_workers: int, 41 | augmentation_pipeline: Optional[Type[IterDataPipe]] = None, 42 | print_seed_debug: bool = False): 43 | super().__init__() 44 | assert batch_size > 0 45 | 46 | if augmentation_pipeline is not None: 47 | self.augmentation_dp = augmentation_pipeline 48 | else: 49 | self.augmentation_dp = DummyIterDataPipe 50 | 51 | # We require MapDataPipes instead of IterDataPipes because IterDataPipes must be deepcopied in each worker. 52 | # Instead, MapDataPipes can be converted to IterDataPipes in each worker without requiring a deepcopy. 53 | self.datapipe_list = datapipe_list 54 | self.batch_size = batch_size 55 | 56 | self.print_seed_debug = print_seed_debug 57 | 58 | @staticmethod 59 | def random_torch_shuffle_list(data: List[Any]) -> Iterator[Any]: 60 | assert isinstance(data, List) 61 | return (data[idx] for idx in th.randperm(len(data)).tolist()) 62 | 63 | def _get_zipped_streams(self, datapipe_list: List[MapDataPipe], batch_size: int): 64 | """Use it only in the iter function of this class! 65 | Reason: randomized shuffling must happen within each worker. Otherwise, the same random order will be used 66 | for all workers. 67 | """ 68 | assert isinstance(datapipe_list, List) 69 | assert batch_size > 0 70 | # randomly shuffle datapipes and concat them as one datapipe 71 | # repeat it `BS` times so that we load `BS` samples per `next()` call 72 | streams = Zipper(*(Concater(*(self.augmentation_dp(x.to_iter_datapipe()) 73 | for x in self.random_torch_shuffle_list(datapipe_list))) 74 | for _ in range(batch_size))) 75 | return streams 76 | 77 | def _print_seed_debug_info(self): 78 | """Debug purpose only.""" 79 | worker_info = th.utils.data.get_worker_info() 80 | local_worker_id = 0 if worker_info is None else worker_info.id 81 | 82 | worker_torch_seed = worker_info.seed 83 | local_num_workers = 1 if worker_info is None else worker_info.num_workers 84 | if dist.is_available() and dist.is_initialized(): 85 | global_rank = dist.get_rank() 86 | else: 87 | global_rank = 0 88 | global_worker_id = global_rank * local_num_workers + local_worker_id 89 | 90 | rnd_number = th.randn(1) 91 | print(f'{worker_torch_seed=},\t{global_worker_id=},\t{global_rank=},\t{local_worker_id=},\t{rnd_number=}', 92 | flush=True) 93 | 94 | def _get_zipped_streams_with_worker_id(self): 95 | """Use it only in the iter function of this class! 96 | """ 97 | worker_info = th.utils.data.get_worker_info() 98 | local_worker_id = 0 if worker_info is None else worker_info.id 99 | # always return self's worker_id 100 | worker_id_stream = IterableWrapper([local_worker_id]).cycle(count=None) 101 | # get ONE datapipe that loads a batch of data at every `next()` call 102 | zipped_stream = self._get_zipped_streams(datapipe_list=self.datapipe_list, batch_size=self.batch_size) 103 | return zipped_stream.zip(worker_id_stream) 104 | 105 | def __iter__(self): 106 | if self.print_seed_debug: 107 | self._print_seed_debug_info() 108 | return iter(self._get_zipped_streams_with_worker_id()) 109 | -------------------------------------------------------------------------------- /data/utils/types.py: -------------------------------------------------------------------------------- 1 | from enum import auto, Enum 2 | 3 | try: 4 | from enum import StrEnum 5 | except ImportError: 6 | from strenum import StrEnum 7 | from typing import Dict, List, Optional, Tuple, Union, Any 8 | 9 | import torch as th 10 | 11 | from data.genx_utils.labels import ObjectLabels, SparselyBatchedObjectLabels 12 | # from data.utils.augmentor import AugmentationState # avoid circular import 13 | 14 | 15 | class DataType(Enum): 16 | PATH = auto() # 'path/to/dst/test/17-06-97_12-14-33_244500000_304500000' 17 | EV_IDX = auto() # index of the loaded ev_repr in the entire sequence 18 | EV_REPR = auto() 19 | FLOW = auto() 20 | IMAGE = auto() 21 | OBJLABELS = auto() 22 | OBJLABELS_SEQ = auto() 23 | SKIPPED_OBJLABELS_SEQ = auto() # GT labels that are skipped in SSOD 24 | IS_PADDED_MASK = auto() 25 | IS_FIRST_SAMPLE = auto() 26 | IS_LAST_SAMPLE = auto() 27 | IS_REVERSED = auto() # whether the sequence is in reverse order 28 | TOKEN_MASK = auto() 29 | PRED_MASK = auto() # if the teacher do prediction at this frame 30 | GT_MASK = auto() # which labels are from GT, others are pseudo labels 31 | PRED_PROBS = auto() # obj/cls_probs predicted by the teacher model 32 | AUGM_STATE = auto() # augmentation state 33 | 34 | 35 | class DatasetType(Enum): 36 | GEN1 = auto() 37 | GEN4 = auto() 38 | 39 | 40 | class DatasetMode(Enum): 41 | TRAIN = auto() 42 | VALIDATION = auto() 43 | TESTING = auto() 44 | 45 | 46 | class DatasetSamplingMode(StrEnum): 47 | RANDOM = 'random' 48 | STREAM = 'stream' 49 | MIXED = 'mixed' 50 | 51 | 52 | class ObjDetOutput(Enum): 53 | LABELS_PROPH = auto() 54 | PRED_PROPH = auto() 55 | EV_REPR = auto() 56 | SKIP_VIZ = auto() 57 | 58 | 59 | LoaderDataDictGenX = Dict[DataType, Union[List[th.Tensor], ObjectLabels, 60 | SparselyBatchedObjectLabels, 61 | # AugmentationState, 62 | List[bool], ]] 63 | 64 | LstmState = Optional[Tuple[th.Tensor, th.Tensor]] 65 | LstmStates = List[LstmState] 66 | 67 | FeatureMap = th.Tensor 68 | BackboneFeatures = Dict[int, th.Tensor] 69 | BatchAugmState = Dict[str, Dict[str, List[Any]]] 70 | -------------------------------------------------------------------------------- /datasets/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/datasets/.gitkeep -------------------------------------------------------------------------------- /docs/install.md: -------------------------------------------------------------------------------- 1 | # Install 2 | 3 | Most of this section is the same as [RVT](https://github.com/uzh-rpg/RVT). 4 | We also heavily rely on a personal package [nerv](https://github.com/Wuziyi616/nerv) for utility functions. 5 | 6 | ## Environment Setup 7 | 8 | Please use [Anaconda](https://www.anaconda.com/) for package management. 9 | ```Bash 10 | conda create -y -n leod python=3.9 11 | conda activate leod 12 | 13 | conda install -y pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.8 -c pytorch -c nvidia 14 | 15 | python -m pip install tqdm numba hdf5plugin h5py==3.8.0 \ 16 | pandas==1.5.3 plotly==5.13.1 opencv-python==4.6.0.66 tabulate==0.9.0 \ 17 | pycocotools==2.0.6 bbox-visualizer==0.1.0 StrEnum==0.4.10 \ 18 | opencv-python hydra-core==1.3.2 einops==0.6.0 \ 19 | pytorch-lightning==1.8.6 wandb==0.14.0 torchdata==0.6.0 20 | 21 | conda install -y blosc-hdf5-plugin -c conda-forge 22 | 23 | # install nerv: https://github.com/Wuziyi616/nerv 24 | git clone git@github.com:Wuziyi616/nerv.git 25 | cd nerv 26 | git checkout v0.4.0 # tested with v0.4.0 release 27 | pip install -e . 28 | cd .. # go back to the root directory of the project 29 | 30 | # (Optional) compile Detectron2 for faster evaluation 31 | python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' 32 | ``` 33 | 34 | We also provide a `environment.yml` file which lists all the packages in our final environment. 35 | 36 | I sometimes encounter a weird bug where `Detectron2` cannot run on types of GPUs different from the one I compile it on (e.g., if I compile it on RTX6000 GPUs, I cannot use it on A40 GPUs). 37 | To avoid this issue, go to [coco_eval.py](../utils/evaluation/prophesee/metrics/coco_eval.py#L17) and set the `compile_gpu` to the GPU you compile it (the program will not import `Detectron2` when detecting a different GPUs in use). 38 | 39 | ## Dataset 40 | 41 | In this project, we use two datasets: Gen1 and 1Mpx. 42 | Following the convention of RVT, we name Gen1 as `gen1` and 1Mpx as `gen4` (because of the camera used to capture them). 43 | Please download the pre-processed datasets from RVT: 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 |
1 MpxGen1
pre-processed datasetdownloaddownload
crc32c5ec7c385acab6f3
58 | 59 | After downloading and unzipping the datasets, soft link Gen1 to `./datasets/gen1` and 1Mpx to `./datasets/gen4`. 60 | 61 | ### Data Splits 62 | 63 | To simulate the weakly-/semi-supervised learning settings, we need to sub-sample labels from the original dataset. 64 | An important thing is that we need to keep the data split the same across experiments. 65 | - For semi-supervised setting where we keep the labels for some sequences while making other sequences completely unlabeled, it is relatively easy. 66 | We just sort the name of event sequences so that their order will be deterministic across runs, and select unlabeled sequences from it. 67 | - For weakly-supervised setting where we sub-sample the labels for all sequences, it is a bit tricky because there are two mode of data sampling in the codebase, and they pre-process events in different ways. 68 | To have a consistent data split, we create a split file for each setting, which are stored [here](../data/genx_utils/splits/). 69 | If you want to explore new experimental settings, remember to create your own split files and read from them [here](../data/genx_utils/dataset_streaming.py#L62). 70 | 71 | All results in the paper are averaged over three different splits (we offset the index when sub-sampling the data). 72 | Overall, the performance variations are very small across different splits. 73 | Therefore, we only release the split files, config files, and pre-trained weights for the first variant we experimented with. 74 | 75 | ## Pre-trained Weights 76 | 77 | We provide checkpoints for all the models used to produce the final performance in the paper. 78 | In addition, we provide models pre-trained on the limited annotated data (the `Supervised Baseline` method in the paper) to ease your experiments. 79 | 80 | Please download the pre-trained weights from [Google Drive](https://drive.google.com/file/d/1xBzFovvNbrtBt0YwYcvvrjbV8ozAdCUK/view?usp=sharing) and unzip them to `./pretrained/`. 81 | The weights are grouped by the Section they are presented in the paper. 82 | They naming follows the pattern `rvt-{$MODEL_SIZE}-{$DATASET}x{$RATIO_OF_DATA}_{$SETTING}.ckpt`. 83 | 84 | For example, `rvt-s-gen1x0.02_ss.ckpt` is the RVT-S pre-trained on 2% of Gen1 data under the weakly-supervised setting. 85 | `rvt-s-gen4x0.05_ss-final.ckpt` is the RVT-S trained on 5% of 1Mpx data under the semi-supervised setting, and `-final` means it is the LEOD self-trained model (used to produce the results in the paper). 86 | 87 | **Note:** it might be a bit confusing, but `ss` means weakly-supervised (all event sequences are sparsely labeled) and `seq` means semi-supervised (some event sequences are densely labeled, while others are completely unlabeled). 88 | -------------------------------------------------------------------------------- /loggers/utils.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig 2 | from pytorch_lightning.loggers.wandb import WandbLogger 3 | 4 | 5 | def get_wandb_logger(full_config: DictConfig) -> WandbLogger: 6 | """Build the native PyTorch Lightning WandB logger.""" 7 | wandb_config = full_config.wandb 8 | 9 | logger = WandbLogger( 10 | project=wandb_config.project_name, 11 | name=wandb_config.wandb_name, 12 | id=wandb_config.wandb_id, 13 | # specify both to make sure it's saved in the right place 14 | save_dir=wandb_config.wandb_runpath, 15 | dir=wandb_config.wandb_runpath, 16 | # group=wandb_config.group_name, # not very useful 17 | # log_model=True, # don't log model weights 18 | # save_last_only_final=False, 19 | # save_code=True, 20 | # config_args=full_config_dict, 21 | ) 22 | 23 | return logger 24 | -------------------------------------------------------------------------------- /models/detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/models/detection/__init__.py -------------------------------------------------------------------------------- /models/detection/recurrent_backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from omegaconf import DictConfig 2 | 3 | from .maxvit_rnn import RNNDetector as MaxViTRNNDetector 4 | 5 | 6 | def build_recurrent_backbone(backbone_cfg: DictConfig): 7 | name = backbone_cfg.name 8 | if name == 'MaxViTRNN': 9 | return MaxViTRNNDetector(backbone_cfg) 10 | else: 11 | raise NotImplementedError 12 | -------------------------------------------------------------------------------- /models/detection/recurrent_backbone/base.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class BaseDetector(nn.Module): 7 | def get_stage_dims(self, stages: Tuple[int, ...]) -> Tuple[int, ...]: 8 | raise NotImplementedError 9 | 10 | def get_strides(self, stages: Tuple[int, ...]) -> Tuple[int, ...]: 11 | raise NotImplementedError 12 | -------------------------------------------------------------------------------- /models/detection/yolox/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/models/detection/yolox/models/__init__.py -------------------------------------------------------------------------------- /models/detection/yolox/models/losses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | # Copyright (c) Megvii Inc. All rights reserved. 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.ops as ops 8 | 9 | 10 | class IOUloss(nn.Module): 11 | """IoU loss for bbox regression.""" 12 | 13 | def __init__(self, reduction="none", loss_type="iou"): 14 | super(IOUloss, self).__init__() 15 | self.reduction = reduction 16 | self.loss_type = loss_type 17 | 18 | def forward(self, pred, target, weights=None): 19 | assert pred.shape[0] == target.shape[0] 20 | if pred.shape[0] == 0: 21 | return 0. 22 | 23 | pred = pred.view(-1, 4) 24 | target = target.view(-1, 4) 25 | tl = torch.max( 26 | (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) 27 | ) 28 | br = torch.min( 29 | (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) 30 | ) 31 | 32 | area_p = torch.prod(pred[:, 2:], 1) 33 | area_g = torch.prod(target[:, 2:], 1) 34 | 35 | en = (tl < br).type(tl.type()).prod(dim=1) 36 | area_i = torch.prod(br - tl, 1) * en 37 | area_u = area_p + area_g - area_i 38 | iou = (area_i) / (area_u + 1e-16) 39 | 40 | if self.loss_type == "iou": 41 | loss = 1 - iou ** 2 42 | elif self.loss_type == "giou": 43 | c_tl = torch.min( 44 | (pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2) 45 | ) 46 | c_br = torch.max( 47 | (pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2) 48 | ) 49 | area_c = torch.prod(c_br - c_tl, 1) 50 | giou = iou - (area_c - area_u) / area_c.clamp(1e-16) 51 | loss = 1 - giou.clamp(min=-1.0, max=1.0) 52 | else: 53 | raise NotImplementedError 54 | 55 | # weighted IoU loss 56 | if weights is not None and isinstance(weights, torch.Tensor) and \ 57 | (weights != 1.).any(): 58 | assert weights.shape[0] == loss.shape[0] 59 | loss = loss * weights 60 | 61 | if self.reduction == "mean": 62 | loss = loss.mean() 63 | elif self.reduction == "sum": 64 | loss = loss.sum() 65 | 66 | return loss 67 | 68 | 69 | class FocalLoss(nn.Module): 70 | """Focal loss for foreground-background classification (objectness).""" 71 | 72 | def __init__(self, alpha=0.25, gamma=2, reduction='none'): 73 | super().__init__() 74 | 75 | self.alpha = alpha 76 | self.gamma = gamma 77 | self.reduction = reduction 78 | 79 | def forward(self, inputs, targets): 80 | return ops.sigmoid_focal_loss( 81 | inputs=inputs, 82 | targets=targets, 83 | alpha=self.alpha, 84 | gamma=self.gamma, 85 | reduction=self.reduction) 86 | -------------------------------------------------------------------------------- /models/detection/yolox/models/network_blocks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | # Copyright (c) Megvii Inc. All rights reserved. 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class SiLU(nn.Module): 10 | """export-friendly version of nn.SiLU()""" 11 | 12 | @staticmethod 13 | def forward(x): 14 | return x * torch.sigmoid(x) 15 | 16 | 17 | def get_activation(name="silu", inplace=True): 18 | if name == "silu": 19 | module = nn.SiLU(inplace=inplace) 20 | elif name == "relu": 21 | module = nn.ReLU(inplace=inplace) 22 | elif name == "lrelu": 23 | module = nn.LeakyReLU(0.1, inplace=inplace) 24 | else: 25 | raise AttributeError("Unsupported act type: {}".format(name)) 26 | return module 27 | 28 | 29 | class BaseConv(nn.Module): 30 | """A Conv2d -> Batchnorm -> silu/leaky relu block""" 31 | 32 | def __init__( 33 | self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu" 34 | ): 35 | super().__init__() 36 | # same padding 37 | pad = (ksize - 1) // 2 38 | self.conv = nn.Conv2d( 39 | in_channels, 40 | out_channels, 41 | kernel_size=ksize, 42 | stride=stride, 43 | padding=pad, 44 | groups=groups, 45 | bias=bias, 46 | ) 47 | self.bn = nn.BatchNorm2d(out_channels) 48 | self.act = get_activation(act, inplace=True) 49 | 50 | def forward(self, x): 51 | return self.act(self.bn(self.conv(x))) 52 | 53 | def fuseforward(self, x): 54 | return self.act(self.conv(x)) 55 | 56 | 57 | class DWConv(nn.Module): 58 | """Depthwise Conv + Conv""" 59 | 60 | def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"): 61 | super().__init__() 62 | self.dconv = BaseConv( 63 | in_channels, 64 | in_channels, 65 | ksize=ksize, 66 | stride=stride, 67 | groups=in_channels, 68 | act=act, 69 | ) 70 | self.pconv = BaseConv( 71 | in_channels, out_channels, ksize=1, stride=1, groups=1, act=act 72 | ) 73 | 74 | def forward(self, x): 75 | x = self.dconv(x) 76 | return self.pconv(x) 77 | 78 | 79 | class Bottleneck(nn.Module): 80 | # Standard bottleneck 81 | def __init__( 82 | self, 83 | in_channels, 84 | out_channels, 85 | shortcut=True, 86 | expansion=0.5, 87 | depthwise=False, 88 | act="silu", 89 | ): 90 | super().__init__() 91 | hidden_channels = int(out_channels * expansion) 92 | Conv = DWConv if depthwise else BaseConv 93 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) 94 | self.conv2 = Conv(hidden_channels, out_channels, 3, stride=1, act=act) 95 | self.use_add = shortcut and in_channels == out_channels 96 | 97 | def forward(self, x): 98 | y = self.conv2(self.conv1(x)) 99 | if self.use_add: 100 | y = y + x 101 | return y 102 | 103 | 104 | class CSPLayer(nn.Module): 105 | """C3 in yolov5, CSP Bottleneck with 3 convolutions""" 106 | 107 | def __init__( 108 | self, 109 | in_channels, 110 | out_channels, 111 | n=1, 112 | shortcut=True, 113 | expansion=0.5, 114 | depthwise=False, 115 | act="silu", 116 | ): 117 | """ 118 | Args: 119 | in_channels (int): input channels. 120 | out_channels (int): output channels. 121 | n (int): number of Bottlenecks. Default value: 1. 122 | """ 123 | # ch_in, ch_out, number, shortcut, groups, expansion 124 | super().__init__() 125 | hidden_channels = int(out_channels * expansion) # hidden channels 126 | self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) 127 | self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) 128 | self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act) 129 | module_list = [ 130 | Bottleneck( 131 | hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act 132 | ) 133 | for _ in range(n) 134 | ] 135 | self.m = nn.Sequential(*module_list) 136 | 137 | def forward(self, x): 138 | x_1 = self.conv1(x) 139 | x_2 = self.conv2(x) 140 | x_1 = self.m(x_1) 141 | x = torch.cat((x_1, x_2), dim=1) 142 | return self.conv3(x) 143 | -------------------------------------------------------------------------------- /models/detection/yolox/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | # Copyright (c) Megvii Inc. All rights reserved. 4 | 5 | from .boxes import * 6 | from .compat import meshgrid 7 | -------------------------------------------------------------------------------- /models/detection/yolox/utils/compat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | 4 | import torch 5 | 6 | _TORCH_VER = [int(x) for x in torch.__version__.split(".")[:2]] 7 | 8 | __all__ = ["meshgrid"] 9 | 10 | 11 | def meshgrid(*tensors): 12 | if _TORCH_VER >= [1, 10]: 13 | return torch.meshgrid(*tensors, indexing="ij") 14 | else: 15 | return torch.meshgrid(*tensors) 16 | -------------------------------------------------------------------------------- /models/detection/yolox_extension/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/models/detection/yolox_extension/models/__init__.py -------------------------------------------------------------------------------- /models/detection/yolox_extension/models/build.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from omegaconf import OmegaConf, DictConfig 4 | 5 | from .yolo_pafpn import YOLOPAFPN 6 | from ...yolox.models.yolo_head import YOLOXHead 7 | 8 | 9 | def build_yolox_head(head_cfg: DictConfig, in_channels: Tuple[int, ...], strides: Tuple[int, ...], ssod: bool = False): 10 | assert not ssod # legacy code already removed, the flag is useless now 11 | head_cfg_dict = OmegaConf.to_container(head_cfg, resolve=True, throw_on_missing=True) 12 | head_cfg_dict.pop('name') 13 | head_cfg_dict.pop('version', None) 14 | head_cfg_dict.update({"in_channels": in_channels}) 15 | head_cfg_dict.update({"strides": strides}) 16 | compile_cfg = head_cfg_dict.pop('compile', None) 17 | head_cfg_dict.update({"compile_cfg": compile_cfg}) 18 | module = YOLOXHead 19 | return module(**head_cfg_dict) 20 | 21 | 22 | def build_yolox_fpn(fpn_cfg: DictConfig, in_channels: Tuple[int, ...]): 23 | fpn_cfg_dict = OmegaConf.to_container(fpn_cfg, resolve=True, throw_on_missing=True) 24 | fpn_name = fpn_cfg_dict.pop('name') 25 | fpn_cfg_dict.update({"in_channels": in_channels}) 26 | if fpn_name in {'PAFPN', 'pafpn'}: 27 | compile_cfg = fpn_cfg_dict.pop('compile', None) 28 | fpn_cfg_dict.update({"compile_cfg": compile_cfg}) 29 | return YOLOPAFPN(**fpn_cfg_dict) 30 | raise NotImplementedError 31 | -------------------------------------------------------------------------------- /models/detection/yolox_extension/models/detector.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple, Union 2 | 3 | import torch as th 4 | from omegaconf import DictConfig 5 | 6 | try: 7 | from torch import compile as th_compile 8 | except ImportError: 9 | th_compile = None 10 | 11 | from ...recurrent_backbone import build_recurrent_backbone 12 | from .build import build_yolox_fpn, build_yolox_head 13 | from utils.timers import TimerDummy as CudaTimer 14 | # from utils.timers import CudaTimer 15 | from data.utils.types import BackboneFeatures, LstmStates 16 | 17 | 18 | class YoloXDetector(th.nn.Module): 19 | """RNN-based MaxViT backbone + YOLOX detection head.""" 20 | 21 | def __init__(self, model_cfg: DictConfig, ssod: bool = False): 22 | super().__init__() 23 | backbone_cfg = model_cfg.backbone 24 | fpn_cfg = model_cfg.fpn 25 | head_cfg = model_cfg.head 26 | 27 | self.backbone = build_recurrent_backbone(backbone_cfg) # maxvit_rnn 28 | 29 | in_channels = self.backbone.get_stage_dims(fpn_cfg.in_stages) 30 | self.fpn = build_yolox_fpn(fpn_cfg, in_channels=in_channels) 31 | 32 | strides = self.backbone.get_strides(fpn_cfg.in_stages) 33 | self.yolox_head = build_yolox_head(head_cfg, in_channels=in_channels, strides=strides, ssod=ssod) 34 | 35 | def forward_backbone(self, 36 | x: th.Tensor, 37 | previous_states: Optional[LstmStates] = None, 38 | token_mask: Optional[th.Tensor] = None) -> \ 39 | Tuple[BackboneFeatures, LstmStates]: 40 | """Extract multi-stage features from the backbone. 41 | 42 | Input: 43 | x: (B, C, H, W), image 44 | previous_states: List[(lstm_h, lstm_c)], RNN states from prev timestep 45 | token_mask: (B, H, W) or None, pixel padding mask 46 | 47 | Returns: 48 | backbone_features: Dict{stage_id: feats, [B, C, h, w]}, multi-stage 49 | states: List[(lstm_h, lstm_c), same shape], RNN state of each stage 50 | """ 51 | with CudaTimer(device=x.device, timer_name="Backbone"): 52 | backbone_features, states = self.backbone(x, previous_states, token_mask) 53 | return backbone_features, states 54 | 55 | def forward_detect(self, 56 | backbone_features: BackboneFeatures, 57 | targets: Optional[th.Tensor] = None, 58 | soft_targets: Optional[th.Tensor] = None) -> \ 59 | Tuple[th.Tensor, Union[Dict[str, th.Tensor], None]]: 60 | """Predict object bbox from multi-stage features. 61 | 62 | Returns: 63 | outputs: (B, N, 4 + 1 + num_cls), [(x, y, w, h), obj_conf, cls] 64 | losses: Dict{loss_name: loss, torch.scalar tensor} 65 | """ 66 | device = next(iter(backbone_features.values())).device 67 | with CudaTimer(device=device, timer_name="FPN"): 68 | fpn_features = self.fpn(backbone_features) # Tuple(feats, [B, C, h, w]) 69 | if self.training: 70 | assert targets is not None 71 | with CudaTimer(device=device, timer_name="HEAD + Loss"): 72 | outputs, losses = self.yolox_head(fpn_features, targets, soft_targets) 73 | return outputs, losses 74 | with CudaTimer(device=device, timer_name="HEAD"): 75 | outputs, losses = self.yolox_head(fpn_features) 76 | assert losses is None 77 | return outputs, losses 78 | 79 | def forward(self, 80 | x: th.Tensor, 81 | previous_states: Optional[LstmStates] = None, 82 | retrieve_detections: bool = True, 83 | targets: Optional[th.Tensor] = None) -> \ 84 | Tuple[Union[th.Tensor, None], Union[Dict[str, th.Tensor], None], LstmStates]: 85 | backbone_features, states = self.forward_backbone(x, previous_states) 86 | outputs, losses = None, None 87 | if not retrieve_detections: 88 | assert targets is None 89 | return outputs, losses, states 90 | outputs, losses = self.forward_detect(backbone_features=backbone_features, targets=targets) 91 | return outputs, losses, states 92 | -------------------------------------------------------------------------------- /models/detection/yolox_extension/models/yolo_pafpn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Original Yolox PAFPN code with slight modifications 3 | """ 4 | from typing import Dict, Optional, Tuple 5 | 6 | import torch as th 7 | import torch.nn as nn 8 | 9 | try: 10 | from torch import compile as th_compile 11 | except ImportError: 12 | th_compile = None 13 | 14 | from ...yolox.models.network_blocks import BaseConv, CSPLayer, DWConv 15 | from data.utils.types import BackboneFeatures 16 | 17 | 18 | class YOLOPAFPN(nn.Module): 19 | """ 20 | Removed the direct dependency on the backbone. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | depth: float = 1.0, 26 | in_stages: Tuple[int, ...] = (2, 3, 4), 27 | in_channels: Tuple[int, ...] = (256, 512, 1024), 28 | depthwise: bool = False, 29 | act: str = "silu", 30 | compile_cfg: Optional[Dict] = None, 31 | ): 32 | super().__init__() 33 | assert len(in_stages) == len(in_channels) 34 | assert len(in_channels) == 3, 'Current implementation only for 3 feature maps' 35 | self.in_features = in_stages 36 | self.in_channels = in_channels 37 | Conv = DWConv if depthwise else BaseConv 38 | 39 | ###### Compile if requested ###### 40 | if compile_cfg is not None: 41 | compile_mdl = compile_cfg['enable'] 42 | if compile_mdl and th_compile is not None: 43 | self.forward = th_compile(self.forward, **compile_cfg['args']) 44 | elif compile_mdl: 45 | print('Could not compile PAFPN because torch.compile is not available') 46 | 47 | ################################## 48 | 49 | self.upsample = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest-exact') 50 | self.lateral_conv0 = BaseConv( 51 | in_channels[2], in_channels[1], 1, 1, act=act 52 | ) 53 | self.C3_p4 = CSPLayer( 54 | 2 * in_channels[1], 55 | in_channels[1], 56 | round(3 * depth), 57 | False, 58 | depthwise=depthwise, 59 | act=act, 60 | ) # cat 61 | 62 | self.reduce_conv1 = BaseConv( 63 | in_channels[1], in_channels[0], 1, 1, act=act 64 | ) 65 | self.C3_p3 = CSPLayer( 66 | 2 * in_channels[0], 67 | in_channels[0], 68 | round(3 * depth), 69 | False, 70 | depthwise=depthwise, 71 | act=act, 72 | ) 73 | 74 | # bottom-up conv 75 | self.bu_conv2 = Conv( 76 | in_channels[0], in_channels[0], 3, 2, act=act 77 | ) 78 | self.C3_n3 = CSPLayer( 79 | 2 * in_channels[0], 80 | in_channels[1], 81 | round(3 * depth), 82 | False, 83 | depthwise=depthwise, 84 | act=act, 85 | ) 86 | 87 | # bottom-up conv 88 | self.bu_conv1 = Conv( 89 | in_channels[1], in_channels[1], 3, 2, act=act 90 | ) 91 | self.C3_n4 = CSPLayer( 92 | 2 * in_channels[1], 93 | in_channels[2], 94 | round(3 * depth), 95 | False, 96 | depthwise=depthwise, 97 | act=act, 98 | ) 99 | 100 | ###### Compile if requested ###### 101 | if compile_cfg is not None: 102 | compile_mdl = compile_cfg['enable'] 103 | if compile_mdl and th_compile is not None: 104 | self.forward = th_compile(self.forward, **compile_cfg['args']) 105 | elif compile_mdl: 106 | print('Could not compile PAFPN because torch.compile is not available') 107 | ################################## 108 | 109 | def forward(self, input: BackboneFeatures): 110 | """ 111 | Args: 112 | inputs: Feature maps from backbone 113 | 114 | Returns: 115 | Tuple[Tensor]: FPN feature. 116 | """ 117 | features = [input[f] for f in self.in_features] 118 | x2, x1, x0 = features 119 | 120 | # channel/dowmsample_stride 121 | fpn_out0 = self.lateral_conv0(x0) # 1024->512/32 122 | f_out0 = self.upsample(fpn_out0) # 512/16 123 | f_out0 = th.cat([f_out0, x1], 1) # 512->1024/16 124 | f_out0 = self.C3_p4(f_out0) # 1024->512/16 125 | 126 | fpn_out1 = self.reduce_conv1(f_out0) # 512->256/16 127 | f_out1 = self.upsample(fpn_out1) # 256/8 128 | f_out1 = th.cat([f_out1, x2], 1) # 256->512/8 129 | pan_out2 = self.C3_p3(f_out1) # 512->256/8 130 | 131 | p_out1 = self.bu_conv2(pan_out2) # 256->256/16 132 | p_out1 = th.cat([p_out1, fpn_out1], 1) # 256->512/16 133 | pan_out1 = self.C3_n3(p_out1) # 512->512/16 134 | 135 | p_out0 = self.bu_conv1(pan_out1) # 512->512/32 136 | p_out0 = th.cat([p_out0, fpn_out0], 1) # 512->1024/32 137 | pan_out0 = self.C3_n4(p_out0) # 1024->1024/32 138 | 139 | outputs = (pan_out2, pan_out1, pan_out0) 140 | return outputs 141 | -------------------------------------------------------------------------------- /models/layers/maxvit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/models/layers/maxvit/__init__.py -------------------------------------------------------------------------------- /models/layers/maxvit/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .adaptive_avgmax_pool import \ 3 | adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d 4 | from .blur_pool import BlurPool2d 5 | from .classifier import ClassifierHead, create_classifier 6 | from .cond_conv2d import CondConv2d, get_condconv_initializer 7 | from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ 8 | set_layer_config 9 | from .conv2d_same import Conv2dSame, conv2d_same 10 | from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct 11 | from .create_act import create_act_layer, get_act_layer, get_act_fn 12 | from .create_attn import get_attn, create_attn 13 | from .create_conv2d import create_conv2d 14 | from .create_norm import get_norm_layer, create_norm_layer 15 | from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer 16 | from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path 17 | from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn 18 | from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ 19 | EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a 20 | from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm 21 | from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d 22 | from .gather_excite import GatherExcite 23 | from .global_context import GlobalContext 24 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple 25 | from .inplace_abn import InplaceAbn 26 | from .linear import Linear 27 | from .mixed_conv2d import MixedConv2d 28 | from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp 29 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 30 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d 31 | from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm 32 | from .padding import get_padding, get_same_padding, pad_same 33 | from .patch_embed import PatchEmbed 34 | from .pool2d_same import AvgPool2dSame, create_pool2d 35 | from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite 36 | from .selective_kernel import SelectiveKernel 37 | from .separable_conv import SeparableConv2d, SeparableConvNormAct 38 | from .space_to_depth import SpaceToDepthModule 39 | from .split_attn import SplitAttn 40 | from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model 41 | from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame 42 | from .test_time_pool import TestTimePoolHead, apply_test_time_pool 43 | from .trace_utils import _assert, _float_to_int 44 | from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_ 45 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/activations.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | 9 | import torch 10 | from torch import nn as nn 11 | from torch.nn import functional as F 12 | 13 | 14 | def swish(x, inplace: bool = False): 15 | """Swish - Described in: https://arxiv.org/abs/1710.05941 16 | """ 17 | return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) 18 | 19 | 20 | class Swish(nn.Module): 21 | def __init__(self, inplace: bool = False): 22 | super(Swish, self).__init__() 23 | self.inplace = inplace 24 | 25 | def forward(self, x): 26 | return swish(x, self.inplace) 27 | 28 | 29 | def mish(x, inplace: bool = False): 30 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 31 | NOTE: I don't have a working inplace variant 32 | """ 33 | return x.mul(F.softplus(x).tanh()) 34 | 35 | 36 | class Mish(nn.Module): 37 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 38 | """ 39 | def __init__(self, inplace: bool = False): 40 | super(Mish, self).__init__() 41 | 42 | def forward(self, x): 43 | return mish(x) 44 | 45 | 46 | def sigmoid(x, inplace: bool = False): 47 | return x.sigmoid_() if inplace else x.sigmoid() 48 | 49 | 50 | # PyTorch has this, but not with a consistent inplace argmument interface 51 | class Sigmoid(nn.Module): 52 | def __init__(self, inplace: bool = False): 53 | super(Sigmoid, self).__init__() 54 | self.inplace = inplace 55 | 56 | def forward(self, x): 57 | return x.sigmoid_() if self.inplace else x.sigmoid() 58 | 59 | 60 | def tanh(x, inplace: bool = False): 61 | return x.tanh_() if inplace else x.tanh() 62 | 63 | 64 | # PyTorch has this, but not with a consistent inplace argmument interface 65 | class Tanh(nn.Module): 66 | def __init__(self, inplace: bool = False): 67 | super(Tanh, self).__init__() 68 | self.inplace = inplace 69 | 70 | def forward(self, x): 71 | return x.tanh_() if self.inplace else x.tanh() 72 | 73 | 74 | def hard_swish(x, inplace: bool = False): 75 | inner = F.relu6(x + 3.).div_(6.) 76 | return x.mul_(inner) if inplace else x.mul(inner) 77 | 78 | 79 | class HardSwish(nn.Module): 80 | def __init__(self, inplace: bool = False): 81 | super(HardSwish, self).__init__() 82 | self.inplace = inplace 83 | 84 | def forward(self, x): 85 | return hard_swish(x, self.inplace) 86 | 87 | 88 | def hard_sigmoid(x, inplace: bool = False): 89 | if inplace: 90 | return x.add_(3.).clamp_(0., 6.).div_(6.) 91 | else: 92 | return F.relu6(x + 3.) / 6. 93 | 94 | 95 | class HardSigmoid(nn.Module): 96 | def __init__(self, inplace: bool = False): 97 | super(HardSigmoid, self).__init__() 98 | self.inplace = inplace 99 | 100 | def forward(self, x): 101 | return hard_sigmoid(x, self.inplace) 102 | 103 | 104 | def hard_mish(x, inplace: bool = False): 105 | """ Hard Mish 106 | Experimental, based on notes by Mish author Diganta Misra at 107 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 108 | """ 109 | if inplace: 110 | return x.mul_(0.5 * (x + 2).clamp(min=0, max=2)) 111 | else: 112 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 113 | 114 | 115 | class HardMish(nn.Module): 116 | def __init__(self, inplace: bool = False): 117 | super(HardMish, self).__init__() 118 | self.inplace = inplace 119 | 120 | def forward(self, x): 121 | return hard_mish(x, self.inplace) 122 | 123 | 124 | class PReLU(nn.PReLU): 125 | """Applies PReLU (w/ dummy inplace arg) 126 | """ 127 | def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None: 128 | super(PReLU, self).__init__(num_parameters=num_parameters, init=init) 129 | 130 | def forward(self, input: torch.Tensor) -> torch.Tensor: 131 | return F.prelu(input, self.weight) 132 | 133 | 134 | def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: 135 | return F.gelu(x) 136 | 137 | 138 | class GELU(nn.Module): 139 | """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) 140 | """ 141 | def __init__(self, inplace: bool = False): 142 | super(GELU, self).__init__() 143 | 144 | def forward(self, input: torch.Tensor) -> torch.Tensor: 145 | return F.gelu(input) 146 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/activations_jit.py: -------------------------------------------------------------------------------- 1 | """ Activations 2 | 3 | A collection of jit-scripted activations fn and modules with a common interface so that they can 4 | easily be swapped. All have an `inplace` arg even if not used. 5 | 6 | All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not 7 | currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted 8 | versions if they contain in-place ops. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | 13 | import torch 14 | from torch import nn as nn 15 | from torch.nn import functional as F 16 | 17 | 18 | @torch.jit.script 19 | def swish_jit(x, inplace: bool = False): 20 | """Swish - Described in: https://arxiv.org/abs/1710.05941 21 | """ 22 | return x.mul(x.sigmoid()) 23 | 24 | 25 | @torch.jit.script 26 | def mish_jit(x, _inplace: bool = False): 27 | """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 28 | """ 29 | return x.mul(F.softplus(x).tanh()) 30 | 31 | 32 | class SwishJit(nn.Module): 33 | def __init__(self, inplace: bool = False): 34 | super(SwishJit, self).__init__() 35 | 36 | def forward(self, x): 37 | return swish_jit(x) 38 | 39 | 40 | class MishJit(nn.Module): 41 | def __init__(self, inplace: bool = False): 42 | super(MishJit, self).__init__() 43 | 44 | def forward(self, x): 45 | return mish_jit(x) 46 | 47 | 48 | @torch.jit.script 49 | def hard_sigmoid_jit(x, inplace: bool = False): 50 | # return F.relu6(x + 3.) / 6. 51 | return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 52 | 53 | 54 | class HardSigmoidJit(nn.Module): 55 | def __init__(self, inplace: bool = False): 56 | super(HardSigmoidJit, self).__init__() 57 | 58 | def forward(self, x): 59 | return hard_sigmoid_jit(x) 60 | 61 | 62 | @torch.jit.script 63 | def hard_swish_jit(x, inplace: bool = False): 64 | # return x * (F.relu6(x + 3.) / 6) 65 | return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? 66 | 67 | 68 | class HardSwishJit(nn.Module): 69 | def __init__(self, inplace: bool = False): 70 | super(HardSwishJit, self).__init__() 71 | 72 | def forward(self, x): 73 | return hard_swish_jit(x) 74 | 75 | 76 | @torch.jit.script 77 | def hard_mish_jit(x, inplace: bool = False): 78 | """ Hard Mish 79 | Experimental, based on notes by Mish author Diganta Misra at 80 | https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md 81 | """ 82 | return 0.5 * x * (x + 2).clamp(min=0, max=2) 83 | 84 | 85 | class HardMishJit(nn.Module): 86 | def __init__(self, inplace: bool = False): 87 | super(HardMishJit, self).__init__() 88 | 89 | def forward(self, x): 90 | return hard_mish_jit(x) 91 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/adaptive_avgmax_pool.py: -------------------------------------------------------------------------------- 1 | """ PyTorch selectable adaptive pooling 2 | Adaptive pooling with the ability to select the type of pooling from: 3 | * 'avg' - Average pooling 4 | * 'max' - Max pooling 5 | * 'avgmax' - Sum of average and max pooling re-scaled by 0.5 6 | * 'avgmaxc' - Concatenation of average and max pooling along feature dim, doubles feature dim 7 | 8 | Both a functional and a nn.Module version of the pooling is provided. 9 | 10 | Hacked together by / Copyright 2020 Ross Wightman 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | def adaptive_pool_feat_mult(pool_type='avg'): 18 | if pool_type == 'catavgmax': 19 | return 2 20 | else: 21 | return 1 22 | 23 | 24 | def adaptive_avgmax_pool2d(x, output_size=1): 25 | x_avg = F.adaptive_avg_pool2d(x, output_size) 26 | x_max = F.adaptive_max_pool2d(x, output_size) 27 | return 0.5 * (x_avg + x_max) 28 | 29 | 30 | def adaptive_catavgmax_pool2d(x, output_size=1): 31 | x_avg = F.adaptive_avg_pool2d(x, output_size) 32 | x_max = F.adaptive_max_pool2d(x, output_size) 33 | return torch.cat((x_avg, x_max), 1) 34 | 35 | 36 | def select_adaptive_pool2d(x, pool_type='avg', output_size=1): 37 | """Selectable global pooling function with dynamic input kernel size 38 | """ 39 | if pool_type == 'avg': 40 | x = F.adaptive_avg_pool2d(x, output_size) 41 | elif pool_type == 'avgmax': 42 | x = adaptive_avgmax_pool2d(x, output_size) 43 | elif pool_type == 'catavgmax': 44 | x = adaptive_catavgmax_pool2d(x, output_size) 45 | elif pool_type == 'max': 46 | x = F.adaptive_max_pool2d(x, output_size) 47 | else: 48 | assert False, 'Invalid pool type: %s' % pool_type 49 | return x 50 | 51 | 52 | class FastAdaptiveAvgPool2d(nn.Module): 53 | def __init__(self, flatten=False): 54 | super(FastAdaptiveAvgPool2d, self).__init__() 55 | self.flatten = flatten 56 | 57 | def forward(self, x): 58 | return x.mean((2, 3), keepdim=not self.flatten) 59 | 60 | 61 | class AdaptiveAvgMaxPool2d(nn.Module): 62 | def __init__(self, output_size=1): 63 | super(AdaptiveAvgMaxPool2d, self).__init__() 64 | self.output_size = output_size 65 | 66 | def forward(self, x): 67 | return adaptive_avgmax_pool2d(x, self.output_size) 68 | 69 | 70 | class AdaptiveCatAvgMaxPool2d(nn.Module): 71 | def __init__(self, output_size=1): 72 | super(AdaptiveCatAvgMaxPool2d, self).__init__() 73 | self.output_size = output_size 74 | 75 | def forward(self, x): 76 | return adaptive_catavgmax_pool2d(x, self.output_size) 77 | 78 | 79 | class SelectAdaptivePool2d(nn.Module): 80 | """Selectable global pooling layer with dynamic input kernel size 81 | """ 82 | def __init__(self, output_size=1, pool_type='fast', flatten=False): 83 | super(SelectAdaptivePool2d, self).__init__() 84 | self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing 85 | self.flatten = nn.Flatten(1) if flatten else nn.Identity() 86 | if pool_type == '': 87 | self.pool = nn.Identity() # pass through 88 | elif pool_type == 'fast': 89 | assert output_size == 1 90 | self.pool = FastAdaptiveAvgPool2d(flatten) 91 | self.flatten = nn.Identity() 92 | elif pool_type == 'avg': 93 | self.pool = nn.AdaptiveAvgPool2d(output_size) 94 | elif pool_type == 'avgmax': 95 | self.pool = AdaptiveAvgMaxPool2d(output_size) 96 | elif pool_type == 'catavgmax': 97 | self.pool = AdaptiveCatAvgMaxPool2d(output_size) 98 | elif pool_type == 'max': 99 | self.pool = nn.AdaptiveMaxPool2d(output_size) 100 | else: 101 | assert False, 'Invalid pool type: %s' % pool_type 102 | 103 | def is_identity(self): 104 | return not self.pool_type 105 | 106 | def forward(self, x): 107 | x = self.pool(x) 108 | x = self.flatten(x) 109 | return x 110 | 111 | def feat_mult(self): 112 | return adaptive_pool_feat_mult(self.pool_type) 113 | 114 | def __repr__(self): 115 | return self.__class__.__name__ + ' (' \ 116 | + 'pool_type=' + self.pool_type \ 117 | + ', flatten=' + str(self.flatten) + ')' 118 | 119 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/blur_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | BlurPool layer inspired by 3 | - Kornia's Max_BlurPool2d 4 | - Making Convolutional Networks Shift-Invariant Again :cite:`zhang2019shiftinvar` 5 | 6 | Hacked together by Chris Ha and Ross Wightman 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | from .padding import get_padding 14 | 15 | 16 | class BlurPool2d(nn.Module): 17 | r"""Creates a module that computes blurs and downsample a given feature map. 18 | See :cite:`zhang2019shiftinvar` for more details. 19 | Corresponds to the Downsample class, which does blurring and subsampling 20 | 21 | Args: 22 | channels = Number of input channels 23 | filt_size (int): binomial filter size for blurring. currently supports 3 (default) and 5. 24 | stride (int): downsampling filter stride 25 | 26 | Returns: 27 | torch.Tensor: the transformed tensor. 28 | """ 29 | def __init__(self, channels, filt_size=3, stride=2) -> None: 30 | super(BlurPool2d, self).__init__() 31 | assert filt_size > 1 32 | self.channels = channels 33 | self.filt_size = filt_size 34 | self.stride = stride 35 | self.padding = [get_padding(filt_size, stride, dilation=1)] * 4 36 | coeffs = torch.tensor((np.poly1d((0.5, 0.5)) ** (self.filt_size - 1)).coeffs.astype(np.float32)) 37 | blur_filter = (coeffs[:, None] * coeffs[None, :])[None, None, :, :].repeat(self.channels, 1, 1, 1) 38 | self.register_buffer('filt', blur_filter, persistent=False) 39 | 40 | def forward(self, x: torch.Tensor) -> torch.Tensor: 41 | x = F.pad(x, self.padding, 'reflect') 42 | return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) 43 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/cbam.py: -------------------------------------------------------------------------------- 1 | """ CBAM (sort-of) Attention 2 | 3 | Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521 4 | 5 | WARNING: Results with these attention layers have been mixed. They can significantly reduce performance on 6 | some tasks, especially fine-grained it seems. I may end up removing this impl. 7 | 8 | Hacked together by / Copyright 2020 Ross Wightman 9 | """ 10 | import torch 11 | from torch import nn as nn 12 | import torch.nn.functional as F 13 | 14 | from .conv_bn_act import ConvNormAct 15 | from .create_act import create_act_layer, get_act_layer 16 | from .helpers import make_divisible 17 | 18 | 19 | class ChannelAttn(nn.Module): 20 | """ Original CBAM channel attention module, currently avg + max pool variant only. 21 | """ 22 | def __init__( 23 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 24 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 25 | super(ChannelAttn, self).__init__() 26 | if not rd_channels: 27 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 28 | self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias) 29 | self.act = act_layer(inplace=True) 30 | self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias) 31 | self.gate = create_act_layer(gate_layer) 32 | 33 | def forward(self, x): 34 | x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True)))) 35 | x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True)))) 36 | return x * self.gate(x_avg + x_max) 37 | 38 | 39 | class LightChannelAttn(ChannelAttn): 40 | """An experimental 'lightweight' that sums avg + max pool first 41 | """ 42 | def __init__( 43 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 44 | act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 45 | super(LightChannelAttn, self).__init__( 46 | channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias) 47 | 48 | def forward(self, x): 49 | x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True) 50 | x_attn = self.fc2(self.act(self.fc1(x_pool))) 51 | return x * F.sigmoid(x_attn) 52 | 53 | 54 | class SpatialAttn(nn.Module): 55 | """ Original CBAM spatial attention module 56 | """ 57 | def __init__(self, kernel_size=7, gate_layer='sigmoid'): 58 | super(SpatialAttn, self).__init__() 59 | self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False) 60 | self.gate = create_act_layer(gate_layer) 61 | 62 | def forward(self, x): 63 | x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1) 64 | x_attn = self.conv(x_attn) 65 | return x * self.gate(x_attn) 66 | 67 | 68 | class LightSpatialAttn(nn.Module): 69 | """An experimental 'lightweight' variant that sums avg_pool and max_pool results. 70 | """ 71 | def __init__(self, kernel_size=7, gate_layer='sigmoid'): 72 | super(LightSpatialAttn, self).__init__() 73 | self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False) 74 | self.gate = create_act_layer(gate_layer) 75 | 76 | def forward(self, x): 77 | x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True) 78 | x_attn = self.conv(x_attn) 79 | return x * self.gate(x_attn) 80 | 81 | 82 | class CbamModule(nn.Module): 83 | def __init__( 84 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 85 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 86 | super(CbamModule, self).__init__() 87 | self.channel = ChannelAttn( 88 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels, 89 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) 90 | self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer) 91 | 92 | def forward(self, x): 93 | x = self.channel(x) 94 | x = self.spatial(x) 95 | return x 96 | 97 | 98 | class LightCbamModule(nn.Module): 99 | def __init__( 100 | self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, 101 | spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): 102 | super(LightCbamModule, self).__init__() 103 | self.channel = LightChannelAttn( 104 | channels, rd_ratio=rd_ratio, rd_channels=rd_channels, 105 | rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) 106 | self.spatial = LightSpatialAttn(spatial_kernel_size) 107 | 108 | def forward(self, x): 109 | x = self.channel(x) 110 | x = self.spatial(x) 111 | return x 112 | 113 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/classifier.py: -------------------------------------------------------------------------------- 1 | """ Classifier head and layer factory 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | 8 | from .adaptive_avgmax_pool import SelectAdaptivePool2d 9 | 10 | 11 | def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): 12 | flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling 13 | if not pool_type: 14 | assert num_classes == 0 or use_conv,\ 15 | 'Pooling can only be disabled if classifier is also removed or conv classifier is used' 16 | flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling) 17 | global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) 18 | num_pooled_features = num_features * global_pool.feat_mult() 19 | return global_pool, num_pooled_features 20 | 21 | 22 | def _create_fc(num_features, num_classes, use_conv=False): 23 | if num_classes <= 0: 24 | fc = nn.Identity() # pass-through (no classifier) 25 | elif use_conv: 26 | fc = nn.Conv2d(num_features, num_classes, 1, bias=True) 27 | else: 28 | fc = nn.Linear(num_features, num_classes, bias=True) 29 | return fc 30 | 31 | 32 | def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): 33 | global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) 34 | fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 35 | return global_pool, fc 36 | 37 | 38 | class ClassifierHead(nn.Module): 39 | """Classifier head w/ configurable global pooling and dropout.""" 40 | 41 | def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): 42 | super(ClassifierHead, self).__init__() 43 | self.drop_rate = drop_rate 44 | self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) 45 | self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) 46 | self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() 47 | 48 | def forward(self, x, pre_logits: bool = False): 49 | x = self.global_pool(x) 50 | if self.drop_rate: 51 | x = F.dropout(x, p=float(self.drop_rate), training=self.training) 52 | if pre_logits: 53 | return x.flatten(1) 54 | else: 55 | x = self.fc(x) 56 | return self.flatten(x) 57 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/config.py: -------------------------------------------------------------------------------- 1 | """ Model / Layer Config singleton state 2 | """ 3 | from typing import Any, Optional 4 | 5 | __all__ = [ 6 | 'is_exportable', 'is_scriptable', 'is_no_jit', 7 | 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' 8 | ] 9 | 10 | # Set to True if prefer to have layers with no jit optimization (includes activations) 11 | _NO_JIT = False 12 | 13 | # Set to True if prefer to have activation layers with no jit optimization 14 | # NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying 15 | # the jit flags so far are activations. This will change as more layers are updated and/or added. 16 | _NO_ACTIVATION_JIT = False 17 | 18 | # Set to True if exporting a model with Same padding via ONNX 19 | _EXPORTABLE = False 20 | 21 | # Set to True if wanting to use torch.jit.script on a model 22 | _SCRIPTABLE = False 23 | 24 | 25 | def is_no_jit(): 26 | return _NO_JIT 27 | 28 | 29 | class set_no_jit: 30 | def __init__(self, mode: bool) -> None: 31 | global _NO_JIT 32 | self.prev = _NO_JIT 33 | _NO_JIT = mode 34 | 35 | def __enter__(self) -> None: 36 | pass 37 | 38 | def __exit__(self, *args: Any) -> bool: 39 | global _NO_JIT 40 | _NO_JIT = self.prev 41 | return False 42 | 43 | 44 | def is_exportable(): 45 | return _EXPORTABLE 46 | 47 | 48 | class set_exportable: 49 | def __init__(self, mode: bool) -> None: 50 | global _EXPORTABLE 51 | self.prev = _EXPORTABLE 52 | _EXPORTABLE = mode 53 | 54 | def __enter__(self) -> None: 55 | pass 56 | 57 | def __exit__(self, *args: Any) -> bool: 58 | global _EXPORTABLE 59 | _EXPORTABLE = self.prev 60 | return False 61 | 62 | 63 | def is_scriptable(): 64 | return _SCRIPTABLE 65 | 66 | 67 | class set_scriptable: 68 | def __init__(self, mode: bool) -> None: 69 | global _SCRIPTABLE 70 | self.prev = _SCRIPTABLE 71 | _SCRIPTABLE = mode 72 | 73 | def __enter__(self) -> None: 74 | pass 75 | 76 | def __exit__(self, *args: Any) -> bool: 77 | global _SCRIPTABLE 78 | _SCRIPTABLE = self.prev 79 | return False 80 | 81 | 82 | class set_layer_config: 83 | """ Layer config context manager that allows setting all layer config flags at once. 84 | If a flag arg is None, it will not change the current value. 85 | """ 86 | def __init__( 87 | self, 88 | scriptable: Optional[bool] = None, 89 | exportable: Optional[bool] = None, 90 | no_jit: Optional[bool] = None, 91 | no_activation_jit: Optional[bool] = None): 92 | global _SCRIPTABLE 93 | global _EXPORTABLE 94 | global _NO_JIT 95 | global _NO_ACTIVATION_JIT 96 | self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT 97 | if scriptable is not None: 98 | _SCRIPTABLE = scriptable 99 | if exportable is not None: 100 | _EXPORTABLE = exportable 101 | if no_jit is not None: 102 | _NO_JIT = no_jit 103 | if no_activation_jit is not None: 104 | _NO_ACTIVATION_JIT = no_activation_jit 105 | 106 | def __enter__(self) -> None: 107 | pass 108 | 109 | def __exit__(self, *args: Any) -> bool: 110 | global _SCRIPTABLE 111 | global _EXPORTABLE 112 | global _NO_JIT 113 | global _NO_ACTIVATION_JIT 114 | _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev 115 | return False 116 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/conv2d_same.py: -------------------------------------------------------------------------------- 1 | """ Conv2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import Tuple, Optional 9 | 10 | from .padding import pad_same, get_padding_value 11 | 12 | 13 | def conv2d_same( 14 | x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), 15 | padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): 16 | x = pad_same(x, weight.shape[-2:], stride, dilation) 17 | return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) 18 | 19 | 20 | class Conv2dSame(nn.Conv2d): 21 | """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions 22 | """ 23 | 24 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 25 | padding=0, dilation=1, groups=1, bias=True): 26 | super(Conv2dSame, self).__init__( 27 | in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) 28 | 29 | def forward(self, x): 30 | return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 31 | 32 | 33 | def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): 34 | padding = kwargs.pop('padding', '') 35 | kwargs.setdefault('bias', False) 36 | padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) 37 | if is_dynamic: 38 | return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) 39 | else: 40 | return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) 41 | 42 | 43 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/conv_bn_act.py: -------------------------------------------------------------------------------- 1 | """ Conv2d + BN + Act 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import functools 6 | from torch import nn as nn 7 | 8 | from .create_conv2d import create_conv2d 9 | from .create_norm_act import get_norm_act_layer 10 | 11 | 12 | class ConvNormAct(nn.Module): 13 | def __init__( 14 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 15 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None): 16 | super(ConvNormAct, self).__init__() 17 | self.conv = create_conv2d( 18 | in_channels, out_channels, kernel_size, stride=stride, 19 | padding=padding, dilation=dilation, groups=groups, bias=bias) 20 | 21 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 22 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 23 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 24 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 25 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 26 | 27 | @property 28 | def in_channels(self): 29 | return self.conv.in_channels 30 | 31 | @property 32 | def out_channels(self): 33 | return self.conv.out_channels 34 | 35 | def forward(self, x): 36 | x = self.conv(x) 37 | x = self.bn(x) 38 | return x 39 | 40 | 41 | ConvBnAct = ConvNormAct 42 | 43 | 44 | def create_aa(aa_layer, channels, stride=2, enable=True): 45 | if not aa_layer or not enable: 46 | return nn.Identity() 47 | if isinstance(aa_layer, functools.partial): 48 | if issubclass(aa_layer.func, nn.AvgPool2d): 49 | return aa_layer() 50 | else: 51 | return aa_layer(channels) 52 | elif issubclass(aa_layer, nn.AvgPool2d): 53 | return aa_layer(stride) 54 | else: 55 | return aa_layer(channels=channels, stride=stride) 56 | 57 | 58 | class ConvNormActAa(nn.Module): 59 | def __init__( 60 | self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, 61 | bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None): 62 | super(ConvNormActAa, self).__init__() 63 | use_aa = aa_layer is not None and stride == 2 64 | 65 | self.conv = create_conv2d( 66 | in_channels, out_channels, kernel_size, stride=1 if use_aa else stride, 67 | padding=padding, dilation=dilation, groups=groups, bias=bias) 68 | 69 | # NOTE for backwards compatibility with models that use separate norm and act layer definitions 70 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 71 | # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` 72 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 73 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 74 | self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa) 75 | 76 | @property 77 | def in_channels(self): 78 | return self.conv.in_channels 79 | 80 | @property 81 | def out_channels(self): 82 | return self.conv.out_channels 83 | 84 | def forward(self, x): 85 | x = self.conv(x) 86 | x = self.bn(x) 87 | x = self.aa(x) 88 | return x 89 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/create_attn.py: -------------------------------------------------------------------------------- 1 | """ Attention Factory 2 | 3 | Hacked together by / Copyright 2021 Ross Wightman 4 | """ 5 | import torch 6 | from functools import partial 7 | 8 | from .bottleneck_attn import BottleneckAttn 9 | from .cbam import CbamModule, LightCbamModule 10 | from .eca import EcaModule, CecaModule 11 | from .gather_excite import GatherExcite 12 | from .global_context import GlobalContext 13 | from .halo_attn import HaloAttn 14 | from .lambda_layer import LambdaLayer 15 | from .non_local_attn import NonLocalAttn, BatNonLocalAttn 16 | from .selective_kernel import SelectiveKernel 17 | from .split_attn import SplitAttn 18 | from .squeeze_excite import SEModule, EffectiveSEModule 19 | 20 | 21 | def get_attn(attn_type): 22 | if isinstance(attn_type, torch.nn.Module): 23 | return attn_type 24 | module_cls = None 25 | if attn_type: 26 | if isinstance(attn_type, str): 27 | attn_type = attn_type.lower() 28 | # Lightweight attention modules (channel and/or coarse spatial). 29 | # Typically added to existing network architecture blocks in addition to existing convolutions. 30 | if attn_type == 'se': 31 | module_cls = SEModule 32 | elif attn_type == 'ese': 33 | module_cls = EffectiveSEModule 34 | elif attn_type == 'eca': 35 | module_cls = EcaModule 36 | elif attn_type == 'ecam': 37 | module_cls = partial(EcaModule, use_mlp=True) 38 | elif attn_type == 'ceca': 39 | module_cls = CecaModule 40 | elif attn_type == 'ge': 41 | module_cls = GatherExcite 42 | elif attn_type == 'gc': 43 | module_cls = GlobalContext 44 | elif attn_type == 'gca': 45 | module_cls = partial(GlobalContext, fuse_add=True, fuse_scale=False) 46 | elif attn_type == 'cbam': 47 | module_cls = CbamModule 48 | elif attn_type == 'lcbam': 49 | module_cls = LightCbamModule 50 | 51 | # Attention / attention-like modules w/ significant params 52 | # Typically replace some of the existing workhorse convs in a network architecture. 53 | # All of these accept a stride argument and can spatially downsample the input. 54 | elif attn_type == 'sk': 55 | module_cls = SelectiveKernel 56 | elif attn_type == 'splat': 57 | module_cls = SplitAttn 58 | 59 | # Self-attention / attention-like modules w/ significant compute and/or params 60 | # Typically replace some of the existing workhorse convs in a network architecture. 61 | # All of these accept a stride argument and can spatially downsample the input. 62 | elif attn_type == 'lambda': 63 | return LambdaLayer 64 | elif attn_type == 'bottleneck': 65 | return BottleneckAttn 66 | elif attn_type == 'halo': 67 | return HaloAttn 68 | elif attn_type == 'nl': 69 | module_cls = NonLocalAttn 70 | elif attn_type == 'bat': 71 | module_cls = BatNonLocalAttn 72 | 73 | # Woops! 74 | else: 75 | assert False, "Invalid attn module (%s)" % attn_type 76 | elif isinstance(attn_type, bool): 77 | if attn_type: 78 | module_cls = SEModule 79 | else: 80 | module_cls = attn_type 81 | return module_cls 82 | 83 | 84 | def create_attn(attn_type, channels, **kwargs): 85 | module_cls = get_attn(attn_type) 86 | if module_cls is not None: 87 | # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels 88 | return module_cls(channels, **kwargs) 89 | return None 90 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/create_conv2d.py: -------------------------------------------------------------------------------- 1 | """ Create Conv2d Factory Method 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | from .mixed_conv2d import MixedConv2d 7 | from .cond_conv2d import CondConv2d 8 | from .conv2d_same import create_conv2d_pad 9 | 10 | 11 | def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): 12 | """ Select a 2d convolution implementation based on arguments 13 | Creates and returns one of torch.nn.Conv2d, Conv2dSame, MixedConv2d, or CondConv2d. 14 | 15 | Used extensively by EfficientNet, MobileNetv3 and related networks. 16 | """ 17 | if isinstance(kernel_size, list): 18 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 19 | if 'groups' in kwargs: 20 | groups = kwargs.pop('groups') 21 | if groups == in_channels: 22 | kwargs['depthwise'] = True 23 | else: 24 | assert groups == 1 25 | # We're going to use only lists for defining the MixedConv2d kernel groups, 26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 27 | m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) 28 | else: 29 | depthwise = kwargs.pop('depthwise', False) 30 | # for DW out_channels must be multiple of in_channels as must have out_channels % groups == 0 31 | groups = in_channels if depthwise else kwargs.pop('groups', 1) 32 | if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 33 | m = CondConv2d(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 34 | else: 35 | m = create_conv2d_pad(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 36 | return m 37 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/create_norm.py: -------------------------------------------------------------------------------- 1 | """ Norm Layer Factory 2 | 3 | Create norm modules by string (to mirror create_act and creat_norm-act fns) 4 | 5 | Copyright 2022 Ross Wightman 6 | """ 7 | import types 8 | import functools 9 | 10 | import torch.nn as nn 11 | 12 | from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d 13 | 14 | _NORM_MAP = dict( 15 | batchnorm=nn.BatchNorm2d, 16 | batchnorm2d=nn.BatchNorm2d, 17 | batchnorm1d=nn.BatchNorm1d, 18 | groupnorm=GroupNorm, 19 | groupnorm1=GroupNorm1, 20 | layernorm=LayerNorm, 21 | layernorm2d=LayerNorm2d, 22 | ) 23 | _NORM_TYPES = {m for n, m in _NORM_MAP.items()} 24 | 25 | 26 | def create_norm_layer(layer_name, num_features, act_layer=None, apply_act=True, **kwargs): 27 | layer = get_norm_layer(layer_name, act_layer=act_layer) 28 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 29 | return layer_instance 30 | 31 | 32 | def get_norm_layer(norm_layer): 33 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 34 | norm_kwargs = {} 35 | 36 | # unbind partial fn, so args can be rebound later 37 | if isinstance(norm_layer, functools.partial): 38 | norm_kwargs.update(norm_layer.keywords) 39 | norm_layer = norm_layer.func 40 | 41 | if isinstance(norm_layer, str): 42 | layer_name = norm_layer.replace('_', '') 43 | norm_layer = _NORM_MAP.get(layer_name, None) 44 | elif norm_layer in _NORM_TYPES: 45 | norm_layer = norm_layer 46 | elif isinstance(norm_layer, types.FunctionType): 47 | # if function type, assume it is a lambda/fn that creates a norm layer 48 | norm_layer = norm_layer 49 | else: 50 | type_name = norm_layer.__name__.lower().replace('_', '') 51 | norm_layer = _NORM_MAP.get(type_name, None) 52 | assert norm_layer is not None, f"No equivalent norm layer for {type_name}" 53 | 54 | if norm_kwargs: 55 | norm_layer = functools.partial(norm_layer, **norm_kwargs) # bind/rebind args 56 | return norm_layer 57 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/create_norm_act.py: -------------------------------------------------------------------------------- 1 | """ NormAct (Normalizaiton + Activation Layer) Factory 2 | 3 | Create norm + act combo modules that attempt to be backwards compatible with separate norm + act 4 | isntances in models. Where these are used it will be possible to swap separate BN + act layers with 5 | combined modules like IABN or EvoNorms. 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | import types 10 | import functools 11 | 12 | from .evo_norm import * 13 | from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d 14 | from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d 15 | from .inplace_abn import InplaceAbn 16 | 17 | _NORM_ACT_MAP = dict( 18 | batchnorm=BatchNormAct2d, 19 | batchnorm2d=BatchNormAct2d, 20 | groupnorm=GroupNormAct, 21 | groupnorm1=functools.partial(GroupNormAct, num_groups=1), 22 | layernorm=LayerNormAct, 23 | layernorm2d=LayerNormAct2d, 24 | evonormb0=EvoNorm2dB0, 25 | evonormb1=EvoNorm2dB1, 26 | evonormb2=EvoNorm2dB2, 27 | evonorms0=EvoNorm2dS0, 28 | evonorms0a=EvoNorm2dS0a, 29 | evonorms1=EvoNorm2dS1, 30 | evonorms1a=EvoNorm2dS1a, 31 | evonorms2=EvoNorm2dS2, 32 | evonorms2a=EvoNorm2dS2a, 33 | frn=FilterResponseNormAct2d, 34 | frntlu=FilterResponseNormTlu2d, 35 | inplaceabn=InplaceAbn, 36 | iabn=InplaceAbn, 37 | ) 38 | _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} 39 | # has act_layer arg to define act type 40 | _NORM_ACT_REQUIRES_ARG = { 41 | BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn} 42 | 43 | 44 | def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs): 45 | layer = get_norm_act_layer(layer_name, act_layer=act_layer) 46 | layer_instance = layer(num_features, apply_act=apply_act, **kwargs) 47 | if jit: 48 | layer_instance = torch.jit.script(layer_instance) 49 | return layer_instance 50 | 51 | 52 | def get_norm_act_layer(norm_layer, act_layer=None): 53 | assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) 54 | assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) 55 | norm_act_kwargs = {} 56 | 57 | # unbind partial fn, so args can be rebound later 58 | if isinstance(norm_layer, functools.partial): 59 | norm_act_kwargs.update(norm_layer.keywords) 60 | norm_layer = norm_layer.func 61 | 62 | if isinstance(norm_layer, str): 63 | layer_name = norm_layer.replace('_', '').lower().split('-')[0] 64 | norm_act_layer = _NORM_ACT_MAP.get(layer_name, None) 65 | elif norm_layer in _NORM_ACT_TYPES: 66 | norm_act_layer = norm_layer 67 | elif isinstance(norm_layer, types.FunctionType): 68 | # if function type, must be a lambda/fn that creates a norm_act layer 69 | norm_act_layer = norm_layer 70 | else: 71 | type_name = norm_layer.__name__.lower() 72 | if type_name.startswith('batchnorm'): 73 | norm_act_layer = BatchNormAct2d 74 | elif type_name.startswith('groupnorm'): 75 | norm_act_layer = GroupNormAct 76 | elif type_name.startswith('groupnorm1'): 77 | norm_act_layer = functools.partial(GroupNormAct, num_groups=1) 78 | elif type_name.startswith('layernorm2d'): 79 | norm_act_layer = LayerNormAct2d 80 | elif type_name.startswith('layernorm'): 81 | norm_act_layer = LayerNormAct 82 | else: 83 | assert False, f"No equivalent norm_act layer for {type_name}" 84 | 85 | if norm_act_layer in _NORM_ACT_REQUIRES_ARG: 86 | # pass `act_layer` through for backwards compat where `act_layer=None` implies no activation. 87 | # In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types 88 | norm_act_kwargs.setdefault('act_layer', act_layer) 89 | if norm_act_kwargs: 90 | norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args 91 | return norm_act_layer 92 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/fast_norm.py: -------------------------------------------------------------------------------- 1 | """ 'Fast' Normalization Functions 2 | 3 | For GroupNorm and LayerNorm these functions bypass typical AMP upcast to float32. 4 | 5 | Additionally, for LayerNorm, the APEX fused LN is used if available (which also does not upcast) 6 | 7 | Hacked together by / Copyright 2022 Ross Wightman 8 | """ 9 | from typing import List, Optional 10 | 11 | import torch 12 | from torch.nn import functional as F 13 | 14 | try: 15 | from apex.normalization.fused_layer_norm import fused_layer_norm_affine 16 | has_apex = True 17 | except ImportError: 18 | has_apex = False 19 | 20 | 21 | # fast (ie lower precision LN) can be disabled with this flag if issues crop up 22 | _USE_FAST_NORM = False # defaulting to False for now 23 | 24 | 25 | def is_fast_norm(): 26 | return _USE_FAST_NORM 27 | 28 | 29 | def set_fast_norm(enable=True): 30 | global _USE_FAST_NORM 31 | _USE_FAST_NORM = enable 32 | 33 | 34 | def fast_group_norm( 35 | x: torch.Tensor, 36 | num_groups: int, 37 | weight: Optional[torch.Tensor] = None, 38 | bias: Optional[torch.Tensor] = None, 39 | eps: float = 1e-5 40 | ) -> torch.Tensor: 41 | if torch.jit.is_scripting(): 42 | # currently cannot use is_autocast_enabled within torchscript 43 | return F.group_norm(x, num_groups, weight, bias, eps) 44 | 45 | if torch.is_autocast_enabled(): 46 | # normally native AMP casts GN inputs to float32 47 | # here we use the low precision autocast dtype 48 | # FIXME what to do re CPU autocast? 49 | dt = torch.get_autocast_gpu_dtype() 50 | x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) 51 | 52 | with torch.cuda.amp.autocast(enabled=False): 53 | return F.group_norm(x, num_groups, weight, bias, eps) 54 | 55 | 56 | def fast_layer_norm( 57 | x: torch.Tensor, 58 | normalized_shape: List[int], 59 | weight: Optional[torch.Tensor] = None, 60 | bias: Optional[torch.Tensor] = None, 61 | eps: float = 1e-5 62 | ) -> torch.Tensor: 63 | if torch.jit.is_scripting(): 64 | # currently cannot use is_autocast_enabled within torchscript 65 | return F.layer_norm(x, normalized_shape, weight, bias, eps) 66 | 67 | if has_apex: 68 | return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps) 69 | 70 | if torch.is_autocast_enabled(): 71 | # normally native AMP casts LN inputs to float32 72 | # apex LN does not, this is behaving like Apex 73 | dt = torch.get_autocast_gpu_dtype() 74 | # FIXME what to do re CPU autocast? 75 | x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) 76 | 77 | with torch.cuda.amp.autocast(enabled=False): 78 | return F.layer_norm(x, normalized_shape, weight, bias, eps) 79 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/filter_response_norm.py: -------------------------------------------------------------------------------- 1 | """ Filter Response Norm in PyTorch 2 | 3 | Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737 4 | 5 | Hacked together by / Copyright 2021 Ross Wightman 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .create_act import create_act_layer 11 | from .trace_utils import _assert 12 | 13 | 14 | def inv_instance_rms(x, eps: float = 1e-5): 15 | rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype) 16 | return rms.expand(x.shape) 17 | 18 | 19 | class FilterResponseNormTlu2d(nn.Module): 20 | def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_): 21 | super(FilterResponseNormTlu2d, self).__init__() 22 | self.apply_act = apply_act # apply activation (non-linearity) 23 | self.rms = rms 24 | self.eps = eps 25 | self.weight = nn.Parameter(torch.ones(num_features)) 26 | self.bias = nn.Parameter(torch.zeros(num_features)) 27 | self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self): 31 | nn.init.ones_(self.weight) 32 | nn.init.zeros_(self.bias) 33 | if self.tau is not None: 34 | nn.init.zeros_(self.tau) 35 | 36 | def forward(self, x): 37 | _assert(x.dim() == 4, 'expected 4D input') 38 | x_dtype = x.dtype 39 | v_shape = (1, -1, 1, 1) 40 | x = x * inv_instance_rms(x, self.eps) 41 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 42 | return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x 43 | 44 | 45 | class FilterResponseNormAct2d(nn.Module): 46 | def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_): 47 | super(FilterResponseNormAct2d, self).__init__() 48 | if act_layer is not None and apply_act: 49 | self.act = create_act_layer(act_layer, inplace=inplace) 50 | else: 51 | self.act = nn.Identity() 52 | self.rms = rms 53 | self.eps = eps 54 | self.weight = nn.Parameter(torch.ones(num_features)) 55 | self.bias = nn.Parameter(torch.zeros(num_features)) 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | nn.init.ones_(self.weight) 60 | nn.init.zeros_(self.bias) 61 | 62 | def forward(self, x): 63 | _assert(x.dim() == 4, 'expected 4D input') 64 | x_dtype = x.dtype 65 | v_shape = (1, -1, 1, 1) 66 | x = x * inv_instance_rms(x, self.eps) 67 | x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) 68 | return self.act(x) 69 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/gather_excite.py: -------------------------------------------------------------------------------- 1 | """ Gather-Excite Attention Block 2 | 3 | Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 4 | 5 | Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet 6 | 7 | I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another 8 | impl that covers all of the cases. 9 | 10 | NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation 11 | 12 | Hacked together by / Copyright 2021 Ross Wightman 13 | """ 14 | import math 15 | 16 | from torch import nn as nn 17 | import torch.nn.functional as F 18 | 19 | from .create_act import create_act_layer, get_act_layer 20 | from .create_conv2d import create_conv2d 21 | from .helpers import make_divisible 22 | from .mlp import ConvMlp 23 | 24 | 25 | class GatherExcite(nn.Module): 26 | """ Gather-Excite Attention Module 27 | """ 28 | def __init__( 29 | self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, 30 | rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, 31 | act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): 32 | super(GatherExcite, self).__init__() 33 | self.add_maxpool = add_maxpool 34 | act_layer = get_act_layer(act_layer) 35 | self.extent = extent 36 | if extra_params: 37 | self.gather = nn.Sequential() 38 | if extent == 0: 39 | assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' 40 | self.gather.add_module( 41 | 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) 42 | if norm_layer: 43 | self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) 44 | else: 45 | assert extent % 2 == 0 46 | num_conv = int(math.log2(extent)) 47 | for i in range(num_conv): 48 | self.gather.add_module( 49 | f'conv{i + 1}', 50 | create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) 51 | if norm_layer: 52 | self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) 53 | if i != num_conv - 1: 54 | self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) 55 | else: 56 | self.gather = None 57 | if self.extent == 0: 58 | self.gk = 0 59 | self.gs = 0 60 | else: 61 | assert extent % 2 == 0 62 | self.gk = self.extent * 2 - 1 63 | self.gs = self.extent 64 | 65 | if not rd_channels: 66 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 67 | self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() 68 | self.gate = create_act_layer(gate_layer) 69 | 70 | def forward(self, x): 71 | size = x.shape[-2:] 72 | if self.gather is not None: 73 | x_ge = self.gather(x) 74 | else: 75 | if self.extent == 0: 76 | # global extent 77 | x_ge = x.mean(dim=(2, 3), keepdims=True) 78 | if self.add_maxpool: 79 | # experimental codepath, may remove or change 80 | x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) 81 | else: 82 | x_ge = F.avg_pool2d( 83 | x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False) 84 | if self.add_maxpool: 85 | # experimental codepath, may remove or change 86 | x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2) 87 | x_ge = self.mlp(x_ge) 88 | if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: 89 | x_ge = F.interpolate(x_ge, size=size) 90 | return x * self.gate(x_ge) 91 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/global_context.py: -------------------------------------------------------------------------------- 1 | """ Global Context Attention Block 2 | 3 | Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` 4 | - https://arxiv.org/abs/1904.11492 5 | 6 | Official code consulted as reference: https://github.com/xvjiarui/GCNet 7 | 8 | Hacked together by / Copyright 2021 Ross Wightman 9 | """ 10 | from torch import nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .create_act import create_act_layer, get_act_layer 14 | from .helpers import make_divisible 15 | from .mlp import ConvMlp 16 | from .norm import LayerNorm2d 17 | 18 | 19 | class GlobalContext(nn.Module): 20 | 21 | def __init__(self, channels, use_attn=True, fuse_add=False, fuse_scale=True, init_last_zero=False, 22 | rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): 23 | super(GlobalContext, self).__init__() 24 | act_layer = get_act_layer(act_layer) 25 | 26 | self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None 27 | 28 | if rd_channels is None: 29 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 30 | if fuse_add: 31 | self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 32 | else: 33 | self.mlp_add = None 34 | if fuse_scale: 35 | self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) 36 | else: 37 | self.mlp_scale = None 38 | 39 | self.gate = create_act_layer(gate_layer) 40 | self.init_last_zero = init_last_zero 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self): 44 | if self.conv_attn is not None: 45 | nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') 46 | if self.mlp_add is not None: 47 | nn.init.zeros_(self.mlp_add.fc2.weight) 48 | 49 | def forward(self, x): 50 | B, C, H, W = x.shape 51 | 52 | if self.conv_attn is not None: 53 | attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) 54 | attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) 55 | context = x.reshape(B, C, H * W).unsqueeze(1) @ attn 56 | context = context.view(B, C, 1, 1) 57 | else: 58 | context = x.mean(dim=(2, 3), keepdim=True) 59 | 60 | if self.mlp_scale is not None: 61 | mlp_x = self.mlp_scale(context) 62 | x = x * self.gate(mlp_x) 63 | if self.mlp_add is not None: 64 | mlp_x = self.mlp_add(context) 65 | x = x + mlp_x 66 | 67 | return x 68 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from itertools import repeat 6 | import collections.abc 7 | 8 | 9 | # From PyTorch internals 10 | def _ntuple(n): 11 | def parse(x): 12 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 13 | return x 14 | return tuple(repeat(x, n)) 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None, round_limit=.9): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < round_limit * v: 30 | new_v += divisor 31 | return new_v 32 | 33 | 34 | def extend_tuple(x, n): 35 | # pdas a tuple to specified n by padding with last value 36 | if not isinstance(x, (tuple, list)): 37 | x = (x,) 38 | else: 39 | x = tuple(x) 40 | pad_n = n - len(x) 41 | if pad_n <= 0: 42 | return x[:n] 43 | return x + (x[-1],) * pad_n 44 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/inplace_abn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | try: 5 | from inplace_abn.functions import inplace_abn, inplace_abn_sync 6 | has_iabn = True 7 | except ImportError: 8 | has_iabn = False 9 | 10 | def inplace_abn(x, weight, bias, running_mean, running_var, 11 | training=True, momentum=0.1, eps=1e-05, activation="leaky_relu", activation_param=0.01): 12 | raise ImportError( 13 | "Please install InplaceABN:'pip install git+https://github.com/mapillary/inplace_abn.git@v1.0.12'") 14 | 15 | def inplace_abn_sync(**kwargs): 16 | inplace_abn(**kwargs) 17 | 18 | 19 | class InplaceAbn(nn.Module): 20 | """Activated Batch Normalization 21 | 22 | This gathers a BatchNorm and an activation function in a single module 23 | 24 | Parameters 25 | ---------- 26 | num_features : int 27 | Number of feature channels in the input and output. 28 | eps : float 29 | Small constant to prevent numerical issues. 30 | momentum : float 31 | Momentum factor applied to compute running statistics. 32 | affine : bool 33 | If `True` apply learned scale and shift transformation after normalization. 34 | act_layer : str or nn.Module type 35 | Name or type of the activation functions, one of: `leaky_relu`, `elu` 36 | act_param : float 37 | Negative slope for the `leaky_relu` activation. 38 | """ 39 | 40 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, 41 | act_layer="leaky_relu", act_param=0.01, drop_layer=None): 42 | super(InplaceAbn, self).__init__() 43 | self.num_features = num_features 44 | self.affine = affine 45 | self.eps = eps 46 | self.momentum = momentum 47 | if apply_act: 48 | if isinstance(act_layer, str): 49 | assert act_layer in ('leaky_relu', 'elu', 'identity', '') 50 | self.act_name = act_layer if act_layer else 'identity' 51 | else: 52 | # convert act layer passed as type to string 53 | if act_layer == nn.ELU: 54 | self.act_name = 'elu' 55 | elif act_layer == nn.LeakyReLU: 56 | self.act_name = 'leaky_relu' 57 | elif act_layer is None or act_layer == nn.Identity: 58 | self.act_name = 'identity' 59 | else: 60 | assert False, f'Invalid act layer {act_layer.__name__} for IABN' 61 | else: 62 | self.act_name = 'identity' 63 | self.act_param = act_param 64 | if self.affine: 65 | self.weight = nn.Parameter(torch.ones(num_features)) 66 | self.bias = nn.Parameter(torch.zeros(num_features)) 67 | else: 68 | self.register_parameter('weight', None) 69 | self.register_parameter('bias', None) 70 | self.register_buffer('running_mean', torch.zeros(num_features)) 71 | self.register_buffer('running_var', torch.ones(num_features)) 72 | self.reset_parameters() 73 | 74 | def reset_parameters(self): 75 | nn.init.constant_(self.running_mean, 0) 76 | nn.init.constant_(self.running_var, 1) 77 | if self.affine: 78 | nn.init.constant_(self.weight, 1) 79 | nn.init.constant_(self.bias, 0) 80 | 81 | def forward(self, x): 82 | output = inplace_abn( 83 | x, self.weight, self.bias, self.running_mean, self.running_var, 84 | self.training, self.momentum, self.eps, self.act_name, self.act_param) 85 | if isinstance(output, tuple): 86 | output = output[0] 87 | return output 88 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear layer (alternate definition) 2 | """ 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn as nn 6 | 7 | 8 | class Linear(nn.Linear): 9 | r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b` 10 | 11 | Wraps torch.nn.Linear to support AMP + torchscript usage by manually casting 12 | weight & bias to input.dtype to work around an issue w/ torch.addmm in this use case. 13 | """ 14 | def forward(self, input: torch.Tensor) -> torch.Tensor: 15 | if torch.jit.is_scripting(): 16 | bias = self.bias.to(dtype=input.dtype) if self.bias is not None else None 17 | return F.linear(input, self.weight.to(dtype=input.dtype), bias=bias) 18 | else: 19 | return F.linear(input, self.weight, self.bias) 20 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/median_pool.py: -------------------------------------------------------------------------------- 1 | """ Median Pool 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from .helpers import to_2tuple, to_4tuple 7 | 8 | 9 | class MedianPool2d(nn.Module): 10 | """ Median pool (usable as median filter when stride=1) module. 11 | 12 | Args: 13 | kernel_size: size of pooling kernel, int or 2-tuple 14 | stride: pool stride, int or 2-tuple 15 | padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad 16 | same: override padding and enforce same padding, boolean 17 | """ 18 | def __init__(self, kernel_size=3, stride=1, padding=0, same=False): 19 | super(MedianPool2d, self).__init__() 20 | self.k = to_2tuple(kernel_size) 21 | self.stride = to_2tuple(stride) 22 | self.padding = to_4tuple(padding) # convert to l, r, t, b 23 | self.same = same 24 | 25 | def _padding(self, x): 26 | if self.same: 27 | ih, iw = x.size()[2:] 28 | if ih % self.stride[0] == 0: 29 | ph = max(self.k[0] - self.stride[0], 0) 30 | else: 31 | ph = max(self.k[0] - (ih % self.stride[0]), 0) 32 | if iw % self.stride[1] == 0: 33 | pw = max(self.k[1] - self.stride[1], 0) 34 | else: 35 | pw = max(self.k[1] - (iw % self.stride[1]), 0) 36 | pl = pw // 2 37 | pr = pw - pl 38 | pt = ph // 2 39 | pb = ph - pt 40 | padding = (pl, pr, pt, pb) 41 | else: 42 | padding = self.padding 43 | return padding 44 | 45 | def forward(self, x): 46 | x = F.pad(x, self._padding(x), mode='reflect') 47 | x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) 48 | x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] 49 | return x 50 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/mixed_conv2d.py: -------------------------------------------------------------------------------- 1 | """ PyTorch Mixed Convolution 2 | 3 | Paper: MixConv: Mixed Depthwise Convolutional Kernels (https://arxiv.org/abs/1907.09595) 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import torch 9 | from torch import nn as nn 10 | 11 | from .conv2d_same import create_conv2d_pad 12 | 13 | 14 | def _split_channels(num_chan, num_groups): 15 | split = [num_chan // num_groups for _ in range(num_groups)] 16 | split[0] += num_chan - sum(split) 17 | return split 18 | 19 | 20 | class MixedConv2d(nn.ModuleDict): 21 | """ Mixed Grouped Convolution 22 | 23 | Based on MDConv and GroupedConv in MixNet impl: 24 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size=3, 27 | stride=1, padding='', dilation=1, depthwise=False, **kwargs): 28 | super(MixedConv2d, self).__init__() 29 | 30 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 31 | num_groups = len(kernel_size) 32 | in_splits = _split_channels(in_channels, num_groups) 33 | out_splits = _split_channels(out_channels, num_groups) 34 | self.in_channels = sum(in_splits) 35 | self.out_channels = sum(out_splits) 36 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 37 | conv_groups = in_ch if depthwise else 1 38 | # use add_module to keep key space clean 39 | self.add_module( 40 | str(idx), 41 | create_conv2d_pad( 42 | in_ch, out_ch, k, stride=stride, 43 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 44 | ) 45 | self.splits = in_splits 46 | 47 | def forward(self, x): 48 | x_split = torch.split(x, self.splits, 1) 49 | x_out = [c(x_split[i]) for i, c in enumerate(self.values())] 50 | x = torch.cat(x_out, 1) 51 | return x 52 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/mlp.py: -------------------------------------------------------------------------------- 1 | """ MLP module w/ dropout and configurable activation layer 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from torch import nn as nn 6 | 7 | from .helpers import to_2tuple 8 | 9 | 10 | class Mlp(nn.Module): 11 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 12 | """ 13 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): 14 | super().__init__() 15 | out_features = out_features or in_features 16 | hidden_features = hidden_features or in_features 17 | bias = to_2tuple(bias) 18 | drop_probs = to_2tuple(drop) 19 | 20 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) 21 | self.act = act_layer() 22 | self.drop1 = nn.Dropout(drop_probs[0]) 23 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) 24 | self.drop2 = nn.Dropout(drop_probs[1]) 25 | 26 | def forward(self, x): 27 | x = self.fc1(x) 28 | x = self.act(x) 29 | x = self.drop1(x) 30 | x = self.fc2(x) 31 | x = self.drop2(x) 32 | return x 33 | 34 | 35 | class GluMlp(nn.Module): 36 | """ MLP w/ GLU style gating 37 | See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 38 | """ 39 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.): 40 | super().__init__() 41 | out_features = out_features or in_features 42 | hidden_features = hidden_features or in_features 43 | assert hidden_features % 2 == 0 44 | bias = to_2tuple(bias) 45 | drop_probs = to_2tuple(drop) 46 | 47 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) 48 | self.act = act_layer() 49 | self.drop1 = nn.Dropout(drop_probs[0]) 50 | self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1]) 51 | self.drop2 = nn.Dropout(drop_probs[1]) 52 | 53 | def init_weights(self): 54 | # override init of fc1 w/ gate portion set to weight near zero, bias=1 55 | fc1_mid = self.fc1.bias.shape[0] // 2 56 | nn.init.ones_(self.fc1.bias[fc1_mid:]) 57 | nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) 58 | 59 | def forward(self, x): 60 | x = self.fc1(x) 61 | x, gates = x.chunk(2, dim=-1) 62 | x = x * self.act(gates) 63 | x = self.drop1(x) 64 | x = self.fc2(x) 65 | x = self.drop2(x) 66 | return x 67 | 68 | 69 | class GatedMlp(nn.Module): 70 | """ MLP as used in gMLP 71 | """ 72 | def __init__( 73 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, 74 | gate_layer=None, bias=True, drop=0.): 75 | super().__init__() 76 | out_features = out_features or in_features 77 | hidden_features = hidden_features or in_features 78 | bias = to_2tuple(bias) 79 | drop_probs = to_2tuple(drop) 80 | 81 | self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) 82 | self.act = act_layer() 83 | self.drop1 = nn.Dropout(drop_probs[0]) 84 | if gate_layer is not None: 85 | assert hidden_features % 2 == 0 86 | self.gate = gate_layer(hidden_features) 87 | hidden_features = hidden_features // 2 # FIXME base reduction on gate property? 88 | else: 89 | self.gate = nn.Identity() 90 | self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) 91 | self.drop2 = nn.Dropout(drop_probs[1]) 92 | 93 | def forward(self, x): 94 | x = self.fc1(x) 95 | x = self.act(x) 96 | x = self.drop1(x) 97 | x = self.gate(x) 98 | x = self.fc2(x) 99 | x = self.drop2(x) 100 | return x 101 | 102 | 103 | class ConvMlp(nn.Module): 104 | """ MLP using 1x1 convs that keeps spatial dims 105 | """ 106 | def __init__( 107 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, 108 | norm_layer=None, bias=True, drop=0.): 109 | super().__init__() 110 | out_features = out_features or in_features 111 | hidden_features = hidden_features or in_features 112 | bias = to_2tuple(bias) 113 | 114 | self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) 115 | self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() 116 | self.act = act_layer() 117 | self.drop = nn.Dropout(drop) 118 | self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) 119 | 120 | def forward(self, x): 121 | x = self.fc1(x) 122 | x = self.norm(x) 123 | x = self.act(x) 124 | x = self.drop(x) 125 | x = self.fc2(x) 126 | return x 127 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/norm.py: -------------------------------------------------------------------------------- 1 | """ Normalization layers and wrappers 2 | 3 | Norm layer definitions that support fast norm and consistent channel arg order (always first arg). 4 | 5 | Hacked together by / Copyright 2022 Ross Wightman 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm 13 | 14 | 15 | class GroupNorm(nn.GroupNorm): 16 | def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True): 17 | # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN 18 | super().__init__(num_groups, num_channels, eps=eps, affine=affine) 19 | self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) 20 | 21 | def forward(self, x): 22 | if self.fast_norm: 23 | return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 24 | else: 25 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 26 | 27 | 28 | class GroupNorm1(nn.GroupNorm): 29 | """ Group Normalization with 1 group. 30 | Input: tensor in shape [B, C, *] 31 | """ 32 | 33 | def __init__(self, num_channels, **kwargs): 34 | super().__init__(1, num_channels, **kwargs) 35 | self.fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) 36 | 37 | def forward(self, x: torch.Tensor) -> torch.Tensor: 38 | if self.fast_norm: 39 | return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 40 | else: 41 | return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) 42 | 43 | 44 | class LayerNorm(nn.LayerNorm): 45 | """ LayerNorm w/ fast norm option 46 | """ 47 | def __init__(self, num_channels, eps=1e-6, affine=True): 48 | super().__init__(num_channels, eps=eps, elementwise_affine=affine) 49 | self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) 50 | 51 | def forward(self, x: torch.Tensor) -> torch.Tensor: 52 | if self._fast_norm: 53 | x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 54 | else: 55 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 56 | return x 57 | 58 | 59 | class LayerNorm2d(nn.LayerNorm): 60 | """ LayerNorm for channels of '2D' spatial NCHW tensors """ 61 | def __init__(self, num_channels, eps=1e-6, affine=True): 62 | super().__init__(num_channels, eps=eps, elementwise_affine=affine) 63 | self._fast_norm = is_fast_norm() # can't script unless we have these flags here (no globals) 64 | 65 | def forward(self, x: torch.Tensor) -> torch.Tensor: 66 | x = x.permute(0, 2, 3, 1) 67 | if self._fast_norm: 68 | x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 69 | else: 70 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 71 | x = x.permute(0, 3, 1, 2) 72 | return x 73 | 74 | 75 | def _is_contiguous(tensor: torch.Tensor) -> bool: 76 | # jit is oh so lovely :/ 77 | if torch.jit.is_scripting(): 78 | return tensor.is_contiguous() 79 | else: 80 | return tensor.is_contiguous(memory_format=torch.contiguous_format) 81 | 82 | 83 | @torch.jit.script 84 | def _layer_norm_cf(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): 85 | s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) 86 | x = (x - u) * torch.rsqrt(s + eps) 87 | x = x * weight[:, None, None] + bias[:, None, None] 88 | return x 89 | 90 | 91 | def _layer_norm_cf_sqm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float): 92 | u = x.mean(dim=1, keepdim=True) 93 | s = ((x * x).mean(dim=1, keepdim=True) - (u * u)).clamp(0) 94 | x = (x - u) * torch.rsqrt(s + eps) 95 | x = x * weight.view(1, -1, 1, 1) + bias.view(1, -1, 1, 1) 96 | return x 97 | 98 | 99 | class LayerNormExp2d(nn.LayerNorm): 100 | """ LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). 101 | 102 | Experimental implementation w/ manual norm for tensors non-contiguous tensors. 103 | 104 | This improves throughput in some scenarios (tested on Ampere GPU), esp w/ channels_last 105 | layout. However, benefits are not always clear and can perform worse on other GPUs. 106 | """ 107 | 108 | def __init__(self, num_channels, eps=1e-6): 109 | super().__init__(num_channels, eps=eps) 110 | 111 | def forward(self, x) -> torch.Tensor: 112 | if _is_contiguous(x): 113 | x = F.layer_norm( 114 | x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) 115 | else: 116 | x = _layer_norm_cf(x, self.weight, self.bias, self.eps) 117 | return x 118 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/padding.py: -------------------------------------------------------------------------------- 1 | """ Padding Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import math 6 | from typing import List, Tuple 7 | 8 | import torch.nn.functional as F 9 | 10 | 11 | # Calculate symmetric padding for a convolution 12 | def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 13 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 14 | return padding 15 | 16 | 17 | # Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution 18 | def get_same_padding(x: int, k: int, s: int, d: int): 19 | return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0) 20 | 21 | 22 | # Can SAME padding for given args be done statically? 23 | def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_): 24 | return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 25 | 26 | 27 | # Dynamically pad input x with 'SAME' padding for conv with specified args 28 | def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0): 29 | ih, iw = x.size()[-2:] 30 | pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1]) 31 | if pad_h > 0 or pad_w > 0: 32 | x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value) 33 | return x 34 | 35 | 36 | def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]: 37 | dynamic = False 38 | if isinstance(padding, str): 39 | # for any string padding, the padding will be calculated for you, one of three ways 40 | padding = padding.lower() 41 | if padding == 'same': 42 | # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact 43 | if is_static_pad(kernel_size, **kwargs): 44 | # static case, no extra overhead 45 | padding = get_padding(kernel_size, **kwargs) 46 | else: 47 | # dynamic 'SAME' padding, has runtime/GPU memory overhead 48 | padding = 0 49 | dynamic = True 50 | elif padding == 'valid': 51 | # 'VALID' padding, same as padding=0 52 | padding = 0 53 | else: 54 | # Default to PyTorch style 'same'-ish symmetric padding 55 | padding = get_padding(kernel_size, **kwargs) 56 | return padding, dynamic 57 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/patch_embed.py: -------------------------------------------------------------------------------- 1 | """ Image to Patch Embedding using Conv2d 2 | 3 | A convolution based approach to patchifying a 2D image w/ embedding projection. 4 | 5 | Based on the impl in https://github.com/google-research/vision_transformer 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | from torch import nn as nn 10 | 11 | from .helpers import to_2tuple 12 | from .trace_utils import _assert 13 | 14 | 15 | class PatchEmbed(nn.Module): 16 | """ 2D Image to Patch Embedding 17 | """ 18 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): 19 | super().__init__() 20 | img_size = to_2tuple(img_size) 21 | patch_size = to_2tuple(patch_size) 22 | self.img_size = img_size 23 | self.patch_size = patch_size 24 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 25 | self.num_patches = self.grid_size[0] * self.grid_size[1] 26 | self.flatten = flatten 27 | 28 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 29 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 30 | 31 | def forward(self, x): 32 | B, C, H, W = x.shape 33 | _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") 34 | _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") 35 | x = self.proj(x) 36 | if self.flatten: 37 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 38 | x = self.norm(x) 39 | return x 40 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/pool2d_same.py: -------------------------------------------------------------------------------- 1 | """ AvgPool2d w/ Same Padding 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from typing import List, Tuple, Optional 9 | 10 | from .helpers import to_2tuple 11 | from .padding import pad_same, get_padding_value 12 | 13 | 14 | def avg_pool2d_same(x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 15 | ceil_mode: bool = False, count_include_pad: bool = True): 16 | # FIXME how to deal with count_include_pad vs not for external padding? 17 | x = pad_same(x, kernel_size, stride) 18 | return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 19 | 20 | 21 | class AvgPool2dSame(nn.AvgPool2d): 22 | """ Tensorflow like 'SAME' wrapper for 2D average pooling 23 | """ 24 | def __init__(self, kernel_size: int, stride=None, padding=0, ceil_mode=False, count_include_pad=True): 25 | kernel_size = to_2tuple(kernel_size) 26 | stride = to_2tuple(stride) 27 | super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) 28 | 29 | def forward(self, x): 30 | x = pad_same(x, self.kernel_size, self.stride) 31 | return F.avg_pool2d( 32 | x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) 33 | 34 | 35 | def max_pool2d_same( 36 | x, kernel_size: List[int], stride: List[int], padding: List[int] = (0, 0), 37 | dilation: List[int] = (1, 1), ceil_mode: bool = False): 38 | x = pad_same(x, kernel_size, stride, value=-float('inf')) 39 | return F.max_pool2d(x, kernel_size, stride, (0, 0), dilation, ceil_mode) 40 | 41 | 42 | class MaxPool2dSame(nn.MaxPool2d): 43 | """ Tensorflow like 'SAME' wrapper for 2D max pooling 44 | """ 45 | def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False): 46 | kernel_size = to_2tuple(kernel_size) 47 | stride = to_2tuple(stride) 48 | dilation = to_2tuple(dilation) 49 | super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode) 50 | 51 | def forward(self, x): 52 | x = pad_same(x, self.kernel_size, self.stride, value=-float('inf')) 53 | return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode) 54 | 55 | 56 | def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): 57 | stride = stride or kernel_size 58 | padding = kwargs.pop('padding', '') 59 | padding, is_dynamic = get_padding_value(padding, kernel_size, stride=stride, **kwargs) 60 | if is_dynamic: 61 | if pool_type == 'avg': 62 | return AvgPool2dSame(kernel_size, stride=stride, **kwargs) 63 | elif pool_type == 'max': 64 | return MaxPool2dSame(kernel_size, stride=stride, **kwargs) 65 | else: 66 | assert False, f'Unsupported pool type {pool_type}' 67 | else: 68 | if pool_type == 'avg': 69 | return nn.AvgPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 70 | elif pool_type == 'max': 71 | return nn.MaxPool2d(kernel_size, stride=stride, padding=padding, **kwargs) 72 | else: 73 | assert False, f'Unsupported pool type {pool_type}' 74 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/separable_conv.py: -------------------------------------------------------------------------------- 1 | """ Depthwise Separable Conv Modules 2 | 3 | Basic DWS convs. Other variations of DWS exist with batch norm or activations between the 4 | DW and PW convs such as the Depthwise modules in MobileNetV2 / EfficientNet and Xception. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | from torch import nn as nn 9 | 10 | from .create_conv2d import create_conv2d 11 | from .create_norm_act import get_norm_act_layer 12 | 13 | 14 | class SeparableConvNormAct(nn.Module): 15 | """ Separable Conv w/ trailing Norm and Activation 16 | """ 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 18 | channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, 19 | apply_act=True, drop_layer=None): 20 | super(SeparableConvNormAct, self).__init__() 21 | 22 | self.conv_dw = create_conv2d( 23 | in_channels, int(in_channels * channel_multiplier), kernel_size, 24 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 25 | 26 | self.conv_pw = create_conv2d( 27 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 28 | 29 | norm_act_layer = get_norm_act_layer(norm_layer, act_layer) 30 | norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} 31 | self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) 32 | 33 | @property 34 | def in_channels(self): 35 | return self.conv_dw.in_channels 36 | 37 | @property 38 | def out_channels(self): 39 | return self.conv_pw.out_channels 40 | 41 | def forward(self, x): 42 | x = self.conv_dw(x) 43 | x = self.conv_pw(x) 44 | x = self.bn(x) 45 | return x 46 | 47 | 48 | SeparableConvBnAct = SeparableConvNormAct 49 | 50 | 51 | class SeparableConv2d(nn.Module): 52 | """ Separable Conv 53 | """ 54 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, 55 | channel_multiplier=1.0, pw_kernel_size=1): 56 | super(SeparableConv2d, self).__init__() 57 | 58 | self.conv_dw = create_conv2d( 59 | in_channels, int(in_channels * channel_multiplier), kernel_size, 60 | stride=stride, dilation=dilation, padding=padding, depthwise=True) 61 | 62 | self.conv_pw = create_conv2d( 63 | int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) 64 | 65 | @property 66 | def in_channels(self): 67 | return self.conv_dw.in_channels 68 | 69 | @property 70 | def out_channels(self): 71 | return self.conv_pw.out_channels 72 | 73 | def forward(self, x): 74 | x = self.conv_dw(x) 75 | x = self.conv_pw(x) 76 | return x 77 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/space_to_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SpaceToDepth(nn.Module): 6 | def __init__(self, block_size=4): 7 | super().__init__() 8 | assert block_size == 4 9 | self.bs = block_size 10 | 11 | def forward(self, x): 12 | N, C, H, W = x.size() 13 | x = x.view(N, C, H // self.bs, self.bs, W // self.bs, self.bs) # (N, C, H//bs, bs, W//bs, bs) 14 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 15 | x = x.view(N, C * (self.bs ** 2), H // self.bs, W // self.bs) # (N, C*bs^2, H//bs, W//bs) 16 | return x 17 | 18 | 19 | @torch.jit.script 20 | class SpaceToDepthJit(object): 21 | def __call__(self, x: torch.Tensor): 22 | # assuming hard-coded that block_size==4 for acceleration 23 | N, C, H, W = x.size() 24 | x = x.view(N, C, H // 4, 4, W // 4, 4) # (N, C, H//bs, bs, W//bs, bs) 25 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 26 | x = x.view(N, C * 16, H // 4, W // 4) # (N, C*bs^2, H//bs, W//bs) 27 | return x 28 | 29 | 30 | class SpaceToDepthModule(nn.Module): 31 | def __init__(self, no_jit=False): 32 | super().__init__() 33 | if not no_jit: 34 | self.op = SpaceToDepthJit() 35 | else: 36 | self.op = SpaceToDepth() 37 | 38 | def forward(self, x): 39 | return self.op(x) 40 | 41 | 42 | class DepthToSpace(nn.Module): 43 | 44 | def __init__(self, block_size): 45 | super().__init__() 46 | self.bs = block_size 47 | 48 | def forward(self, x): 49 | N, C, H, W = x.size() 50 | x = x.view(N, self.bs, self.bs, C // (self.bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 51 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 52 | x = x.view(N, C // (self.bs ** 2), H * self.bs, W * self.bs) # (N, C//bs^2, H * bs, W * bs) 53 | return x 54 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/split_attn.py: -------------------------------------------------------------------------------- 1 | """ Split Attention Conv2d (for ResNeSt Models) 2 | 3 | Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955 4 | 5 | Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt 6 | 7 | Modified for torchscript compat, performance, and consistency with timm by Ross Wightman 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from .helpers import make_divisible 14 | 15 | 16 | class RadixSoftmax(nn.Module): 17 | def __init__(self, radix, cardinality): 18 | super(RadixSoftmax, self).__init__() 19 | self.radix = radix 20 | self.cardinality = cardinality 21 | 22 | def forward(self, x): 23 | batch = x.size(0) 24 | if self.radix > 1: 25 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 26 | x = F.softmax(x, dim=1) 27 | x = x.reshape(batch, -1) 28 | else: 29 | x = torch.sigmoid(x) 30 | return x 31 | 32 | 33 | class SplitAttn(nn.Module): 34 | """Split-Attention (aka Splat) 35 | """ 36 | def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, 37 | dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, 38 | act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs): 39 | super(SplitAttn, self).__init__() 40 | out_channels = out_channels or in_channels 41 | self.radix = radix 42 | mid_chs = out_channels * radix 43 | if rd_channels is None: 44 | attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) 45 | else: 46 | attn_chs = rd_channels * radix 47 | 48 | padding = kernel_size // 2 if padding is None else padding 49 | self.conv = nn.Conv2d( 50 | in_channels, mid_chs, kernel_size, stride, padding, dilation, 51 | groups=groups * radix, bias=bias, **kwargs) 52 | self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() 53 | self.drop = drop_layer() if drop_layer is not None else nn.Identity() 54 | self.act0 = act_layer(inplace=True) 55 | self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) 56 | self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() 57 | self.act1 = act_layer(inplace=True) 58 | self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) 59 | self.rsoftmax = RadixSoftmax(radix, groups) 60 | 61 | def forward(self, x): 62 | x = self.conv(x) 63 | x = self.bn0(x) 64 | x = self.drop(x) 65 | x = self.act0(x) 66 | 67 | B, RC, H, W = x.shape 68 | if self.radix > 1: 69 | x = x.reshape((B, self.radix, RC // self.radix, H, W)) 70 | x_gap = x.sum(dim=1) 71 | else: 72 | x_gap = x 73 | x_gap = x_gap.mean((2, 3), keepdim=True) 74 | x_gap = self.fc1(x_gap) 75 | x_gap = self.bn1(x_gap) 76 | x_gap = self.act1(x_gap) 77 | x_attn = self.fc2(x_gap) 78 | 79 | x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1) 80 | if self.radix > 1: 81 | out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1) 82 | else: 83 | out = x * x_attn 84 | return out.contiguous() 85 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/split_batchnorm.py: -------------------------------------------------------------------------------- 1 | """ Split BatchNorm 2 | 3 | A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through 4 | a separate BN layer. The first split is passed through the parent BN layers with weight/bias 5 | keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn' 6 | namespace. 7 | 8 | This allows easily removing the auxiliary BN layers after training to efficiently 9 | achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2, 10 | 'Disentangled Learning via An Auxiliary BN' 11 | 12 | Hacked together by / Copyright 2020 Ross Wightman 13 | """ 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class SplitBatchNorm2d(torch.nn.BatchNorm2d): 19 | 20 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 21 | track_running_stats=True, num_splits=2): 22 | super().__init__(num_features, eps, momentum, affine, track_running_stats) 23 | assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' 24 | self.num_splits = num_splits 25 | self.aux_bn = nn.ModuleList([ 26 | nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) 27 | 28 | def forward(self, input: torch.Tensor): 29 | if self.training: # aux BN only relevant while training 30 | split_size = input.shape[0] // self.num_splits 31 | assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" 32 | split_input = input.split(split_size) 33 | x = [super().forward(split_input[0])] 34 | for i, a in enumerate(self.aux_bn): 35 | x.append(a(split_input[i + 1])) 36 | return torch.cat(x, dim=0) 37 | else: 38 | return super().forward(input) 39 | 40 | 41 | def convert_splitbn_model(module, num_splits=2): 42 | """ 43 | Recursively traverse module and its children to replace all instances of 44 | ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. 45 | Args: 46 | module (torch.nn.Module): input module 47 | num_splits: number of separate batchnorm layers to split input across 48 | Example:: 49 | >>> # model is an instance of torch.nn.Module 50 | >>> model = timm.models.convert_splitbn_model(model, num_splits=2) 51 | """ 52 | mod = module 53 | if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): 54 | return module 55 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 56 | mod = SplitBatchNorm2d( 57 | module.num_features, module.eps, module.momentum, module.affine, 58 | module.track_running_stats, num_splits=num_splits) 59 | mod.running_mean = module.running_mean 60 | mod.running_var = module.running_var 61 | mod.num_batches_tracked = module.num_batches_tracked 62 | if module.affine: 63 | mod.weight.data = module.weight.data.clone().detach() 64 | mod.bias.data = module.bias.data.clone().detach() 65 | for aux in mod.aux_bn: 66 | aux.running_mean = module.running_mean.clone() 67 | aux.running_var = module.running_var.clone() 68 | aux.num_batches_tracked = module.num_batches_tracked.clone() 69 | if module.affine: 70 | aux.weight.data = module.weight.data.clone().detach() 71 | aux.bias.data = module.bias.data.clone().detach() 72 | for name, child in module.named_children(): 73 | mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) 74 | del module 75 | return mod 76 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/squeeze_excite.py: -------------------------------------------------------------------------------- 1 | """ Squeeze-and-Excitation Channel Attention 2 | 3 | An SE implementation originally based on PyTorch SE-Net impl. 4 | Has since evolved with additional functionality / configuration. 5 | 6 | Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 7 | 8 | Also included is Effective Squeeze-Excitation (ESE). 9 | Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 10 | 11 | Hacked together by / Copyright 2021 Ross Wightman 12 | """ 13 | from torch import nn as nn 14 | 15 | from .create_act import create_act_layer 16 | from .helpers import make_divisible 17 | 18 | 19 | class SEModule(nn.Module): 20 | """ SE Module as defined in original SE-Nets with a few additions 21 | Additions include: 22 | * divisor can be specified to keep channels % div == 0 (default: 8) 23 | * reduction channels can be specified directly by arg (if rd_channels is set) 24 | * reduction channels can be specified by float rd_ratio (default: 1/16) 25 | * global max pooling can be added to the squeeze aggregation 26 | * customizable activation, normalization, and gate layer 27 | """ 28 | def __init__( 29 | self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, 30 | bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): 31 | super(SEModule, self).__init__() 32 | self.add_maxpool = add_maxpool 33 | if not rd_channels: 34 | rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) 35 | self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=bias) 36 | self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() 37 | self.act = create_act_layer(act_layer, inplace=True) 38 | self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=bias) 39 | self.gate = create_act_layer(gate_layer) 40 | 41 | def forward(self, x): 42 | x_se = x.mean((2, 3), keepdim=True) 43 | if self.add_maxpool: 44 | # experimental codepath, may remove or change 45 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 46 | x_se = self.fc1(x_se) 47 | x_se = self.act(self.bn(x_se)) 48 | x_se = self.fc2(x_se) 49 | return x * self.gate(x_se) 50 | 51 | 52 | SqueezeExcite = SEModule # alias 53 | 54 | 55 | class EffectiveSEModule(nn.Module): 56 | """ 'Effective Squeeze-Excitation 57 | From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 58 | """ 59 | def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): 60 | super(EffectiveSEModule, self).__init__() 61 | self.add_maxpool = add_maxpool 62 | self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 63 | self.gate = create_act_layer(gate_layer) 64 | 65 | def forward(self, x): 66 | x_se = x.mean((2, 3), keepdim=True) 67 | if self.add_maxpool: 68 | # experimental codepath, may remove or change 69 | x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) 70 | x_se = self.fc(x_se) 71 | return x * self.gate(x_se) 72 | 73 | 74 | EffectiveSqueezeExcite = EffectiveSEModule # alias 75 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/test_time_pool.py: -------------------------------------------------------------------------------- 1 | """ Test Time Pooling (Average-Max Pool) 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | 6 | import logging 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from .adaptive_avgmax_pool import adaptive_avgmax_pool2d 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | class TestTimePoolHead(nn.Module): 17 | def __init__(self, base, original_pool=7): 18 | super(TestTimePoolHead, self).__init__() 19 | self.base = base 20 | self.original_pool = original_pool 21 | base_fc = self.base.get_classifier() 22 | if isinstance(base_fc, nn.Conv2d): 23 | self.fc = base_fc 24 | else: 25 | self.fc = nn.Conv2d( 26 | self.base.num_features, self.base.num_classes, kernel_size=1, bias=True) 27 | self.fc.weight.data.copy_(base_fc.weight.data.view(self.fc.weight.size())) 28 | self.fc.bias.data.copy_(base_fc.bias.data.view(self.fc.bias.size())) 29 | self.base.reset_classifier(0) # delete original fc layer 30 | 31 | def forward(self, x): 32 | x = self.base.forward_features(x) 33 | x = F.avg_pool2d(x, kernel_size=self.original_pool, stride=1) 34 | x = self.fc(x) 35 | x = adaptive_avgmax_pool2d(x, 1) 36 | return x.view(x.size(0), -1) 37 | 38 | 39 | def apply_test_time_pool(model, config, use_test_size=False): 40 | test_time_pool = False 41 | if not hasattr(model, 'default_cfg') or not model.default_cfg: 42 | return model, False 43 | if use_test_size and 'test_input_size' in model.default_cfg: 44 | df_input_size = model.default_cfg['test_input_size'] 45 | else: 46 | df_input_size = model.default_cfg['input_size'] 47 | if config['input_size'][-1] > df_input_size[-1] and config['input_size'][-2] > df_input_size[-2]: 48 | _logger.info('Target input size %s > pretrained default %s, using test time pooling' % 49 | (str(config['input_size'][-2:]), str(df_input_size[-2:]))) 50 | model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) 51 | test_time_pool = True 52 | return model, test_time_pool 53 | -------------------------------------------------------------------------------- /models/layers/maxvit/layers/trace_utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch import _assert 3 | except ImportError: 4 | def _assert(condition: bool, message: str): 5 | assert condition, message 6 | 7 | 8 | def _float_to_int(x: float) -> int: 9 | """ 10 | Symbolic tracing helper to substitute for inbuilt `int`. 11 | Hint: Inbuilt `int` can't accept an argument of type `Proxy` 12 | """ 13 | return int(x) 14 | -------------------------------------------------------------------------------- /models/layers/rnn.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch as th 4 | import torch.nn as nn 5 | 6 | 7 | class DWSConvLSTM2d(nn.Module): 8 | """LSTM with (depthwise-separable) Conv option in NCHW [channel-first] format. 9 | """ 10 | 11 | def __init__(self, 12 | dim: int, 13 | dws_conv: bool = True, # RVT uses False 14 | dws_conv_only_hidden: bool = True, 15 | dws_conv_kernel_size: int = 3, 16 | cell_update_dropout: float = 0.): 17 | super().__init__() 18 | assert isinstance(dws_conv, bool) 19 | assert isinstance(dws_conv_only_hidden, bool) 20 | self.dim = dim 21 | 22 | xh_dim = dim * 2 23 | gates_dim = dim * 4 24 | conv3x3_dws_dim = dim if dws_conv_only_hidden else xh_dim 25 | # `self.conv3x3_dws` is actually just Identity mapping in RVT 26 | self.conv3x3_dws = nn.Conv2d(in_channels=conv3x3_dws_dim, 27 | out_channels=conv3x3_dws_dim, 28 | kernel_size=dws_conv_kernel_size, 29 | padding=dws_conv_kernel_size // 2, 30 | groups=conv3x3_dws_dim) if dws_conv else nn.Identity() 31 | self.conv1x1 = nn.Conv2d(in_channels=xh_dim, 32 | out_channels=gates_dim, 33 | kernel_size=1) 34 | self.conv_only_hidden = dws_conv_only_hidden 35 | self.cell_update_dropout = nn.Dropout(p=cell_update_dropout) 36 | 37 | def forward(self, x: th.Tensor, h_and_c_previous: Optional[Tuple[th.Tensor, th.Tensor]] = None) \ 38 | -> Tuple[th.Tensor, th.Tensor]: 39 | """ 40 | :param x: (N C H W), new features extracted at current time step 41 | :param h_and_c_previous: ((N C H W), (N C H W)), prev RNN states 42 | :return: ((N C H W), (N C H W)) 43 | """ 44 | if h_and_c_previous is None: 45 | # generate zero states 46 | hidden = th.zeros_like(x) 47 | cell = th.zeros_like(x) 48 | h_and_c_previous = (hidden, cell) 49 | h_tm1, c_tm1 = h_and_c_previous 50 | 51 | if self.conv_only_hidden: 52 | h_tm1 = self.conv3x3_dws(h_tm1) 53 | xh = th.cat((x, h_tm1), dim=1) # (N 2C H W) 54 | if not self.conv_only_hidden: 55 | xh = self.conv3x3_dws(xh) 56 | mix = self.conv1x1(xh) 57 | 58 | gates, cell_input = th.tensor_split(mix, [self.dim * 3], dim=1) 59 | assert gates.shape[1] == cell_input.shape[1] * 3 60 | 61 | gates = th.sigmoid(gates) 62 | forget_gate, input_gate, output_gate = th.tensor_split(gates, 3, dim=1) 63 | assert forget_gate.shape == input_gate.shape == output_gate.shape 64 | 65 | cell_input = self.cell_update_dropout(th.tanh(cell_input)) 66 | 67 | c_t = forget_gate * c_tm1 + input_gate * cell_input 68 | h_t = output_gate * th.tanh(c_t) 69 | 70 | return h_t, c_t 71 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/modules/__init__.py -------------------------------------------------------------------------------- /modules/tracking/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear import LinearTracker 2 | -------------------------------------------------------------------------------- /modules/tracking/tracker.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy as np 4 | 5 | 6 | class Tracker(object): 7 | """Base class for multi-object video tracker.""" 8 | 9 | def __init__(self, img_hw: Tuple[int, int], iou_threshold: float = 0.45): 10 | """ 11 | Sets key parameters for the tracker. 12 | """ 13 | self.img_hw = img_hw 14 | self.iou_threshold = iou_threshold 15 | 16 | self.trackers, self.prev_trackers, self.bbox_idx2tracker = [], [], {} 17 | self.track_count, self.bbox_count = 0, 0 18 | self.done = False 19 | 20 | def update(self, 21 | dets: np.ndarray = np.empty((0, 4)), 22 | is_gt: np.ndarray = np.empty((0, ))): 23 | raise NotImplementedError 24 | 25 | def _del_tracker(self, idx: int, done: bool = True): 26 | """Delete self.trackers[idx], move it to self.del_trackers.""" 27 | tracker = self.trackers.pop(idx) 28 | tracker.finish(done=done) 29 | self.prev_trackers.append(tracker) 30 | for idx in tracker.bbox_idx: 31 | self.bbox_idx2tracker[idx] = tracker 32 | 33 | def finish(self): 34 | """Delete all remaining trackers.""" 35 | for idx in reversed(range(len(self.trackers))): 36 | # don't filter out unfinished tracklets! 37 | self._del_tracker(idx, done=False) 38 | self.done = True 39 | 40 | def new(self): 41 | """Create a new tracker.""" 42 | raise NotImplementedError 43 | 44 | def get_bbox_tracker(self, bbox_idx: int): 45 | """Get the bbox_tracker for the given bbox_idx.""" 46 | assert self.done, 'Please call Tracker.finish() first.' 47 | return self.bbox_idx2tracker[bbox_idx] 48 | -------------------------------------------------------------------------------- /modules/tracking/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | 6 | 7 | def greedy_matching(cost_matrix: np.ndarray, idx_lst: np.ndarray, 8 | thresh: float = 0.0) -> np.ndarray: 9 | cost_matrix = copy.deepcopy(cost_matrix) 10 | matched_indices = [] 11 | assert len(idx_lst) == cost_matrix.shape[0] 12 | for i in idx_lst: 13 | if cost_matrix[i].max() < thresh: 14 | continue 15 | j = np.argmax(cost_matrix[i]) 16 | cost_matrix[:, j] = -np.inf 17 | matched_indices.append([i, j]) 18 | return np.array(matched_indices) # (N, 2) 19 | 20 | 21 | def iou_batch_xywh(bb_test: np.ndarray, bb_gt: np.ndarray) -> np.ndarray: 22 | """ 23 | Computes IOU between two bboxes in the form [x,y,w,h,(cls_id)] 24 | both bbox are in shape (N, 4/5) where N is the number of bboxes 25 | If class_id is provided, take it into account by ignoring the IOU between 26 | bboxes of different classes. 27 | """ 28 | bb_gt = np.expand_dims(bb_gt, 0) 29 | bb_test = np.expand_dims(bb_test, 1) 30 | 31 | xx1 = np.maximum(bb_test[..., 0] - bb_test[..., 2] / 2., 32 | bb_gt[..., 0] - bb_gt[..., 2] / 2.) 33 | yy1 = np.maximum(bb_test[..., 1] - bb_test[..., 3] / 2., 34 | bb_gt[..., 1] - bb_gt[..., 3] / 2.) 35 | xx2 = np.minimum(bb_test[..., 0] + bb_test[..., 2] / 2., 36 | bb_gt[..., 0] + bb_gt[..., 2] / 2.) 37 | yy2 = np.minimum(bb_test[..., 1] + bb_test[..., 3] / 2., 38 | bb_gt[..., 1] + bb_gt[..., 3] / 2.) 39 | w = np.maximum(0., xx2 - xx1) 40 | h = np.maximum(0., yy2 - yy1) 41 | wh = w * h 42 | o = wh / ( 43 | bb_test[..., 2] * bb_test[..., 3] + bb_gt[..., 2] * bb_gt[..., 3] - wh) 44 | 45 | # set IoU of between different class objects to 0 46 | if bb_test.shape[-1] == 5 and bb_gt.shape[-1] == 5: 47 | o[bb_gt[..., 4] != bb_test[..., 4]] = 0. 48 | 49 | return o 50 | 51 | 52 | def xyxy2xywh(bbox: np.ndarray) -> np.ndarray: 53 | """ 54 | Takes a bbox in the form [x1,y1,x2,y2] and returns a new bbox in the form 55 | [x,y,w,h] where x,y is the center and w,h are the width and height 56 | """ 57 | x1, y1, x2, y2 = bbox 58 | bbox = [(x1 + x2) / 2., (y1 + y2) / 2., x2 - x1, y2 - y1] 59 | return np.array(bbox) 60 | 61 | 62 | def xywh2xyxy(bbox: np.ndarray) -> np.ndarray: 63 | """ 64 | Takes a bounding box in the form [x,y,w,h] and returns it in the form 65 | [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right 66 | """ 67 | x, y, w, h = bbox 68 | bbox = [x - w / 2., y - h / 2., x + w / 2., y + h / 2.] 69 | return np.array(bbox) 70 | 71 | 72 | def clamp_bbox(bbox: np.ndarray, 73 | img_hw: Tuple[int, int], 74 | format_: str = 'xyxy') -> np.ndarray: 75 | """ 76 | Clamp bbox to image boundaries. 77 | """ 78 | # bbox: (4,) or (1, 4) or (4, 1) 79 | bbox_shape = bbox.shape 80 | bbox = bbox.squeeze() # to (4,) 81 | H, W = img_hw 82 | assert format_ in ['xyxy', 'xywh'] 83 | if format_ == 'xywh': 84 | bbox = xywh2xyxy(bbox) 85 | x1_, y1_, x2_, y2_ = bbox 86 | x1 = np.clip(x1_, 0., W - 1.) 87 | x2 = np.clip(x2_, 0., W - 1.) 88 | y1 = np.clip(y1_, 0., H - 1.) 89 | y2 = np.clip(y2_, 0., H - 1.) 90 | bbox = np.array([x1, y1, x2, y2]) 91 | clamp_top, clamp_down = (y1 != y1_), (y2 != y2_) 92 | clamp_left, clamp_right = (x1 != x1_), (x2 != x2_) 93 | if format_ == 'xywh': 94 | bbox = xyxy2xywh(bbox) 95 | bbox = bbox.reshape(bbox_shape) 96 | return bbox, clamp_top, clamp_down, clamp_left, clamp_right 97 | -------------------------------------------------------------------------------- /modules/utils/fetch.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from omegaconf import DictConfig 3 | 4 | from modules.data.genx import DataModule as genx_data_module 5 | from modules.detection import Module as rnn_det_module 6 | from modules.utils.tta import TTAModule 7 | from modules.pseudo_labeler import PseudoLabeler 8 | 9 | 10 | def fetch_model_module(config: DictConfig) -> pl.LightningModule: 11 | """Build model.""" 12 | model_str = config.model.name 13 | if model_str == 'rnndet': 14 | if config.get('tta', {}).get('enable', False): 15 | return TTAModule(config) 16 | return rnn_det_module(config) 17 | elif model_str == 'pseudo_labeler': 18 | return PseudoLabeler(config) 19 | raise NotImplementedError 20 | 21 | 22 | def fetch_data_module(config: DictConfig) -> pl.LightningDataModule: 23 | """Build dataloaders.""" 24 | batch_size_train = config.batch_size.train 25 | batch_size_eval = config.batch_size.eval 26 | num_workers_generic = config.hardware.get('num_workers', None) 27 | num_workers_train = config.hardware.num_workers.get('train', num_workers_generic) 28 | num_workers_eval = config.hardware.num_workers.get('eval', num_workers_generic) 29 | dataset_str = config.dataset.name 30 | if dataset_str in {'gen1', 'gen4'}: 31 | return genx_data_module(config.dataset, 32 | num_workers_train=num_workers_train, 33 | num_workers_eval=num_workers_eval, 34 | batch_size_train=batch_size_train, 35 | batch_size_eval=batch_size_eval) 36 | raise NotImplementedError 37 | -------------------------------------------------------------------------------- /pretrained/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/pretrained/.gitkeep -------------------------------------------------------------------------------- /utils/bbox.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Tuple 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch as th 7 | 8 | from data.genx_utils.labels import ObjectLabels 9 | 10 | 11 | def np_th_stack(values: List[Union[np.ndarray, th.Tensor]], axis: int = 0): 12 | """Stack a list of numpy arrays or tensors.""" 13 | if isinstance(values[0], np.ndarray): 14 | return np.stack(values, axis=axis) 15 | elif isinstance(values[0], th.Tensor): 16 | return torch.stack(values, dim=axis) 17 | else: 18 | raise ValueError(f'Unknown type {type(values[0])}') 19 | 20 | 21 | def np_th_concat(values: List[Union[np.ndarray, th.Tensor]], axis: int = 0): 22 | """Concat a list of numpy arrays or tensors.""" 23 | if isinstance(values[0], np.ndarray): 24 | return np.concatenate(values, axis=axis) 25 | elif isinstance(values[0], th.Tensor): 26 | return torch.cat(values, dim=axis) 27 | else: 28 | raise ValueError(f'Unknown type {type(values[0])}') 29 | 30 | 31 | def get_bbox_coords(bbox: Union[np.ndarray, th.Tensor], last4: bool = None): 32 | """Get the 4 coords (xyxy/xywh) from a bbox array or tensor.""" 33 | if isinstance(bbox, list): 34 | bbox = np_th_stack(bbox, axis=0) 35 | if last4 is None: # infer from shape, buggy when bbox.shape == (4, 4) 36 | if bbox.shape[0] == 4: 37 | last4 = False 38 | elif bbox.shape[-1] == 4: 39 | last4 = True 40 | else: 41 | raise ValueError(f'Unknown shape {bbox.shape}') 42 | if last4: 43 | a, b, c, d = bbox[..., 0], bbox[..., 1], bbox[..., 2], bbox[..., 3] 44 | else: 45 | a, b, c, d = bbox 46 | return (a, b, c, d), last4 47 | 48 | 49 | def construct_bbox(abcd: Tuple[Union[np.ndarray, th.Tensor]], last4: bool): 50 | """Construct a bbox from 4 coords (xyxy/xywh).""" 51 | if last4: 52 | return np_th_stack(abcd, axis=-1) 53 | return np_th_stack(abcd, axis=0) 54 | 55 | 56 | def xywh2xyxy(xywh: Union[np.ndarray, th.Tensor], format_: str = 'center', last4: bool = None): 57 | """Convert bounding box from xywh to xyxy format.""" 58 | if isinstance(xywh, ObjectLabels): 59 | return xywh.get_xyxy() 60 | 61 | (x, y, w, h), last4 = get_bbox_coords(xywh, last4=last4) 62 | 63 | if format_ == 'center': 64 | x1, x2 = x - w / 2., x + w / 2. 65 | y1, y2 = y - h / 2., y + h / 2. 66 | elif format_ == 'corner': 67 | x1, x2 = x, x + w 68 | y1, y2 = y, y + h 69 | else: 70 | raise NotImplementedError(f'Unknown format {format_}') 71 | 72 | return construct_bbox((x1, y1, x2, y2), last4=last4) 73 | 74 | 75 | def xyxy2xywh(xyxy: Union[np.ndarray, th.Tensor], format_: str = 'center', last4: bool = None): 76 | """Convert bounding box from xyxy to xywh format.""" 77 | if isinstance(xyxy, ObjectLabels): 78 | return xyxy.get_xywh(format_=format_) 79 | 80 | (x1, y1, x2, y2), last4 = get_bbox_coords(xyxy, last4=last4) 81 | 82 | w, h = x2 - x1, y2 - y1 83 | if format_ == 'center': 84 | x, y = (x1 + x2) / 2., (y1 + y2) / 2. 85 | elif format_ == 'corner': 86 | x, y = x1, y1 87 | else: 88 | raise NotImplementedError(f'Unknown format {format_}') 89 | 90 | return construct_bbox((x, y, w, h), last4=last4) 91 | -------------------------------------------------------------------------------- /utils/evaluation/prophesee/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/utils/evaluation/prophesee/__init__.py -------------------------------------------------------------------------------- /utils/evaluation/prophesee/evaluation.py: -------------------------------------------------------------------------------- 1 | from .io.box_filtering import filter_boxes 2 | from .metrics.coco_eval import evaluate_detection 3 | 4 | 5 | def evaluate_list(result_boxes_list, 6 | gt_boxes_list, 7 | height: int, 8 | width: int, 9 | camera: str = 'gen1', 10 | apply_bbox_filters: bool = True, 11 | downsampled_by_2: bool = False, 12 | return_aps: bool = True): 13 | assert camera in {'gen1', 'gen4'} 14 | 15 | if camera == 'gen1': 16 | classes = ("car", "pedestrian") 17 | elif camera == 'gen4': 18 | classes = ("pedestrian", "two-wheeler", "car") 19 | else: 20 | raise NotImplementedError 21 | 22 | if apply_bbox_filters: 23 | # Default values taken from: https://github.com/prophesee-ai/prophesee-automotive-dataset-toolbox/blob/0393adea2bf22d833893c8cb1d986fcbe4e6f82d/src/psee_evaluator.py#L23-L24 24 | min_box_diag = 60 if camera == 'gen4' else 30 25 | # In the supplementary mat, they say that min_box_side is 20 for gen4. 26 | min_box_side = 20 if camera == 'gen4' else 10 27 | if downsampled_by_2: 28 | assert min_box_diag % 2 == 0 29 | min_box_diag //= 2 30 | assert min_box_side % 2 == 0 31 | min_box_side //= 2 32 | 33 | half_sec_us = int(5e5) 34 | filter_boxes_fn = lambda x: filter_boxes(x, half_sec_us, min_box_diag, min_box_side) 35 | 36 | gt_boxes_list = map(filter_boxes_fn, gt_boxes_list) 37 | # NOTE: We also filter the prediction to follow the prophesee protocol of evaluation. 38 | result_boxes_list = map(filter_boxes_fn, result_boxes_list) 39 | 40 | return evaluate_detection(gt_boxes_list, result_boxes_list, 41 | height=height, width=width, 42 | classes=classes, return_aps=return_aps) 43 | -------------------------------------------------------------------------------- /utils/evaluation/prophesee/evaluator.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple, List, Optional, Dict 2 | from warnings import warn 3 | 4 | import numpy as np 5 | 6 | from utils.evaluation.prophesee.evaluation import evaluate_list 7 | 8 | LABELMAP = { 9 | 'gen1': ('car', 'ped'), 10 | 'gen4': ('ped', 'cyc', 'car'), 11 | } 12 | 13 | 14 | def get_labelmap(dst_name: str = None, num_cls: int = None) -> Tuple[str]: 15 | assert dst_name is None or num_cls is None 16 | if dst_name is not None: 17 | return LABELMAP[dst_name.lower()] 18 | elif num_cls is not None: 19 | assert num_cls in (2, 3), f'Invalid number of classes: {num_cls}' 20 | return LABELMAP['gen1'] if num_cls == 2 else LABELMAP['gen4'] 21 | else: 22 | raise NotImplementedError('Either dst_name or num_cls must be input') 23 | 24 | 25 | class PropheseeEvaluator: 26 | LABELS = 'lables' 27 | PREDICTIONS = 'predictions' 28 | 29 | def __init__(self, dataset: str, downsample_by_2: bool): 30 | super().__init__() 31 | assert dataset in {'gen1', 'gen4'} 32 | self.dataset = dataset 33 | self.label_map = get_labelmap(dataset) 34 | self.downsample_by_2 = downsample_by_2 35 | 36 | self._buffer = None 37 | self._buffer_empty = True 38 | self._reset_buffer() 39 | 40 | def _reset_buffer(self): 41 | self._buffer_empty = True 42 | self._buffer = { 43 | self.LABELS: list(), 44 | self.PREDICTIONS: list(), 45 | } 46 | 47 | def _add_to_buffer(self, key: str, value: List[np.ndarray]): 48 | assert isinstance(value, list) 49 | for entry in value: 50 | assert isinstance(entry, np.ndarray) 51 | self._buffer_empty = False 52 | assert self._buffer is not None 53 | self._buffer[key].extend(value) 54 | 55 | def _get_from_buffer(self, key: str) -> List[np.ndarray]: 56 | assert not self._buffer_empty 57 | assert self._buffer is not None 58 | return self._buffer[key] 59 | 60 | def add_predictions(self, predictions: List[np.ndarray]): 61 | self._add_to_buffer(self.PREDICTIONS, predictions) 62 | 63 | def add_labels(self, labels: List[np.ndarray]): 64 | self._add_to_buffer(self.LABELS, labels) 65 | 66 | def reset_buffer(self) -> None: 67 | # E.g. call in on_validation_epoch_start 68 | self._reset_buffer() 69 | 70 | def has_data(self): 71 | return not self._buffer_empty 72 | 73 | def evaluate_buffer(self, img_height: int, img_width: int, ret_pr_curve: bool = False) -> Optional[Dict[str, Any]]: 74 | # e.g call in on_validation_epoch_end 75 | if self._buffer_empty: 76 | warn("Attempt to use prophesee evaluation buffer, but it is empty", UserWarning, stacklevel=2) 77 | return 78 | 79 | labels = self._get_from_buffer(self.LABELS) 80 | predictions = self._get_from_buffer(self.PREDICTIONS) 81 | assert len(labels) == len(predictions) 82 | 83 | # we perform both per-category and overall evaluation 84 | # overall 85 | metrics = evaluate_list(result_boxes_list=predictions, 86 | gt_boxes_list=labels, 87 | height=img_height, 88 | width=img_width, 89 | apply_bbox_filters=True, 90 | downsampled_by_2=self.downsample_by_2, 91 | camera=self.dataset) 92 | # per-category 93 | for cls_id, cls_name in enumerate(self.label_map): 94 | lbls = [lbl[lbl['class_id'] == cls_id] for lbl in labels] 95 | preds = [pred[pred['class_id'] == cls_id] for pred in predictions] 96 | cls_metric = evaluate_list(result_boxes_list=preds, 97 | gt_boxes_list=lbls, 98 | height=img_height, 99 | width=img_width, 100 | apply_bbox_filters=True, 101 | downsampled_by_2=self.downsample_by_2, 102 | camera=self.dataset) 103 | cls_metric = {f'{k}_{cls_name}': v for k, v in cls_metric.items()} 104 | metrics.update(cls_metric) 105 | 106 | if not ret_pr_curve: 107 | del_keys = [k for k in metrics.keys() if 'PR' in k] 108 | for k in del_keys: 109 | del metrics[k] 110 | return metrics 111 | -------------------------------------------------------------------------------- /utils/evaluation/prophesee/io/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/utils/evaluation/prophesee/io/__init__.py -------------------------------------------------------------------------------- /utils/evaluation/prophesee/io/box_filtering.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define same filtering that we apply in: 3 | "Learning to detect objects on a 1 Megapixel Event Camera" by Etienne Perot et al. 4 | 5 | Namely we apply 2 different filters: 6 | 1. skip all boxes before 0.5s (before we assume it is unlikely you have sufficient historic) 7 | 2. filter all boxes whose diagonal <= min_box_diag**2 and whose side <= min_box_side 8 | 9 | 10 | 11 | Copyright: (c) 2019-2020 Prophesee 12 | """ 13 | from __future__ import print_function 14 | 15 | import numpy as np 16 | 17 | 18 | def filter_boxes(boxes, skip_ts=int(5e5), min_box_diag=60, min_box_side=20): 19 | """Filters boxes according to the paper rule. 20 | 21 | To note: the default represents our threshold when evaluating GEN4 resolution (1280x720) 22 | To note: we assume the initial time of the video is always 0 23 | 24 | Args: 25 | boxes (np.ndarray): structured box array with fields ['t','x','y','w','h','class_id','track_id','class_confidence'] 26 | (example BBOX_DTYPE is provided in src/box_loading.py) 27 | 28 | Returns: 29 | boxes: filtered boxes 30 | """ 31 | ts = boxes['t'] 32 | width = boxes['w'] 33 | height = boxes['h'] 34 | diag_square = width ** 2 + height ** 2 35 | mask = (ts > skip_ts) * (diag_square >= min_box_diag ** 2) * (width >= min_box_side) * (height >= min_box_side) 36 | return boxes[mask] 37 | -------------------------------------------------------------------------------- /utils/evaluation/prophesee/io/npy_events_tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Defines some tools to handle events, mimicking dat_events_tools.py. 5 | In particular : 6 | -> defines functions to read events from binary .npy files using numpy 7 | -> defines functions to write events to binary .dat files using numpy (TODO later) 8 | 9 | Copyright: (c) 2015-2019 Prophesee 10 | """ 11 | from __future__ import print_function 12 | 13 | import numpy as np 14 | 15 | 16 | def stream_td_data(file_handle, buffer, dtype, ev_count=-1): 17 | """ 18 | Streams data from opened file_handle 19 | args : 20 | - file_handle: file object 21 | - buffer: pre-allocated buffer to fill with events 22 | - dtype: expected fields 23 | - ev_count: number of events 24 | """ 25 | dat = np.fromfile(file_handle, dtype=dtype, count=ev_count) 26 | count = len(dat['t']) 27 | for name, _ in dtype: 28 | buffer[name][:count] = dat[name] 29 | 30 | 31 | def parse_header(fhandle): 32 | """ 33 | Parses the header of a .npy file 34 | Args: 35 | - f file handle to a .npy file 36 | return : 37 | - int position of the file cursor after the header 38 | - int type of event 39 | - int size of event in bytes 40 | - size (height, width) tuple of int or (None, None) 41 | """ 42 | version = np.lib.format.read_magic(fhandle) 43 | shape, fortran, dtype = np.lib.format._read_array_header(fhandle, version) 44 | assert not fortran, "Fortran order arrays not supported" 45 | # Get the number of elements in one 'row' by taking 46 | # a product over all other dimensions. 47 | if len(shape) == 0: 48 | count = 1 49 | else: 50 | count = np.multiply.reduce(shape, dtype=np.int64) 51 | ev_size = dtype.itemsize 52 | assert ev_size != 0 53 | start = fhandle.tell() 54 | # turn numpy.dtype into an iterable list 55 | ev_type = [(x, str(dtype.fields[x][0])) for x in dtype.names] 56 | # filter name to have only t and not ts 57 | ev_type = [(name if name != "ts" else "t", desc) for name, desc in ev_type] 58 | ev_type = [(name if name != "confidence" else "class_confidence", desc) for name, desc in ev_type] 59 | size = (None, None) 60 | size = (None, None) 61 | 62 | return start, ev_type, ev_size, size 63 | -------------------------------------------------------------------------------- /utils/evaluation/prophesee/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/utils/evaluation/prophesee/metrics/__init__.py -------------------------------------------------------------------------------- /utils/evaluation/prophesee/visualize/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wuziyi616/LEOD/d783b254bcfe2a5bd12b621f34f014491a82bb4b/utils/evaluation/prophesee/visualize/__init__.py -------------------------------------------------------------------------------- /utils/padding.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Tuple 2 | 3 | import torch as th 4 | import torch.nn.functional as F 5 | 6 | 7 | class InputPadderFromShape: 8 | """Pad input to desired height and width.""" 9 | 10 | def __init__(self, desired_hw: Tuple[int, int], mode: str = 'constant', 11 | value: int = 0, type: str = 'corner'): 12 | """ 13 | :param desired_hw: Desired height and width 14 | :param mode: See torch.nn.functional.pad 15 | :param value: See torch.nn.functional.pad 16 | :param type: "corner": add zero to bottom and right 17 | """ 18 | assert isinstance(desired_hw, tuple) 19 | assert len(desired_hw) == 2 20 | assert desired_hw[0] % 4 == 0, 'Required for token mask padding' 21 | assert desired_hw[1] % 4 == 0, 'Required for token mask padding' 22 | assert type in {'corner'} 23 | 24 | self.desired_hw = desired_hw 25 | self.mode = mode 26 | self.value = value 27 | self.type = type 28 | self._pad_ev_repr = None 29 | self._pad_token_mask = None 30 | 31 | @staticmethod 32 | def _pad_tensor_impl(input_tensor: th.Tensor, desired_hw: Tuple[int, int], 33 | mode: str, value: Any) -> Tuple[th.Tensor, List[int]]: 34 | assert isinstance(input_tensor, th.Tensor) 35 | 36 | ht, wd = input_tensor.shape[-2:] 37 | ht_des, wd_des = desired_hw 38 | assert ht <= ht_des 39 | assert wd <= wd_des 40 | 41 | pad_left = 0 42 | pad_right = wd_des - wd 43 | pad_top = 0 44 | pad_bottom = ht_des - ht 45 | 46 | pad = [pad_left, pad_right, pad_top, pad_bottom] 47 | return F.pad(input_tensor, pad=pad, mode=mode, value=value if 48 | mode == 'constant' else None), pad 49 | 50 | def pad_tensor_ev_repr(self, ev_repr: th.Tensor) -> th.Tensor: 51 | padded_ev_repr, pad = self._pad_tensor_impl( 52 | input_tensor=ev_repr, desired_hw=self.desired_hw, 53 | mode=self.mode, value=self.value) 54 | if self._pad_ev_repr is None: 55 | self._pad_ev_repr = pad 56 | else: 57 | assert self._pad_ev_repr == pad 58 | return padded_ev_repr 59 | 60 | def pad_token_mask(self, token_mask: th.Tensor): 61 | assert isinstance(token_mask, th.Tensor) 62 | 63 | desired_hw = tuple(x // 4 for x in self.desired_hw) 64 | padded_token_mask, pad = self._pad_tensor_impl( 65 | input_tensor=token_mask, desired_hw=desired_hw, 66 | mode='constant', value=0) 67 | if self._pad_token_mask is None: 68 | self._pad_token_mask = pad 69 | else: 70 | assert self._pad_token_mask == pad 71 | return padded_token_mask 72 | -------------------------------------------------------------------------------- /utils/preprocessing.py: -------------------------------------------------------------------------------- 1 | from .helpers import subsample_list 2 | 3 | 4 | def _blosc_opts(complevel=1, complib='blosc:zstd', shuffle='byte'): 5 | shuffle = 2 if shuffle == 'bit' else 1 if shuffle == 'byte' else 0 6 | compressors = ['blosclz', 'lz4', 'lz4hc', 'snappy', 'zlib', 'zstd'] 7 | complib = ['blosc:' + c for c in compressors].index(complib) 8 | args = { 9 | 'compression': 32001, 10 | 'compression_opts': (0, 0, 0, 0, complevel, shuffle, complib), 11 | } 12 | if shuffle > 0: 13 | # Do not use h5py shuffle if blosc shuffle is enabled. 14 | args['shuffle'] = False 15 | return args 16 | 17 | 18 | def subsample_sequence(split_path, ratio): 19 | """Subsample the sequence under a folder by a given ratio.""" 20 | seq_dirs = sorted([p for p in split_path.iterdir()]) 21 | print(f'Found {len(seq_dirs)} sequences in {str(split_path)}') 22 | # may need to sub-sample training seqs 23 | if 0. < ratio < 1.: 24 | num = round(len(seq_dirs) * ratio) 25 | seq_dirs = subsample_list(seq_dirs, num) 26 | assert len(seq_dirs) == num 27 | print(f'Using {ratio*100}% of data --> {len(seq_dirs)} sequences') 28 | return seq_dirs 29 | -------------------------------------------------------------------------------- /utils/timers.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import time 3 | from functools import wraps 4 | 5 | import numpy as np 6 | import torch 7 | 8 | cuda_timers = {} 9 | timers = {} 10 | 11 | 12 | class CudaTimer: 13 | def __init__(self, device: torch.device, timer_name: str): 14 | assert isinstance(device, torch.device) 15 | assert isinstance(timer_name, str) 16 | self.timer_name = timer_name 17 | if self.timer_name not in cuda_timers: 18 | cuda_timers[self.timer_name] = [] 19 | 20 | self.device = device 21 | self.start = None 22 | self.end = None 23 | 24 | def __enter__(self): 25 | torch.cuda.synchronize(device=self.device) 26 | self.start = time.time() 27 | return self 28 | 29 | def __exit__(self, *args): 30 | assert self.start is not None 31 | torch.cuda.synchronize(device=self.device) 32 | end = time.time() 33 | cuda_timers[self.timer_name].append(end - self.start) 34 | 35 | 36 | def cuda_timer_decorator(device: torch.device, timer_name: str): 37 | def decorator(func): 38 | @wraps(func) 39 | def wrapper(*args, **kwargs): 40 | with CudaTimer(device=device, timer_name=timer_name): 41 | out = func(*args, **kwargs) 42 | return out 43 | 44 | return wrapper 45 | 46 | return decorator 47 | 48 | 49 | class TimerDummy: 50 | def __init__(self, *args, **kwargs): 51 | pass 52 | 53 | def __enter__(self): 54 | pass 55 | 56 | def __exit__(self, *args): 57 | pass 58 | 59 | 60 | class Timer: 61 | def __init__(self, timer_name=''): 62 | self.timer_name = timer_name 63 | if self.timer_name not in timers: 64 | timers[self.timer_name] = [] 65 | 66 | def __enter__(self): 67 | self.start = time.time() 68 | return self 69 | 70 | def __exit__(self, *args): 71 | end = time.time() 72 | time_diff_s = end - self.start # measured in seconds 73 | timers[self.timer_name].append(time_diff_s) 74 | 75 | 76 | def print_timing_info(): 77 | print('== Timing statistics ==') 78 | skip_warmup = 10 79 | for timer_name, timing_values in [*cuda_timers.items(), *timers.items()]: 80 | if len(timing_values) <= skip_warmup: 81 | continue 82 | values = timing_values[skip_warmup:] 83 | timing_value_s_mean = np.mean(np.array(values)) 84 | timing_value_s_median = np.median(np.array(values)) 85 | timing_value_ms_mean = timing_value_s_mean * 1000 86 | timing_value_ms_median = timing_value_s_median * 1000 87 | if timing_value_ms_mean > 1000: 88 | print('{}: mean={:.2f} s, median={:.2f} s'.format(timer_name, timing_value_s_mean, timing_value_s_median)) 89 | else: 90 | print( 91 | '{}: mean={:.2f} ms, median={:.2f} ms'.format(timer_name, timing_value_ms_mean, timing_value_ms_median)) 92 | 93 | 94 | # this will print all the timer values upon termination of any program that imported this file 95 | atexit.register(print_timing_info) 96 | -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 4 | os.environ["OMP_NUM_THREADS"] = "1" 5 | os.environ["OPENBLAS_NUM_THREADS"] = "1" 6 | os.environ["MKL_NUM_THREADS"] = "1" 7 | os.environ["VECLIB_MAXIMUM_THREADS"] = "1" 8 | os.environ["NUMEXPR_NUM_THREADS"] = "1" 9 | import hdf5plugin # resolve a weird h5py error 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | 13 | import torch 14 | from torch.backends import cuda, cudnn 15 | 16 | cuda.matmul.allow_tf32 = True 17 | cudnn.allow_tf32 = True 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | 20 | import hydra 21 | from omegaconf import DictConfig, OmegaConf 22 | import pytorch_lightning as pl 23 | from pytorch_lightning.loggers import CSVLogger 24 | from pytorch_lightning.callbacks import ModelSummary 25 | 26 | from config.modifier import dynamically_modify_train_config 27 | from modules.utils.fetch import fetch_data_module, fetch_model_module 28 | 29 | 30 | @hydra.main(config_path='config', config_name='val', version_base='1.2') 31 | def main(config: DictConfig): 32 | dynamically_modify_train_config(config) 33 | # Just to check whether config can be resolved 34 | OmegaConf.to_container(config, resolve=True, throw_on_missing=True) 35 | 36 | print('------ Configuration ------') 37 | # print(OmegaConf.to_yaml(config)) 38 | _ = OmegaConf.to_yaml(config) 39 | print('---------------------------') 40 | 41 | # --------------------- 42 | # GPU options 43 | # --------------------- 44 | gpus = config.hardware.gpus 45 | assert isinstance(gpus, int), 'no more than 1 GPU supported' 46 | gpus = [gpus] 47 | 48 | # --------------------- 49 | # Data 50 | # --------------------- 51 | if 'T4' in torch.cuda.get_device_name() and \ 52 | config.tta.enable and config.tta.hflip: 53 | if config.dataset.name == 'gen1': 54 | config.batch_size.eval = 12 # to avoid OOM on T4 GPU 55 | else: 56 | config.batch_size.eval = 6 57 | if config.reverse: 58 | config.dataset.reverse_event_order = True 59 | print('Testing on event sequences with reversed temporal order.') 60 | data_module = fetch_data_module(config=config) 61 | 62 | # --------------------- 63 | # Logging 64 | # --------------------- 65 | logger = CSVLogger(save_dir='./validation_logs') 66 | 67 | # --------------------- 68 | # Model 69 | # --------------------- 70 | module = fetch_model_module(config=config).eval() 71 | module.load_weight(config.checkpoint) 72 | 73 | # --------------------- 74 | # Callbacks and Misc 75 | # --------------------- 76 | callbacks = [ModelSummary(max_depth=2)] 77 | 78 | # --------------------- 79 | # Validation 80 | # --------------------- 81 | 82 | trainer = pl.Trainer( 83 | accelerator='gpu', 84 | callbacks=callbacks, 85 | default_root_dir=None, 86 | devices=gpus, 87 | logger=logger, 88 | log_every_n_steps=100, 89 | precision=config.training.precision, 90 | move_metrics_to_cpu=False, 91 | ) 92 | with torch.inference_mode(): 93 | trainer.test(model=module, datamodule=data_module) 94 | print(f'Evaluating {config.checkpoint=} finished.') 95 | print(f'Conf_thresh: {config.model.postprocess.confidence_threshold}') 96 | 97 | 98 | if __name__ == '__main__': 99 | main() 100 | --------------------------------------------------------------------------------