├── .gitmodules
├── README.md
├── configs
├── base_r50_4x_clip.yaml
├── lvis-base_in-l_r50_4x_4x_clip_gpt3_descriptions.yaml
├── lvis-base_in-l_r50_4x_4x_clip_image_exemplars_agg.yaml
├── lvis-base_in-l_r50_4x_4x_clip_image_exemplars_avg.yaml
├── lvis-base_in-l_r50_4x_4x_clip_multi_modal_agg.yaml
├── lvis-base_in-l_r50_4x_4x_clip_multi_modal_avg.yaml
├── lvis-base_r50_4x_clip_gpt3_descriptions.yaml
├── lvis-base_r50_4x_clip_image_exemplars_agg.yaml
├── lvis-base_r50_4x_clip_image_exemplars_avg.yaml
├── lvis-base_r50_4x_clip_multi_modal_agg.yaml
└── lvis-base_r50_4x_clip_multi_modal_avg.yaml
├── datasets
├── README.md
└── metadata
│ ├── lvis_gpt3_text-davinci-002_descriptions_author.json
│ ├── lvis_gpt3_text-davinci-002_features_author.npy
│ ├── lvis_image_exemplar_dict_K-005_author.json
│ ├── lvis_image_exemplar_features_agg_K-005_author.npy
│ ├── lvis_image_exemplar_features_avg_K-005_author.npy
│ ├── lvis_multi-modal_agg_K-005_author.npy
│ ├── lvis_multi-modal_avg_K-005_author.npy
│ ├── lvis_v1_clip_a+cname.npy
│ └── lvis_v1_train_cat_info.json
├── docs
├── INSTALL.md
├── MODEL_ZOO.md
└── teaser.jpg
├── mmovod
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── config.cpython-38.pyc
│ └── custom_solver.cpython-38.pyc
├── config.py
├── custom_solver.py
├── data
│ ├── __pycache__
│ │ ├── custom_build_augmentation.cpython-38.pyc
│ │ ├── custom_dataset_dataloader.cpython-38.pyc
│ │ ├── custom_dataset_mapper.cpython-38.pyc
│ │ └── tar_dataset.cpython-38.pyc
│ ├── custom_build_augmentation.py
│ ├── custom_dataset_dataloader.py
│ ├── custom_dataset_mapper.py
│ ├── datasets
│ │ ├── __pycache__
│ │ │ ├── cc.cpython-38.pyc
│ │ │ ├── coco_zeroshot.cpython-38.pyc
│ │ │ ├── imagenet.cpython-38.pyc
│ │ │ ├── lvis_22k_categories.cpython-38.pyc
│ │ │ ├── lvis_v1.cpython-38.pyc
│ │ │ ├── objects365.cpython-38.pyc
│ │ │ ├── oid.cpython-38.pyc
│ │ │ └── register_oid.cpython-38.pyc
│ │ ├── imagenet.py
│ │ └── lvis_v1.py
│ └── transforms
│ │ ├── __pycache__
│ │ ├── custom_augmentation_impl.cpython-38.pyc
│ │ └── custom_transform.cpython-38.pyc
│ │ ├── custom_augmentation_impl.py
│ │ └── custom_transform.py
├── modeling
│ ├── __pycache__
│ │ ├── debug.cpython-38.pyc
│ │ └── utils.cpython-38.pyc
│ ├── backbone
│ │ ├── __pycache__
│ │ │ ├── swintransformer.cpython-38.pyc
│ │ │ └── timm.cpython-38.pyc
│ │ ├── swintransformer.py
│ │ └── timm.py
│ ├── debug.py
│ ├── meta_arch
│ │ ├── __pycache__
│ │ │ ├── custom_rcnn.cpython-38.pyc
│ │ │ └── d2_deformable_detr.cpython-38.pyc
│ │ └── custom_rcnn.py
│ ├── roi_heads
│ │ ├── __pycache__
│ │ │ ├── detic_fast_rcnn.cpython-38.pyc
│ │ │ ├── detic_roi_heads.cpython-38.pyc
│ │ │ ├── res5_roi_heads.cpython-38.pyc
│ │ │ └── zero_shot_classifier.cpython-38.pyc
│ │ ├── detic_fast_rcnn.py
│ │ ├── detic_roi_heads.py
│ │ ├── res5_roi_heads.py
│ │ └── zero_shot_classifier.py
│ ├── text
│ │ ├── __pycache__
│ │ │ └── text_encoder.cpython-38.pyc
│ │ └── text_encoder.py
│ └── utils.py
└── predictor.py
├── requirements.txt
├── tools
├── collate_exemplar_dict.py
├── convert-thirdparty-pretrained-model-to-d2.py
├── create_imagenetlvis_json.py
├── dump_clip_features.py
├── dump_clip_features_lvis_sentences.py
├── generate_descriptions.py
├── generate_vision_classifier_agg.py
├── get_exemplars_tta.py
├── get_lvis_cat_info.py
├── norm_feat_sum_norm.py
├── remove_lvis_rare.py
├── sample_exemplars.py
└── unzip_imagenet_lvis.py
└── train_net_auto.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/CenterNet2"]
2 | path = third_party/CenterNet2
3 | url = https://github.com/xingyizhou/CenterNet2.git
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multi-Modal Classifiers for Open-Vocabulary Object Detection
2 |
3 |
4 |
5 | > [**Multi-Modal Classifiers for Open Vocabulary Object Detection**](https://arxiv.org/abs/2306.05493),
6 | > Prannay Kaul, Weidi Xie, Andrew Zisserman
7 | > *ICML 2023 ([arXiv 2201.02605](https://arxiv.org/abs/2306.05493))*
8 |
9 |
10 | ## Updates
11 |
12 | - **June 2023** Code and checkpoints for LVIS models in the main paper are released. Training code for visual aggregator to follow soon.
13 |
14 |
15 | ## Installation
16 |
17 | See [installation instructions](docs/INSTALL.md).
18 |
19 |
20 | ## Benchmark evaluation and training
21 |
22 | Please first [prepare datasets](datasets/README.md), then check our [MODEL ZOO](docs/MODEL_ZOO.md) to reproduce results in our paper.
23 |
24 |
25 | ## License
26 |
27 | See [Detic](https://github.com/facebookresearch/Detic). Our code is based on this repository.
28 |
29 | ## Citation
30 |
31 | If you find this project useful for your research, please use the following BibTeX entry.
32 |
33 | @inproceedings{Kaul2023,
34 | title={Multi-Modal Classifiers for Open-Vocabulary Object Detection},
35 | author={Kaul, Prannay and Xie, Weidi and Zisserman, Andrew},
36 | booktitle={ICML},
37 | year={2023}
38 | }
39 |
--------------------------------------------------------------------------------
/configs/base_r50_4x_clip.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | META_ARCHITECTURE: "CustomRCNN"
3 | MASK_ON: True
4 | PROPOSAL_GENERATOR:
5 | NAME: "CenterNet"
6 | WEIGHTS: "checkpoints/resnet50_miil_21k.pkl"
7 | BACKBONE:
8 | NAME: build_p67_timm_fpn_backbone
9 | TIMM:
10 | BASE_NAME: resnet50_in21k
11 | FPN:
12 | IN_FEATURES: ["layer3", "layer4", "layer5"]
13 | PIXEL_MEAN: [123.675, 116.280, 103.530]
14 | PIXEL_STD: [58.395, 57.12, 57.375]
15 | ROI_HEADS:
16 | NAME: DeticCascadeROIHeads
17 | IN_FEATURES: ["p3", "p4", "p5"]
18 | IOU_THRESHOLDS: [0.6]
19 | NUM_CLASSES: 1203
20 | SCORE_THRESH_TEST: 0.02
21 | NMS_THRESH_TEST: 0.5
22 | ROI_BOX_CASCADE_HEAD:
23 | IOUS: [0.6, 0.7, 0.8]
24 | ROI_BOX_HEAD:
25 | NAME: "FastRCNNConvFCHead"
26 | NUM_FC: 2
27 | POOLER_RESOLUTION: 7
28 | CLS_AGNOSTIC_BBOX_REG: True
29 | MULT_PROPOSAL_SCORE: True
30 | USE_SIGMOID_CE: True
31 | USE_FED_LOSS: True
32 | ROI_MASK_HEAD:
33 | NAME: "MaskRCNNConvUpsampleHead"
34 | NUM_CONV: 4
35 | POOLER_RESOLUTION: 14
36 | CLS_AGNOSTIC_MASK: True
37 | CENTERNET:
38 | NUM_CLASSES: 1203
39 | REG_WEIGHT: 1.
40 | NOT_NORM_REG: True
41 | ONLY_PROPOSAL: True
42 | WITH_AGN_HM: True
43 | INFERENCE_TH: 0.0001
44 | PRE_NMS_TOPK_TRAIN: 4000
45 | POST_NMS_TOPK_TRAIN: 2000
46 | PRE_NMS_TOPK_TEST: 1000
47 | POST_NMS_TOPK_TEST: 256
48 | NMS_TH_TRAIN: 0.9
49 | NMS_TH_TEST: 0.9
50 | POS_WEIGHT: 0.5
51 | NEG_WEIGHT: 0.5
52 | IGNORE_HIGH_FP: 0.85
53 | DATASETS:
54 | TRAIN: ("lvis_v1_train",)
55 | TEST: ("lvis_v1_val",)
56 | DATALOADER:
57 | SAMPLER_TRAIN: "RepeatFactorTrainingSampler"
58 | REPEAT_THRESHOLD: 0.001
59 | NUM_WORKERS: 8
60 | TEST:
61 | DETECTIONS_PER_IMAGE: 300
62 | SOLVER:
63 | LR_SCHEDULER_NAME: "WarmupCosineLR"
64 | CHECKPOINT_PERIOD: 30000
65 | WARMUP_ITERS: 10000
66 | WARMUP_FACTOR: 0.0001
67 | USE_CUSTOM_SOLVER: True
68 | OPTIMIZER: "ADAMW"
69 | MAX_ITER: 90000
70 | IMS_PER_BATCH: 64
71 | BASE_LR: 0.0002
72 | CLIP_GRADIENTS:
73 | ENABLED: True
74 | INPUT:
75 | FORMAT: RGB
76 | CUSTOM_AUG: EfficientDetResizeCrop
77 | TRAIN_SIZE: 640
78 | OUTPUT_DIR: "./output/mm-ovod/auto"
79 | EVAL_PROPOSAL_AR: False
80 | VERSION: 2
81 | FP16: True
82 |
--------------------------------------------------------------------------------
/configs/lvis-base_in-l_r50_4x_4x_clip_gpt3_descriptions.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "base_r50_4x_clip.yaml"
2 | MODEL:
3 | ROI_BOX_HEAD:
4 | USE_ZEROSHOT_CLS: True
5 | IMAGE_LABEL_LOSS: 'max_size'
6 | USE_BIAS: -2.0
7 | ZEROSHOT_WEIGHT_PATH: "datasets/metadata/lvis_gpt3_text-davinci-002_features_author.npy"
8 | WEIGHTS: "output/mm-ovod/lvis-base_r50_4x_clip_gpt3_descriptions/model_final.pth"
9 | SOLVER:
10 | MAX_ITER: 90000
11 | IMS_PER_BATCH: 64
12 | BASE_LR: 0.0002
13 | WARMUP_ITERS: 1000
14 | WARMUP_FACTOR: 0.001
15 | DATASETS:
16 | TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1")
17 | DATALOADER:
18 | SAMPLER_TRAIN: "MultiDatasetSampler"
19 | DATASET_RATIO: [1, 4]
20 | USE_DIFF_BS_SIZE: True
21 | DATASET_BS: [16, 64]
22 | DATASET_INPUT_SIZE: [640, 320]
23 | USE_RFS: [True, False]
24 | DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]]
25 | FILTER_EMPTY_ANNOTATIONS: False
26 | MULTI_DATASET_GROUPING: True
27 | DATASET_ANN: ['box', 'image']
28 | NUM_WORKERS: 8
29 | WITH_IMAGE_LABELS: True
30 |
--------------------------------------------------------------------------------
/configs/lvis-base_in-l_r50_4x_4x_clip_image_exemplars_agg.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "base_r50_4x_clip.yaml"
2 | MODEL:
3 | ROI_BOX_HEAD:
4 | USE_ZEROSHOT_CLS: True
5 | IMAGE_LABEL_LOSS: 'max_size'
6 | USE_BIAS: -2.0
7 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_image_exemplar_features_agg_K-005_author.npy'
8 | WEIGHTS: "output/mm-ovod/lvis-base_r50_4x_clip_image_exemplars_agg/model_final.pth"
9 | SOLVER:
10 | MAX_ITER: 90000
11 | IMS_PER_BATCH: 64
12 | BASE_LR: 0.0002
13 | WARMUP_ITERS: 1000
14 | WARMUP_FACTOR: 0.001
15 | DATASETS:
16 | TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1")
17 | DATALOADER:
18 | SAMPLER_TRAIN: "MultiDatasetSampler"
19 | DATASET_RATIO: [1, 4]
20 | USE_DIFF_BS_SIZE: True
21 | DATASET_BS: [16, 64]
22 | DATASET_INPUT_SIZE: [640, 320]
23 | USE_RFS: [True, False]
24 | DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]]
25 | FILTER_EMPTY_ANNOTATIONS: False
26 | MULTI_DATASET_GROUPING: True
27 | DATASET_ANN: ['box', 'image']
28 | NUM_WORKERS: 8
29 | WITH_IMAGE_LABELS: True
30 |
--------------------------------------------------------------------------------
/configs/lvis-base_in-l_r50_4x_4x_clip_image_exemplars_avg.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "base_r50_4x_clip.yaml"
2 | MODEL:
3 | ROI_BOX_HEAD:
4 | USE_ZEROSHOT_CLS: True
5 | IMAGE_LABEL_LOSS: 'max_size'
6 | USE_BIAS: -2.0
7 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_image_exemplar_features_avg_K-005_author.npy'
8 | WEIGHTS: "output/mm-ovod/lvis-base_r50_4x_clip_image_exemplars_avg/model_final.pth"
9 | SOLVER:
10 | MAX_ITER: 90000
11 | IMS_PER_BATCH: 64
12 | BASE_LR: 0.0002
13 | WARMUP_ITERS: 1000
14 | WARMUP_FACTOR: 0.001
15 | DATASETS:
16 | TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1")
17 | DATALOADER:
18 | SAMPLER_TRAIN: "MultiDatasetSampler"
19 | DATASET_RATIO: [1, 4]
20 | USE_DIFF_BS_SIZE: True
21 | DATASET_BS: [16, 64]
22 | DATASET_INPUT_SIZE: [640, 320]
23 | USE_RFS: [True, False]
24 | DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]]
25 | FILTER_EMPTY_ANNOTATIONS: False
26 | MULTI_DATASET_GROUPING: True
27 | DATASET_ANN: ['box', 'image']
28 | NUM_WORKERS: 8
29 | WITH_IMAGE_LABELS: True
30 |
--------------------------------------------------------------------------------
/configs/lvis-base_in-l_r50_4x_4x_clip_multi_modal_agg.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "base_r50_4x_clip.yaml"
2 | MODEL:
3 | ROI_BOX_HEAD:
4 | USE_ZEROSHOT_CLS: True
5 | IMAGE_LABEL_LOSS: 'max_size'
6 | USE_BIAS: -2.0
7 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_multi-modal_agg_K-005_author.npy'
8 | WEIGHTS: "output/mm-ovod/lvis-base_r50_4x_clip_multi_modal_agg/model_final.pth"
9 | SOLVER:
10 | MAX_ITER: 90000
11 | IMS_PER_BATCH: 64
12 | BASE_LR: 0.0002
13 | WARMUP_ITERS: 1000
14 | WARMUP_FACTOR: 0.001
15 | DATASETS:
16 | TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1")
17 | DATALOADER:
18 | SAMPLER_TRAIN: "MultiDatasetSampler"
19 | DATASET_RATIO: [1, 4]
20 | USE_DIFF_BS_SIZE: True
21 | DATASET_BS: [16, 64]
22 | DATASET_INPUT_SIZE: [640, 320]
23 | USE_RFS: [True, False]
24 | DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]]
25 | FILTER_EMPTY_ANNOTATIONS: False
26 | MULTI_DATASET_GROUPING: True
27 | DATASET_ANN: ['box', 'image']
28 | NUM_WORKERS: 8
29 | WITH_IMAGE_LABELS: True
30 |
--------------------------------------------------------------------------------
/configs/lvis-base_in-l_r50_4x_4x_clip_multi_modal_avg.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "base_r50_4x_clip.yaml"
2 | MODEL:
3 | ROI_BOX_HEAD:
4 | USE_ZEROSHOT_CLS: True
5 | IMAGE_LABEL_LOSS: 'max_size'
6 | USE_BIAS: -2.0
7 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_multi-modal_avg_K-005_author.npy'
8 | WEIGHTS: "output/mm-ovod/lvis-base_r50_4x_clip_multi_modal_avg/model_final.pth"
9 | SOLVER:
10 | MAX_ITER: 90000
11 | IMS_PER_BATCH: 64
12 | BASE_LR: 0.0002
13 | WARMUP_ITERS: 1000
14 | WARMUP_FACTOR: 0.001
15 | DATASETS:
16 | TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1")
17 | DATALOADER:
18 | SAMPLER_TRAIN: "MultiDatasetSampler"
19 | DATASET_RATIO: [1, 4]
20 | USE_DIFF_BS_SIZE: True
21 | DATASET_BS: [16, 64]
22 | DATASET_INPUT_SIZE: [640, 320]
23 | USE_RFS: [True, False]
24 | DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]]
25 | FILTER_EMPTY_ANNOTATIONS: False
26 | MULTI_DATASET_GROUPING: True
27 | DATASET_ANN: ['box', 'image']
28 | NUM_WORKERS: 8
29 | WITH_IMAGE_LABELS: True
30 |
--------------------------------------------------------------------------------
/configs/lvis-base_r50_4x_clip_gpt3_descriptions.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "base_r50_4x_clip.yaml"
2 | MODEL:
3 | ROI_BOX_HEAD:
4 | USE_ZEROSHOT_CLS: True
5 | ZEROSHOT_WEIGHT_PATH: "datasets/metadata/lvis_gpt3_text-davinci-002_features_author.npy"
6 | USE_BIAS: -2.0
7 | DATASETS:
8 | TRAIN: ("lvis_v1_train_norare",)
9 |
--------------------------------------------------------------------------------
/configs/lvis-base_r50_4x_clip_image_exemplars_agg.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "base_r50_4x_clip.yaml"
2 | MODEL:
3 | ROI_BOX_HEAD:
4 | USE_ZEROSHOT_CLS: True
5 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_image_exemplar_features_agg_K-005_author.npy'
6 | USE_BIAS: -2.0
7 | DATASETS:
8 | TRAIN: ("lvis_v1_train_norare",)
9 |
--------------------------------------------------------------------------------
/configs/lvis-base_r50_4x_clip_image_exemplars_avg.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "base_r50_4x_clip.yaml"
2 | MODEL:
3 | ROI_BOX_HEAD:
4 | USE_ZEROSHOT_CLS: True
5 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_image_exemplar_features_avg_K-005_author.npy'
6 | USE_BIAS: -2.0
7 | DATASETS:
8 | TRAIN: ("lvis_v1_train_norare",)
9 |
--------------------------------------------------------------------------------
/configs/lvis-base_r50_4x_clip_multi_modal_agg.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "base_r50_4x_clip.yaml"
2 | MODEL:
3 | ROI_BOX_HEAD:
4 | USE_ZEROSHOT_CLS: True
5 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_multi-modal_agg_K-005_author.npy'
6 | USE_BIAS: -2.0
7 | DATASETS:
8 | TRAIN: ("lvis_v1_train_norare",)
9 |
--------------------------------------------------------------------------------
/configs/lvis-base_r50_4x_clip_multi_modal_avg.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "base_r50_4x_clip.yaml"
2 | MODEL:
3 | ROI_BOX_HEAD:
4 | USE_ZEROSHOT_CLS: True
5 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_multi-modal_avg_K-005_author.npy'
6 | USE_BIAS: -2.0
7 | DATASETS:
8 | TRAIN: ("lvis_v1_train_norare",)
9 |
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Prepare datasets for MM-OVOD (borrows and edits from Detic)
2 |
3 | The basic training of our model uses [LVIS](https://www.lvisdataset.org/) (which uses [COCO](https://cocodataset.org/) images) and [ImageNet-21K](https://www.image-net.org/download.php).
4 |
5 | Before starting processing, please download the (selected) datasets from the official websites and place or sim-link them under `${mm-ovod_ROOT}/datasets/` with details shown below.
6 |
7 | ```
8 | ${mm-ovod_ROOT}/datasets/
9 | metadata/
10 | lvis/
11 | coco/
12 | imagenet/
13 | VisualGenome/
14 | ```
15 | `metadata/` is our preprocessed meta-data (included in the repo). See the below [section](#Metadata) for details.
16 | Please follow the following instruction to pre-process individual datasets.
17 |
18 | ### COCO and LVIS
19 |
20 | First, download COCO images and LVIS data place them in the following way:
21 |
22 | ```
23 | lvis/
24 | lvis_v1_train.json
25 | lvis_v1_val.json
26 | coco/
27 | train2017/
28 | val2017/
29 | ```
30 |
31 | Next, prepare the open-vocabulary LVIS training set using
32 |
33 | ```
34 | python tools/remove_lvis_rare.py --ann datasets/lvis/lvis_v1_train.json
35 | ```
36 |
37 | This will generate `datasets/lvis/lvis_v1_train_norare.json`.
38 |
39 | ### ImageNet-21K
40 |
41 | The imagenet folder should look like the below, after following the data-processing
42 | [script](https://github.com/Alibaba-MIIL/ImageNet21K/blob/main/dataset_preprocessing/processing_script.sh) from ImageNet-21K Pretraining for the Masses ensuring to use the FALL 2011 version.
43 | After this script has run, please rename folders to give the below structure:
44 | ```
45 | imagenet/
46 | imagenet21k_P/
47 | train/
48 | n00005787/
49 | n00005787_*.JPEG
50 | n00006484/
51 | n00006484_*.JPEG
52 | ...
53 | val/
54 | n00005787/
55 | n00005787_*.JPEG
56 | n00006484/
57 | n00006484_*.JPEG
58 | ...
59 | imagenet21k_small_classes/
60 | n00004475/
61 | n00004475_*.JPEG
62 | n00006024/
63 | n00006024_*.JPEG
64 | ...
65 | ```
66 |
67 | The subset of ImageNet that overlaps with LVIS (IN-L in the paper) will be created from this directory
68 | structure.
69 |
70 | ~~~
71 | cd ${mm-ovod_ROOT}/datasets/
72 | mkdir imagenet/annotations
73 | python tools/create_imagenetlvis_json.py --imagenet-path datasets/imagenet/imagenet21k_P --out-path datasets/imagenet/annotations/imagenet_lvis_image_info.json
74 | ~~~
75 | This creates `datasets/imagenet/annotations/imagenet_lvis_image_info.json`.
76 |
77 |
78 | ### VisualGenome
79 |
80 | Some of our image exemplars are sourced from VisualGenome and so download the dataset ensuring the following
81 | files are present with the below structure:
82 | ```
83 | VisualGenome/
84 | VG_100K/
85 | *.jpg
86 | VG_100K_2/
87 | *.jpg
88 | objects.json
89 | image_data.json
90 | ```
91 |
92 | ### Metadata
93 |
94 | ```
95 | metadata/
96 | lvis_v1_train_cat_info.json
97 | lvis_gpt3_text-davinci-002_descriptions_author.json
98 | lvis_gpt3_text-davinci-002_features_author.npy
99 | lvis_image_exemplar_dict_K-005_author.json
100 | lvis_image_exemplar_features_agg_K-005_author.npy
101 | lvis_image_exemplar_features_avg_K-005_author.npy
102 | lvis_multi-modal_agg_K-005_author.npy
103 | lvis_multi-modal_avg_K-005_author.npy
104 | lvis_v1_clip_a+cname.npy
105 | ```
106 |
107 | `lvis_v1_train_cat_info.json` is used by the Federated loss.
108 | This is created by
109 | ~~~
110 | python tools/get_lvis_cat_info.py --ann datasets/lvis/lvis_v1_train.json
111 | ~~~
112 |
113 | `lvis_gpt3_text-davinci-002_descriptions_author.json` are the descriptions for each LVIS class
114 | we found using the (now deprecated) text-davinci-002 model from OpenAI.
115 |
116 | Users may create their own descriptions by:
117 | ~~~
118 | python tools/generate_descriptions.py --ann-path datasets/lvis/lvis_v1_val.json --openai-model text-davinci-003
119 | ~~~
120 | which will create a file called `lvis_gpt3_text-davinci-003_descriptions_own.json`.
121 | Be sure to include your own OpenAI API key at the top of `tools/generate_descriptions.py`.
122 |
123 | `lvis_gpt3_text-davinci-002_features_author.npy` is the CLIP embeddings for each class in the LVIS
124 | dataset using the descriptions we generate using GPT-3.
125 | ~~~
126 | python tools/dump_clip_features_lvis_sentences.py --descriptions-path datasets/metadata/lvis_gpt3_text-davinci-002_descriptions_author.json --ann-path datasets/lvis/lvis_v1_val.json --out-path datasets/metadata/lvis_gpt3_text-davinci-002_features_author.npy
127 | ~~~
128 |
129 | `lvis_image_exemplar_dict_K-005_author.json` is the dictionary of image exemplars for each LVIS class
130 | used in the paper and produce our results.
131 | One can create their own as follows:
132 | ~~~
133 | python tools/sample_exemplars.py --lvis-ann-path datasets/lvis/lvis_v1_val.json --exemplar-dict-path datasets/metadata/exemplar_dict.json -K 5 --out-path datasets/metadata/lvis_image_exemplar_dict_K-005_own.json
134 | ~~~
135 |
136 | `lvis_image_exemplar_features_agg_K-005_author.npy` is the CLIP embeddings for each class in the LVIS dataset
137 | when using image examplars AND our trained visual aggregator for combining multiple exemplars.
138 | One can create their own using our trained visual aggregator as follows (see [INSTALL.md](../../docs/INSTALL.md) for downloading
139 | visual aggregator weights):
140 | ~~~
141 | python tools/generate_vision_classifier_agg.py --exemplar-list datasets/metadata/lvis_image_exemplar_dict_K-005_own.json --out-path datasets/metadata/lvis_image_exemplar_features_agg_K-005_own.npy --num-augs 5 --tta --load-path checkpoints/visual_aggregator/visual_aggregator_ckpt_4_transformer.pth
142 | ~~~
143 |
144 | `lvis_image_exemplar_features_avg_K-005_author.npy` is the CLIP embeddings for each class in the LVIS dataset
145 | when using image examplars AND averaging the CLIP embeddings of multiple exemplars (not using our trained
146 | aggregator).
147 | One can create their own for example:
148 | ~~~
149 | python tools/get_exemplars_tta.py --ann-path /users/prannay/mm-ovod/datasets/metadata/lvis_image_exemplar_dict_K-005_own.json --output-path datasets/metadata/lvis_image_exemplar_features_avg_K-005_own.npy --num-augs 5
150 | ~~~
151 |
152 | `lvis_multi-modal_agg_K-005_author.npy` is the CLIP embeddings for each class in the LVIS dataset
153 | when using image examplars AND descriptions AND our trained visual aggregator for combining multiple exemplars.
154 | One can create their own for example:
155 | ~~~
156 | python tools/norm_feat_sum_norm.py --feat1-path datasets/metadata/lvis_gpt3_text-davinci-002_features_own.npy --feat2-path datasets/metadata/lvis_image_exemplar_features_agg_K-005_own.npy --out-path datasets/metadata/lvis_multi-modal_agg_K-005_own.npy
157 | ~~~
158 |
159 | `lvis_multi-modal_avg_K-005_author.npy` is the CLIP embeddings for each class in the LVIS dataset
160 | when using image examplars AND descriptions AND averaging the CLIP embeddings of multiple exemplars (not using our trained aggregator).
161 | One can create their own for example:
162 | ~~~
163 | python tools/norm_feat_sum_norm.py --feat1-path datasets/metadata/lvis_gpt3_text-davinci-002_features_own.npy --feat2-path datasets/metadata/lvis_image_exemplar_features_avg_K-005_own.npy --out-path datasets/metadata/lvis_multi-modal_avg_K-005_own.npy
164 | ~~~
165 |
166 | `lvis_clip_a+cname.npy` is the pre-computed CLIP embeddings for each class in the LVIS dataset (from Detic)
167 | They are created by:
168 | ~~~
169 | python tools/dump_clip_features.py --ann datasets/lvis/lvis_v1_val.json --out_path metadata/lvis_v1_clip_a+cname.npy
170 | ~~~
171 |
172 | ### Collating Image Exemplars
173 |
174 | We provide the exact image exemplars used (5 per LVIS category) in our results in the metadata folder defined
175 | above.
176 | However, if you wish to create your own, one first needs to create a full dictionary of exemplars for each
177 | LVIS category.
178 | This is done by:
179 | ~~~
180 | python tools/collate_exemplar_dict.py --lvis-dir datasets/lvis --imagenet-dir datasets/imagenet --visual-genome-dir datasets/VisualGenome --output-path datasets/metadata/exemplar_dict.json
181 | ~~~
182 | This will create `datasets/metadata/exemplar_dict.json` which is a dictionary of exemplars with
183 | at least 10 exemplars per LVIS category.
--------------------------------------------------------------------------------
/datasets/metadata/lvis_gpt3_text-davinci-002_features_author.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/datasets/metadata/lvis_gpt3_text-davinci-002_features_author.npy
--------------------------------------------------------------------------------
/datasets/metadata/lvis_image_exemplar_features_agg_K-005_author.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/datasets/metadata/lvis_image_exemplar_features_agg_K-005_author.npy
--------------------------------------------------------------------------------
/datasets/metadata/lvis_image_exemplar_features_avg_K-005_author.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/datasets/metadata/lvis_image_exemplar_features_avg_K-005_author.npy
--------------------------------------------------------------------------------
/datasets/metadata/lvis_multi-modal_agg_K-005_author.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/datasets/metadata/lvis_multi-modal_agg_K-005_author.npy
--------------------------------------------------------------------------------
/datasets/metadata/lvis_multi-modal_avg_K-005_author.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/datasets/metadata/lvis_multi-modal_avg_K-005_author.npy
--------------------------------------------------------------------------------
/datasets/metadata/lvis_v1_clip_a+cname.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/datasets/metadata/lvis_v1_clip_a+cname.npy
--------------------------------------------------------------------------------
/docs/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | ### Requirements
4 | - Linux or macOS with Python ≥ 3.6
5 | - PyTorch ≥ 1.8.
6 | Install them together at [pytorch.org](https://pytorch.org) to make sure of this. Note, please check
7 | PyTorch version matches that is required by Detectron2.
8 | - Detectron2: follow [Detectron2 installation instructions](https://detectron2.readthedocs.io/tutorials/install.html).
9 |
10 |
11 | ### Author conda environment setup
12 | ```bash
13 | conda create --name mm-ovod python=3.8 -y
14 | conda activate mm-ovod
15 | conda install pytorch torchvision=0.9.2 torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia
16 |
17 | # under your working directory
18 | git clone git@github.com:facebookresearch/detectron2.git
19 | cd detectron2
20 | git checkout 2b98c273b240b54d2d0ee6853dc331c4f2ca87b9
21 | pip install -e .
22 |
23 | cd ..
24 | git clone https://github.com/prannaykaul/mm-ovod.git --recurse-submodules
25 | cd mm-ovod
26 | pip install -r requirements.txt
27 | pip uninstall pillow
28 | CC="cc -mavx2" pip install -U --force-reinstall pillow-simd
29 | ```
30 |
31 | Our project (like Detic) use a submodule: [CenterNet2](https://github.com/xingyizhou/CenterNet2.git). If you forget to add `--recurse-submodules`, do `git submodule init` and then `git submodule update`.
32 |
33 |
34 | ### Downloading pre-trained ResNet-50 backbone
35 | We use the ResNet-50 backbone pre-trained on ImageNet-21k-P from [here](
36 | https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth). Please download it from the previous link, place it in the `${mm-ovod_ROOT}/checkpoints` folder and use the following command to convert it for use with detectron2:
37 | ```bash
38 | cd ${mm-ovod_ROOT}
39 | mkdir checkpoints
40 | cd checkpoints
41 | wget https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth
42 | python ../tools/convert-thirdparty-pretrained-model-to-d2.py --path resnet50_miil_21k.pth
43 | ```
44 |
45 | ### Downloading pre-trained visual aggregator
46 | The pretrained model for the visual aggregator is required if one wants to use their own image exemplars to produce a vison-based
47 | classifier. The model can be downloaded from [here](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/visual_aggregator_ckpt_4_transformer.pth.tar) and should be placed in the `${mm-ovod_ROOT}/checkpoints` folder.
48 | ```bash
49 | cd ${mm-ovod_ROOT}
50 | mkdir checkpoints
51 | cd checkpoints
52 | wget https://robots.ox.ac.uk/~prannay/public_models/mm-ovod/visual_aggregator_ckpt_4_transformer.pth.tar
53 | tar -xf visual_aggregator_ckpt_4_transformer.pth.tar
54 | rm visual_aggregator_ckpt_4_transformer.pth.tar
55 | ```
56 |
--------------------------------------------------------------------------------
/docs/MODEL_ZOO.md:
--------------------------------------------------------------------------------
1 | # Multi-Modal Open-Vocabulary Object Detection Model Zoo
2 |
3 | ## Introduction
4 |
5 | This file documents a collection of models reported in our paper.
6 | Training in all cases is done with 4 32GB V100 GPUs.
7 |
8 | #### How to Read the Tables
9 |
10 | The "Name" column contains a link to the config file.
11 | To train a model, run
12 |
13 | ```
14 | python train_net_auto.py --num-gpus 4 --config-file /path/to/config/name.yaml
15 | ```
16 |
17 | To evaluate a model with a trained/ pretrained model, run
18 |
19 | ```
20 | python train_net_auto.py --num-gpus 4 --config-file /path/to/config/name.yaml --eval-only MODEL.WEIGHTS /path/to/weight.pth
21 | ```
22 |
23 |
24 | ## Open-vocabulary LVIS
25 |
26 | | Name | APr | mAP | Weights |
27 | |------------------------------------------------------------------------------------------------------------------------|:----:|:----:|------------------------------------------------------------------|
28 | | [lvis-base_r50_4x_clip_gpt3_descriptions](../configs/lvis-base_r50_4x_clip_gpt3_descriptions.yaml) | 19.3 | 30.3 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_r50_4x_clip_gpt3_descriptions.pth.tar) |
29 | | [lvis-base_r50_4x_clip_image_exemplars_avg](../configs/lvis-base_r50_4x_clip_image_exemplars_avg.yaml) | 14.8 | 28.8 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_r50_4x_clip_image_exemplars_avg.pth.tar) |
30 | | [lvis-base_r50_4x_clip_image_exemplars_agg](../configs/lvis-base_r50_4x_clip_image_exemplars_agg.yaml) | 18.3 | 29.2 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_r50_4x_clip_image_exemplars_agg.pth.tar) |
31 | | [lvis-base_r50_4x_clip_multi_modal_avg](../configs/lvis-base_r50_4x_clip_multi_modal_avg.yaml) | 20.7 | 30.5 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_r50_4x_clip_multi_modal_avg.pth.tar) |
32 | | [lvis-base_r50_4x_clip_multi_modal_agg](../configs/lvis-base_r50_4x_clip_multi_modal_agg.yaml) | 19.2 | 30.6 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_r50_4x_clip_multi_modal_agg.pth.tar) |
33 | | [lvis-base_in-l_r50_4x_4x_clip_gpt3_descriptions](../configs/lvis-base_in-l_r50_4x_4x_clip_gpt3_descriptions.yaml) | 25.8 | 32.6 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_in-l_r50_4x_4x_clip_gpt3_descriptions.pth.tar) |
34 | | [lvis-base_in-l_r50_4x_4x_clip_image_exemplars_avg](../configs/lvis-base_in-l_r50_4x_4x_clip_image_exemplars_avg.yaml) | 21.6 | 31.3 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_in-l_r50_4x_4x_clip_image_exemplars_avg.pth.tar) |
35 | | [lvis-base_in-l_r50_4x_4x_clip_image_exemplars_agg](../configs/lvis-base_in-l_r50_4x_4x_clip_image_exemplars_agg.yaml) | 23.8 | 31.3 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_in-l_r50_4x_4x_clip_image_exemplars_agg.pth.tar) |
36 | | [lvis-base_in-l_r50_4x_4x_clip_multi_modal_avg](../configs/lvis-base_in-l_r50_4x_4x_clip_multi_modal_avg.yaml) | 26.5 | 32.8 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_in-l_r50_4x_4x_clip_multi_modal_avg.pth.tar) |
37 | | [lvis-base_in-l_r50_4x_4x_clip_multi_modal_agg](../configs/lvis-base_in-l_r50_4x_4x_clip_multi_modal_agg.yaml) | 27.3 | 33.1 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_in-l_r50_4x_4x_clip_multi_modal_agg.pth.tar) |
38 |
39 |
40 | #### Note
41 |
42 | - The open-vocabulary LVIS setup is LVIS without rare class annotations in training. We evaluate rare classes as novel classes in testing.
43 |
44 | - All models use [CLIP](https://github.com/openai/CLIP) embeddings as classifiers. This makes the box-supervised models have non-zero mAP on novel classes.
45 |
46 | - The models with `in-l` use the overlap classes between ImageNet-21K and LVIS as image-labeled data.
47 |
48 | - The models which are trained on `in-l` require the corresponding models _without_ `in-l` (indicated by MODEL.WEIGHTS in the config files). Please train or download the model without `in-l` and place them under `${mm-ovod_ROOT}/output/..` before training the model using `in-l` (check the config file).
49 |
50 |
--------------------------------------------------------------------------------
/docs/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/docs/teaser.jpg
--------------------------------------------------------------------------------
/mmovod/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | from .modeling.meta_arch import custom_rcnn
3 | from .modeling.roi_heads import detic_roi_heads
4 | from .modeling.roi_heads import res5_roi_heads
5 | from .modeling.backbone import swintransformer
6 | from .modeling.backbone import timm
7 |
8 |
9 | from .data.datasets import lvis_v1
10 | from .data.datasets import imagenet
11 | # from .data.datasets import objects365
12 |
--------------------------------------------------------------------------------
/mmovod/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/__pycache__/custom_solver.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/__pycache__/custom_solver.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | from detectron2.config import CfgNode as CN
3 |
4 |
5 | def add_mmovod_config(cfg):
6 | _C = cfg
7 |
8 | _C.WITH_IMAGE_LABELS = False # Turn on co-training with classification data
9 |
10 | # Open-vocabulary classifier
11 | _C.MODEL.ROI_BOX_HEAD.USE_ZEROSHOT_CLS = False # Use fixed classifier for open-vocabulary detection
12 | _C.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'datasets/metadata/lvis_v1_clip_a+cname.npy'
13 | _C.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_DIM = 512
14 | _C.MODEL.ROI_BOX_HEAD.NORM_WEIGHT = True
15 | _C.MODEL.ROI_BOX_HEAD.NORM_TEMP = 50.0
16 | _C.MODEL.ROI_BOX_HEAD.IGNORE_ZERO_CATS = False
17 | _C.MODEL.ROI_BOX_HEAD.USE_BIAS = 0.0 # >= 0: not use
18 |
19 | _C.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE = False # CenterNet2
20 | _C.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE = False
21 | _C.MODEL.ROI_BOX_HEAD.PRIOR_PROB = 0.01
22 | _C.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = False # Federated Loss
23 | _C.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = \
24 | 'datasets/metadata/lvis_v1_train_cat_info.json'
25 | _C.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT = 50
26 | _C.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT = 0.5
27 |
28 | # Classification data configs
29 | _C.MODEL.ROI_BOX_HEAD.IMAGE_LABEL_LOSS = 'max_size' # max, softmax, sum
30 | _C.MODEL.ROI_BOX_HEAD.IMAGE_LOSS_WEIGHT = 0.1
31 | _C.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE = 1.0
32 | _C.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX = False # Used for image-box loss and caption loss
33 | _C.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS = 128 # num proposals for image-labeled data
34 | _C.MODEL.ROI_BOX_HEAD.WITH_SOFTMAX_PROP = False # Used for WSDDN
35 | _C.MODEL.ROI_BOX_HEAD.CAPTION_WEIGHT = 1.0 # Caption loss weight
36 | _C.MODEL.ROI_BOX_HEAD.NEG_CAP_WEIGHT = 0.125 # Caption loss hyper-parameter
37 | _C.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP = False # Used for WSDDN
38 | _C.MODEL.ROI_BOX_HEAD.SOFTMAX_WEAK_LOSS = False # Used when USE_SIGMOID_CE is False
39 |
40 | _C.MODEL.ROI_HEADS.MASK_WEIGHT = 1.0
41 | _C.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = False # For demo only
42 |
43 | # Caption losses
44 | _C.MODEL.CAP_BATCH_RATIO = 4 # Ratio between detection data and caption data
45 | _C.MODEL.WITH_CAPTION = False
46 | _C.MODEL.SYNC_CAPTION_BATCH = False # synchronize across GPUs to enlarge # "classes"
47 |
48 | # dynamic class sampling when training with 21K classes
49 | _C.MODEL.DYNAMIC_CLASSIFIER = False
50 | _C.MODEL.NUM_SAMPLE_CATS = 50
51 |
52 | # Different classifiers in testing, used in cross-dataset evaluation
53 | _C.MODEL.RESET_CLS_TESTS = False
54 | _C.MODEL.TEST_CLASSIFIERS = []
55 | _C.MODEL.TEST_NUM_CLASSES = []
56 |
57 | # Backbones
58 | _C.MODEL.SWIN = CN()
59 | _C.MODEL.SWIN.SIZE = 'T' # 'T', 'S', 'B'
60 | _C.MODEL.SWIN.USE_CHECKPOINT = False
61 | _C.MODEL.SWIN.OUT_FEATURES = (1, 2, 3) # FPN stride 8 - 32
62 |
63 | _C.MODEL.TIMM = CN()
64 | _C.MODEL.TIMM.BASE_NAME = 'resnet50'
65 | _C.MODEL.TIMM.OUT_LEVELS = (3, 4, 5)
66 | _C.MODEL.TIMM.NORM = 'FrozenBN'
67 | _C.MODEL.TIMM.FREEZE_AT = 0
68 | _C.MODEL.TIMM.PRETRAINED = False
69 | _C.MODEL.DATASET_LOSS_WEIGHT = []
70 |
71 | # Multi-dataset dataloader
72 | _C.DATALOADER.DATASET_RATIO = [1, 1] # sample ratio
73 | _C.DATALOADER.USE_RFS = [False, False]
74 | _C.DATALOADER.MULTI_DATASET_GROUPING = False # Always true when multi-dataset is enabled
75 | _C.DATALOADER.DATASET_ANN = ['box', 'box'] # Annotation type of each dataset
76 | _C.DATALOADER.USE_DIFF_BS_SIZE = False # Use different batchsize for each dataset
77 | _C.DATALOADER.DATASET_BS = [8, 32] # Used when USE_DIFF_BS_SIZE is on PER GPU!!!!
78 | _C.DATALOADER.DATASET_INPUT_SIZE = [896, 384] # Used when USE_DIFF_BS_SIZE is on
79 | _C.DATALOADER.DATASET_INPUT_SCALE = [(0.1, 2.0), (0.5, 1.5)] # Used when USE_DIFF_BS_SIZE is on
80 | _C.DATALOADER.DATASET_MIN_SIZES = [(640, 800), (320, 400)] # Used when USE_DIFF_BS_SIZE is on
81 | _C.DATALOADER.DATASET_MAX_SIZES = [1333, 667] # Used when USE_DIFF_BS_SIZE is on
82 | _C.DATALOADER.USE_TAR_DATASET = False # for ImageNet-21K, directly reading from unziped files
83 | _C.DATALOADER.TARFILE_PATH = 'datasets/imagenet/metadata-22k/tar_files.npy'
84 | _C.DATALOADER.TAR_INDEX_DIR = 'datasets/imagenet/metadata-22k/tarindex_npy'
85 |
86 | _C.SOLVER.USE_CUSTOM_SOLVER = False
87 | _C.SOLVER.OPTIMIZER = 'SGD'
88 |
89 | _C.INPUT.CUSTOM_AUG = ''
90 | _C.INPUT.TRAIN_SIZE = 640
91 | _C.INPUT.TEST_SIZE = 640
92 | _C.INPUT.SCALE_RANGE = (0.1, 2.)
93 | # 'default' for fixed short/ long edge, 'square' for max size=INPUT.SIZE
94 | _C.INPUT.TEST_INPUT_TYPE = 'default'
95 |
96 | _C.FIND_UNUSED_PARAM = True
97 | _C.EVAL_PRED_AR = False
98 | _C.EVAL_PROPOSAL_AR = False
99 | _C.EVAL_CAT_SPEC_AR = False
100 | _C.IS_DEBUG = False
101 | _C.QUICK_DEBUG = False
102 | _C.FP16 = False
103 | _C.EVAL_AP_FIX = False
104 | _C.GEN_PSEDO_LABELS = False
105 | _C.SAVE_DEBUG_PATH = 'output/save_debug/'
106 |
--------------------------------------------------------------------------------
/mmovod/custom_solver.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | from enum import Enum
3 | import itertools
4 | from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union
5 | import torch
6 |
7 | from detectron2.config import CfgNode
8 |
9 | from detectron2.solver.build import maybe_add_gradient_clipping
10 |
11 | def match_name_keywords(n, name_keywords):
12 | out = False
13 | for b in name_keywords:
14 | if b in n:
15 | out = True
16 | break
17 | return out
18 |
19 | def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
20 | """
21 | Build an optimizer from config.
22 | """
23 | params: List[Dict[str, Any]] = []
24 | memo: Set[torch.nn.parameter.Parameter] = set()
25 |
26 | optimizer_type = cfg.SOLVER.OPTIMIZER
27 | for key, value in model.named_parameters(recurse=True):
28 | if not value.requires_grad:
29 | continue
30 | # Avoid duplicating parameters
31 | if value in memo:
32 | continue
33 | memo.add(value)
34 | lr = cfg.SOLVER.BASE_LR
35 | weight_decay = cfg.SOLVER.WEIGHT_DECAY
36 | param = {"params": [value], "lr": lr}
37 | if optimizer_type != 'ADAMW':
38 | param['weight_decay'] = weight_decay
39 | params += [param]
40 |
41 | def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class
42 | # detectron2 doesn't have full model gradient clipping now
43 | clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE
44 | enable = (
45 | cfg.SOLVER.CLIP_GRADIENTS.ENABLED
46 | and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model"
47 | and clip_norm_val > 0.0
48 | )
49 |
50 | class FullModelGradientClippingOptimizer(optim):
51 | def step(self, closure=None):
52 | all_params = itertools.chain(*[x["params"] for x in self.param_groups])
53 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val)
54 | super().step(closure=closure)
55 |
56 | return FullModelGradientClippingOptimizer if enable else optim
57 |
58 | if optimizer_type == 'SGD':
59 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)(
60 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
61 | nesterov=cfg.SOLVER.NESTEROV
62 | )
63 | elif optimizer_type == 'ADAMW':
64 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)(
65 | params, cfg.SOLVER.BASE_LR,
66 | weight_decay=cfg.SOLVER.WEIGHT_DECAY
67 | )
68 | else:
69 | raise NotImplementedError(f"no optimizer type {optimizer_type}")
70 | if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model":
71 | optimizer = maybe_add_gradient_clipping(cfg, optimizer)
72 | return optimizer
73 |
--------------------------------------------------------------------------------
/mmovod/data/__pycache__/custom_build_augmentation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/__pycache__/custom_build_augmentation.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/data/__pycache__/custom_dataset_dataloader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/__pycache__/custom_dataset_dataloader.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/data/__pycache__/custom_dataset_mapper.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/__pycache__/custom_dataset_mapper.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/data/__pycache__/tar_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/__pycache__/tar_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/data/custom_build_augmentation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import logging
3 | import numpy as np
4 | import pycocotools.mask as mask_util
5 | import torch
6 | from fvcore.common.file_io import PathManager
7 | from PIL import Image
8 |
9 |
10 | from detectron2.data import transforms as T
11 | from .transforms.custom_augmentation_impl import EfficientDetResizeCrop
12 |
13 | def build_custom_augmentation(cfg, is_train, scale=None, size=None, \
14 | min_size=None, max_size=None):
15 | """
16 | Create a list of default :class:`Augmentation` from config.
17 | Now it includes resizing and flipping.
18 |
19 | Returns:
20 | list[Augmentation]
21 | """
22 | if cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge':
23 | if is_train:
24 | min_size = cfg.INPUT.MIN_SIZE_TRAIN if min_size is None else min_size
25 | max_size = cfg.INPUT.MAX_SIZE_TRAIN if max_size is None else max_size
26 | sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
27 | else:
28 | min_size = cfg.INPUT.MIN_SIZE_TEST
29 | max_size = cfg.INPUT.MAX_SIZE_TEST
30 | sample_style = "choice"
31 | augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
32 | elif cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop':
33 | if is_train:
34 | scale = cfg.INPUT.SCALE_RANGE if scale is None else scale
35 | size = cfg.INPUT.TRAIN_SIZE if size is None else size
36 | else:
37 | scale = (1, 1)
38 | size = cfg.INPUT.TEST_SIZE
39 | augmentation = [EfficientDetResizeCrop(size, scale)]
40 | else:
41 | assert 0, cfg.INPUT.CUSTOM_AUG
42 |
43 | if is_train:
44 | augmentation.append(T.RandomFlip())
45 | return augmentation
46 |
47 |
48 | build_custom_transform_gen = build_custom_augmentation
49 | """
50 | Alias for backward-compatibility.
51 | """
--------------------------------------------------------------------------------
/mmovod/data/custom_dataset_dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/multi_dataset_dataloader.py (Apache-2.0 License)
3 | import copy
4 | import logging
5 | import numpy as np
6 | import operator
7 | import torch
8 | import torch.utils.data
9 | import json
10 | from detectron2.utils.comm import get_world_size
11 | from detectron2.utils.logger import _log_api_usage, log_first_n
12 |
13 | from detectron2.config import configurable
14 | from detectron2.data import samplers
15 | from torch.utils.data.sampler import BatchSampler, Sampler
16 | from detectron2.data.common import DatasetFromList, MapDataset
17 | from detectron2.data.dataset_mapper import DatasetMapper
18 | from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader
19 | from detectron2.data.samplers import TrainingSampler, RepeatFactorTrainingSampler
20 | from detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram
21 | from detectron2.data.build import filter_images_with_only_crowd_annotations
22 | from detectron2.data.build import filter_images_with_few_keypoints
23 | from detectron2.data.build import check_metadata_consistency
24 | from detectron2.data.catalog import MetadataCatalog, DatasetCatalog
25 | from detectron2.utils import comm
26 | import itertools
27 | import math
28 | from collections import defaultdict
29 | from typing import Optional
30 |
31 |
32 | def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
33 | sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
34 | if 'MultiDataset' in sampler_name:
35 | dataset_dicts = get_detection_dataset_dicts_with_source(
36 | cfg.DATASETS.TRAIN,
37 | filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
38 | min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
39 | if cfg.MODEL.KEYPOINT_ON else 0,
40 | proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
41 | ann_types=cfg.DATALOADER.DATASET_ANN,
42 | )
43 | else:
44 | dataset_dicts = get_detection_dataset_dicts(
45 | cfg.DATASETS.TRAIN,
46 | filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
47 | min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
48 | if cfg.MODEL.KEYPOINT_ON else 0,
49 | proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
50 | )
51 |
52 | if mapper is None:
53 | mapper = DatasetMapper(cfg, True)
54 |
55 | if sampler is not None:
56 | pass
57 | elif sampler_name == "TrainingSampler":
58 | sampler = TrainingSampler(len(dataset))
59 | elif sampler_name == "MultiDatasetSampler":
60 | sampler = MultiDatasetSampler(
61 | dataset_dicts,
62 | dataset_ratio = cfg.DATALOADER.DATASET_RATIO,
63 | use_rfs = cfg.DATALOADER.USE_RFS,
64 | dataset_ann = cfg.DATALOADER.DATASET_ANN,
65 | repeat_threshold = cfg.DATALOADER.REPEAT_THRESHOLD,
66 | )
67 | elif sampler_name == "RepeatFactorTrainingSampler":
68 | repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
69 | dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD
70 | )
71 | sampler = RepeatFactorTrainingSampler(repeat_factors)
72 | else:
73 | raise ValueError("Unknown training sampler: {}".format(sampler_name))
74 |
75 | return {
76 | "dataset": dataset_dicts,
77 | "sampler": sampler,
78 | "mapper": mapper,
79 | "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
80 | "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
81 | "num_workers": cfg.DATALOADER.NUM_WORKERS,
82 | 'multi_dataset_grouping': cfg.DATALOADER.MULTI_DATASET_GROUPING,
83 | 'use_diff_bs_size': cfg.DATALOADER.USE_DIFF_BS_SIZE,
84 | 'dataset_bs': cfg.DATALOADER.DATASET_BS,
85 | 'num_datasets': len(cfg.DATASETS.TRAIN)
86 | }
87 |
88 |
89 | @configurable(from_config=_custom_train_loader_from_config)
90 | def build_custom_train_loader(
91 | dataset, *, mapper, sampler,
92 | total_batch_size=16,
93 | aspect_ratio_grouping=True,
94 | num_workers=0,
95 | num_datasets=1,
96 | multi_dataset_grouping=False,
97 | use_diff_bs_size=False,
98 | dataset_bs=[]
99 | ):
100 | """
101 | Modified from detectron2.data.build.build_custom_train_loader, but supports
102 | different samplers
103 | """
104 | if isinstance(dataset, list):
105 | dataset = DatasetFromList(dataset, copy=False)
106 | if mapper is not None:
107 | dataset = MapDataset(dataset, mapper)
108 | if sampler is None:
109 | sampler = TrainingSampler(len(dataset))
110 | assert isinstance(sampler, torch.utils.data.sampler.Sampler)
111 | if multi_dataset_grouping:
112 | return build_multi_dataset_batch_data_loader(
113 | use_diff_bs_size,
114 | dataset_bs,
115 | dataset,
116 | sampler,
117 | total_batch_size,
118 | num_datasets=num_datasets,
119 | num_workers=num_workers,
120 | )
121 | else:
122 | return build_batch_data_loader(
123 | dataset,
124 | sampler,
125 | total_batch_size,
126 | aspect_ratio_grouping=aspect_ratio_grouping,
127 | num_workers=num_workers,
128 | )
129 |
130 |
131 | def build_multi_dataset_batch_data_loader(
132 | use_diff_bs_size, dataset_bs,
133 | dataset, sampler, total_batch_size, num_datasets, num_workers=0
134 | ):
135 | """
136 | """
137 | world_size = get_world_size()
138 | assert (
139 | total_batch_size > 0 and total_batch_size % world_size == 0
140 | ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
141 | total_batch_size, world_size
142 | )
143 |
144 | batch_size = total_batch_size // world_size
145 | data_loader = torch.utils.data.DataLoader(
146 | dataset,
147 | sampler=sampler,
148 | num_workers=num_workers,
149 | batch_sampler=None,
150 | collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
151 | worker_init_fn=worker_init_reset_seed,
152 | ) # yield individual mapped dict
153 | if use_diff_bs_size:
154 | return DIFFMDAspectRatioGroupedDataset(
155 | data_loader, dataset_bs, num_datasets)
156 | else:
157 | return MDAspectRatioGroupedDataset(
158 | data_loader, batch_size, num_datasets)
159 |
160 |
161 | def filter_images_with_only_crowd_annotations_detic(dataset_dicts):
162 | """
163 | Filter out images with none annotations or only crowd annotations
164 | (i.e., images without non-crowd annotations).
165 | A common training-time preprocessing on COCO dataset.
166 |
167 | Args:
168 | dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
169 |
170 | Returns:
171 | list[dict]: the same format, but filtered.
172 | """
173 | num_before = len(dataset_dicts)
174 |
175 | def valid(anns, ann_type):
176 | if ann_type != "box":
177 | return True
178 | for ann in anns:
179 | if ann.get("iscrowd", 0) == 0:
180 | return True
181 | return False
182 |
183 | dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"], x['ann_type'])]
184 | num_after = len(dataset_dicts)
185 | logger = logging.getLogger(__name__)
186 | logger.info(
187 | "Removed {} images with no usable annotations. {} images left.".format(
188 | num_before - num_after, num_after
189 | )
190 | )
191 | return dataset_dicts
192 |
193 |
194 | def get_detection_dataset_dicts_with_source(
195 | dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None, ann_types=[],
196 | ):
197 | assert len(dataset_names)
198 | dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names]
199 | for dataset_name, dicts in zip(dataset_names, dataset_dicts):
200 | assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
201 |
202 | for source_id, (dataset_name, dicts) in enumerate(zip(
203 | dataset_names, dataset_dicts)):
204 | assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
205 | for d in dicts:
206 | d['dataset_source'] = source_id
207 | d['ann_type'] = ann_types[source_id]
208 |
209 | if "annotations" in dicts[0]:
210 | try:
211 | class_names = MetadataCatalog.get(dataset_name).thing_classes
212 | check_metadata_consistency("thing_classes", dataset_name)
213 | print_instances_class_histogram(dicts, class_names)
214 | except AttributeError: # class names are not available for this dataset
215 | pass
216 |
217 | assert proposal_files is None
218 |
219 | dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
220 |
221 | has_instances = "annotations" in dataset_dicts[0]
222 | if filter_empty and has_instances:
223 | dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
224 | if min_keypoints > 0 and has_instances:
225 | dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
226 |
227 | return dataset_dicts
228 |
229 |
230 | class MultiDatasetSampler(Sampler):
231 | def __init__(
232 | self,
233 | dataset_dicts,
234 | dataset_ratio,
235 | use_rfs,
236 | dataset_ann,
237 | repeat_threshold=0.001,
238 | seed: Optional[int] = None,
239 | ):
240 | """
241 | """
242 | sizes = [0 for _ in range(len(dataset_ratio))]
243 | for d in dataset_dicts:
244 | sizes[d['dataset_source']] += 1
245 | print('dataset sizes', sizes)
246 | self.sizes = sizes
247 | assert len(dataset_ratio) == len(sizes), \
248 | 'length of dataset ratio {} should be equal to number if dataset {}'.format(
249 | len(dataset_ratio), len(sizes)
250 | )
251 | if seed is None:
252 | seed = comm.shared_random_seed()
253 | self._seed = int(seed)
254 | self._rank = comm.get_rank()
255 | self._world_size = comm.get_world_size()
256 |
257 | self.dataset_ids = torch.tensor(
258 | [d['dataset_source'] for d in dataset_dicts], dtype=torch.long)
259 |
260 | dataset_weight = (
261 | [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio)
262 | for i, (r, s) in enumerate(zip(dataset_ratio, sizes))]
263 | )
264 | dataset_weight = torch.cat(dataset_weight)
265 |
266 | rfs_factors = []
267 | st = 0
268 | for i, s in enumerate(sizes):
269 | if use_rfs[i]:
270 | if dataset_ann[i] == 'box':
271 | rfs_func = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency
272 | else:
273 | rfs_func = repeat_factors_from_tag_frequency
274 | rfs_factor = rfs_func(
275 | dataset_dicts[st: st + s],
276 | repeat_thresh=repeat_threshold)
277 | rfs_factor = rfs_factor * (s / rfs_factor.sum())
278 | else:
279 | rfs_factor = torch.ones(s)
280 | rfs_factors.append(rfs_factor)
281 | st = st + s
282 | rfs_factors = torch.cat(rfs_factors)
283 |
284 | self.weights = dataset_weight * rfs_factors
285 | self.sample_epoch_size = len(self.weights)
286 |
287 | def __iter__(self):
288 | start = self._rank
289 | yield from itertools.islice(
290 | self._infinite_indices(), start, None, self._world_size)
291 |
292 | def _infinite_indices(self):
293 | g = torch.Generator()
294 | g.manual_seed(self._seed)
295 | while True:
296 | ids = torch.multinomial(
297 | self.weights, self.sample_epoch_size, generator=g,
298 | replacement=True)
299 | nums = ([(self.dataset_ids[ids] == i).sum().int().item()
300 | for i in range(len(self.sizes))])
301 | yield from ids
302 |
303 |
304 | class MDAspectRatioGroupedDataset(torch.utils.data.IterableDataset):
305 | def __init__(self, dataset, batch_size, num_datasets):
306 | """
307 | """
308 | self.dataset = dataset
309 | self.batch_size = batch_size
310 | self._buckets = [[] for _ in range(2 * num_datasets)]
311 |
312 | def __iter__(self):
313 | for d in self.dataset:
314 | w, h = d["width"], d["height"]
315 | aspect_ratio_bucket_id = 0 if w > h else 1
316 | bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id
317 | bucket = self._buckets[bucket_id]
318 | bucket.append(d)
319 | if len(bucket) == self.batch_size:
320 | yield bucket[:]
321 | del bucket[:]
322 |
323 |
324 | class DIFFMDAspectRatioGroupedDataset(torch.utils.data.IterableDataset):
325 | def __init__(self, dataset, batch_sizes, num_datasets):
326 | """
327 | """
328 | self.dataset = dataset
329 | self.batch_sizes = batch_sizes
330 | self._buckets = [[] for _ in range(2 * num_datasets)]
331 |
332 | def __iter__(self):
333 | for d in self.dataset:
334 | w, h = d["width"], d["height"]
335 | aspect_ratio_bucket_id = 0 if w > h else 1
336 | bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id
337 | bucket = self._buckets[bucket_id]
338 | bucket.append(d)
339 | if len(bucket) == self.batch_sizes[d['dataset_source']]:
340 | yield bucket[:]
341 | del bucket[:]
342 |
343 |
344 | def repeat_factors_from_tag_frequency(dataset_dicts, repeat_thresh):
345 | """
346 | """
347 | category_freq = defaultdict(int)
348 | for dataset_dict in dataset_dicts:
349 | cat_ids = dataset_dict['pos_category_ids']
350 | for cat_id in cat_ids:
351 | category_freq[cat_id] += 1
352 | num_images = len(dataset_dicts)
353 | for k, v in category_freq.items():
354 | category_freq[k] = v / num_images
355 |
356 | category_rep = {
357 | cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq))
358 | for cat_id, cat_freq in category_freq.items()
359 | }
360 |
361 | rep_factors = []
362 | for dataset_dict in dataset_dicts:
363 | cat_ids = dataset_dict['pos_category_ids']
364 | rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0)
365 | rep_factors.append(rep_factor)
366 |
367 | return torch.tensor(rep_factors, dtype=torch.float32)
--------------------------------------------------------------------------------
/mmovod/data/custom_dataset_mapper.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import copy
3 | import logging
4 | import numpy as np
5 | from typing import List, Optional, Union
6 | import torch
7 | import pycocotools.mask as mask_util
8 |
9 | from detectron2.config import configurable
10 |
11 | from detectron2.data import detection_utils as utils
12 | from detectron2.data.detection_utils import transform_keypoint_annotations
13 | from detectron2.data import transforms as T
14 | from detectron2.data.dataset_mapper import DatasetMapper
15 | from detectron2.structures import Boxes, BoxMode, Instances
16 | from detectron2.structures import Keypoints, PolygonMasks, BitMasks
17 | from fvcore.transforms.transform import TransformList
18 | from .custom_build_augmentation import build_custom_augmentation
19 |
20 | __all__ = ["CustomDatasetMapper"]
21 |
22 | class CustomDatasetMapper(DatasetMapper):
23 | @configurable
24 | def __init__(self, is_train: bool,
25 | with_ann_type=False,
26 | dataset_ann=[],
27 | use_diff_bs_size=False,
28 | dataset_augs=[],
29 | is_debug=False,
30 | use_tar_dataset=False,
31 | tarfile_path='',
32 | tar_index_dir='',
33 | **kwargs):
34 | """
35 | add image labels
36 | """
37 | self.with_ann_type = with_ann_type
38 | self.dataset_ann = dataset_ann
39 | self.use_diff_bs_size = use_diff_bs_size
40 | if self.use_diff_bs_size and is_train:
41 | self.dataset_augs = [T.AugmentationList(x) for x in dataset_augs]
42 | self.is_debug = is_debug
43 | self.use_tar_dataset = use_tar_dataset
44 | if self.use_tar_dataset:
45 | raise NotImplementedError('Using tar dataset not supported')
46 | super().__init__(is_train, **kwargs)
47 |
48 |
49 | @classmethod
50 | def from_config(cls, cfg, is_train: bool = True):
51 | ret = super().from_config(cfg, is_train)
52 | ret.update({
53 | 'with_ann_type': cfg.WITH_IMAGE_LABELS,
54 | 'dataset_ann': cfg.DATALOADER.DATASET_ANN,
55 | 'use_diff_bs_size': cfg.DATALOADER.USE_DIFF_BS_SIZE,
56 | 'is_debug': cfg.IS_DEBUG,
57 | 'use_tar_dataset': cfg.DATALOADER.USE_TAR_DATASET,
58 | 'tarfile_path': cfg.DATALOADER.TARFILE_PATH,
59 | 'tar_index_dir': cfg.DATALOADER.TAR_INDEX_DIR,
60 | })
61 | if ret['use_diff_bs_size'] and is_train:
62 | if cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop':
63 | dataset_scales = cfg.DATALOADER.DATASET_INPUT_SCALE
64 | dataset_sizes = cfg.DATALOADER.DATASET_INPUT_SIZE
65 | ret['dataset_augs'] = [
66 | build_custom_augmentation(cfg, True, scale, size) \
67 | for scale, size in zip(dataset_scales, dataset_sizes)]
68 | else:
69 | assert cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge'
70 | min_sizes = cfg.DATALOADER.DATASET_MIN_SIZES
71 | max_sizes = cfg.DATALOADER.DATASET_MAX_SIZES
72 | ret['dataset_augs'] = [
73 | build_custom_augmentation(
74 | cfg, True, min_size=mi, max_size=ma) \
75 | for mi, ma in zip(min_sizes, max_sizes)]
76 | else:
77 | ret['dataset_augs'] = []
78 |
79 | return ret
80 |
81 | def __call__(self, dataset_dict):
82 | """
83 | include image labels
84 | """
85 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
86 | # USER: Write your own image loading if it's not from a file
87 | if 'file_name' in dataset_dict:
88 | ori_image = utils.read_image(
89 | dataset_dict["file_name"], format=self.image_format)
90 | else:
91 | ori_image, _, _ = self.tar_dataset[dataset_dict["tar_index"]]
92 | ori_image = utils._apply_exif_orientation(ori_image)
93 | ori_image = utils.convert_PIL_to_numpy(ori_image, self.image_format)
94 | utils.check_image_size(dataset_dict, ori_image)
95 |
96 | # USER: Remove if you don't do semantic/panoptic segmentation.
97 | if "sem_seg_file_name" in dataset_dict:
98 | sem_seg_gt = utils.read_image(
99 | dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
100 | else:
101 | sem_seg_gt = None
102 |
103 | if self.is_debug:
104 | dataset_dict['dataset_source'] = 0
105 |
106 | not_full_labeled = 'dataset_source' in dataset_dict and \
107 | self.with_ann_type and \
108 | self.dataset_ann[dataset_dict['dataset_source']] != 'box'
109 |
110 | aug_input = T.AugInput(copy.deepcopy(ori_image), sem_seg=sem_seg_gt)
111 | if self.use_diff_bs_size and self.is_train:
112 | transforms = \
113 | self.dataset_augs[dataset_dict['dataset_source']](aug_input)
114 | else:
115 | transforms = self.augmentations(aug_input)
116 | image, sem_seg_gt = aug_input.image, aug_input.sem_seg
117 |
118 | image_shape = image.shape[:2] # h, w
119 | dataset_dict["image"] = torch.as_tensor(
120 | np.ascontiguousarray(image.transpose(2, 0, 1)))
121 |
122 | if sem_seg_gt is not None:
123 | dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
124 |
125 | # USER: Remove if you don't use pre-computed proposals.
126 | # Most users would not need this feature.
127 | if self.proposal_topk is not None:
128 | utils.transform_proposals(
129 | dataset_dict, image_shape, transforms,
130 | proposal_topk=self.proposal_topk
131 | )
132 |
133 | if not self.is_train:
134 | # USER: Modify this if you want to keep them for some reason.
135 | dataset_dict.pop("annotations", None)
136 | dataset_dict.pop("sem_seg_file_name", None)
137 | return dataset_dict
138 |
139 | if "annotations" in dataset_dict:
140 | # USER: Modify this if you want to keep them for some reason.
141 | for anno in dataset_dict["annotations"]:
142 | if not self.use_instance_mask:
143 | anno.pop("segmentation", None)
144 | if not self.use_keypoint:
145 | anno.pop("keypoints", None)
146 |
147 | # USER: Implement additional transformations if you have other types of data
148 | all_annos = [
149 | (utils.transform_instance_annotations(
150 | obj, transforms, image_shape,
151 | keypoint_hflip_indices=self.keypoint_hflip_indices,
152 | ), obj.get("iscrowd", 0))
153 | for obj in dataset_dict.pop("annotations")
154 | ]
155 | annos = [ann[0] for ann in all_annos if ann[1] == 0]
156 | instances = utils.annotations_to_instances(
157 | annos, image_shape, mask_format=self.instance_mask_format
158 | )
159 |
160 | del all_annos
161 | if self.recompute_boxes:
162 | instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
163 | dataset_dict["instances"] = utils.filter_empty_instances(instances)
164 | if self.with_ann_type:
165 | dataset_dict["pos_category_ids"] = dataset_dict.get(
166 | 'pos_category_ids', [])
167 | dataset_dict["ann_type"] = \
168 | self.dataset_ann[dataset_dict['dataset_source']]
169 | if self.is_debug and (('pos_category_ids' not in dataset_dict) or \
170 | (dataset_dict['pos_category_ids'] == [])):
171 | dataset_dict['pos_category_ids'] = [x for x in sorted(set(
172 | dataset_dict['instances'].gt_classes.tolist()
173 | ))]
174 | return dataset_dict
175 |
--------------------------------------------------------------------------------
/mmovod/data/datasets/__pycache__/cc.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/cc.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/data/datasets/__pycache__/coco_zeroshot.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/coco_zeroshot.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/data/datasets/__pycache__/imagenet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/imagenet.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/data/datasets/__pycache__/lvis_22k_categories.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/lvis_22k_categories.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/data/datasets/__pycache__/lvis_v1.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/lvis_v1.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/data/datasets/__pycache__/objects365.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/objects365.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/data/datasets/__pycache__/oid.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/oid.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/data/datasets/__pycache__/register_oid.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/register_oid.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/data/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # import logging
3 | import os
4 |
5 | from detectron2.data import DatasetCatalog, MetadataCatalog
6 | from detectron2.data.datasets.lvis import get_lvis_instances_meta
7 | from .lvis_v1 import custom_load_lvis_json
8 |
9 |
10 | def custom_register_imagenet_instances(name, metadata, json_file, image_root):
11 | """
12 | """
13 | DatasetCatalog.register(name, lambda: custom_load_lvis_json(
14 | json_file, image_root, name))
15 | MetadataCatalog.get(name).set(
16 | json_file=json_file, image_root=image_root,
17 | evaluator_type="imagenet", **metadata
18 | )
19 |
20 |
21 | _CUSTOM_SPLITS_IMAGENET = {
22 | "imagenet_lvis_v1": (
23 | "imagenet/imagenet21k_P/",
24 | "imagenet/annotations/imagenet_lvis_image_info.json",
25 | ),
26 | }
27 |
28 |
29 | for key, (image_root, json_file) in _CUSTOM_SPLITS_IMAGENET.items():
30 | _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
31 | custom_register_imagenet_instances(
32 | key,
33 | get_lvis_instances_meta('lvis_v1'),
34 | os.path.join(_root, json_file) if "://" not in json_file else json_file,
35 | os.path.join(_root, image_root),
36 | )
37 |
--------------------------------------------------------------------------------
/mmovod/data/datasets/lvis_v1.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import logging
3 | import os
4 |
5 | from fvcore.common.timer import Timer
6 | from detectron2.structures import BoxMode
7 | from fvcore.common.file_io import PathManager
8 | from detectron2.data import DatasetCatalog, MetadataCatalog
9 | from detectron2.data.datasets.lvis import get_lvis_instances_meta
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | __all__ = ["custom_load_lvis_json", "custom_register_lvis_instances"]
14 |
15 |
16 | def custom_register_lvis_instances(name, metadata, json_file, image_root):
17 | """
18 | """
19 | DatasetCatalog.register(name, lambda: custom_load_lvis_json(
20 | json_file, image_root, name))
21 | MetadataCatalog.get(name).set(
22 | json_file=json_file, image_root=image_root,
23 | evaluator_type="lvis", **metadata
24 | )
25 |
26 |
27 | def custom_load_lvis_json(json_file, image_root, dataset_name=None):
28 | '''
29 | Modifications:
30 | use `file_name`
31 | convert neg_category_ids
32 | add pos_category_ids
33 | '''
34 | from lvis import LVIS
35 |
36 | json_file = PathManager.get_local_path(json_file)
37 |
38 | timer = Timer()
39 | lvis_api = LVIS(json_file)
40 | if timer.seconds() > 1:
41 | logger.info("Loading {} takes {:.2f} seconds.".format(
42 | json_file, timer.seconds()))
43 |
44 | catid2contid = {x['id']: i for i, x in enumerate(
45 | sorted(lvis_api.dataset['categories'], key=lambda x: x['id']))}
46 | if len(lvis_api.dataset['categories']) == 1203:
47 | for x in lvis_api.dataset['categories']:
48 | assert catid2contid[x['id']] == x['id'] - 1
49 | img_ids = sorted(lvis_api.imgs.keys())
50 | imgs = lvis_api.load_imgs(img_ids)
51 | anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids]
52 |
53 | ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image]
54 | assert len(set(ann_ids)) == len(ann_ids), \
55 | "Annotation ids in '{}' are not unique".format(json_file)
56 |
57 | imgs_anns = list(zip(imgs, anns))
58 | logger.info("Loaded {} images in the LVIS v1 format from {}".format(
59 | len(imgs_anns), json_file))
60 |
61 | dataset_dicts = []
62 |
63 | for (img_dict, anno_dict_list) in imgs_anns:
64 | record = {}
65 | if "file_name" in img_dict:
66 | file_name = img_dict["file_name"]
67 | if img_dict["file_name"].startswith("COCO"):
68 | file_name = file_name[-16:]
69 | record["file_name"] = os.path.join(image_root, file_name)
70 | elif 'coco_url' in img_dict:
71 | # e.g., http://images.cocodataset.org/train2017/000000391895.jpg
72 | file_name = img_dict["coco_url"][30:]
73 | record["file_name"] = os.path.join(image_root, file_name)
74 | elif 'tar_index' in img_dict:
75 | record['tar_index'] = img_dict['tar_index']
76 |
77 | record["height"] = img_dict["height"]
78 | record["width"] = img_dict["width"]
79 | record["not_exhaustive_category_ids"] = img_dict.get(
80 | "not_exhaustive_category_ids", [])
81 | record["neg_category_ids"] = img_dict.get("neg_category_ids", [])
82 | # NOTE: modified by Xingyi: convert to 0-based
83 | record["neg_category_ids"] = [
84 | catid2contid[x] for x in record["neg_category_ids"]]
85 | if 'pos_category_ids' in img_dict:
86 | record['pos_category_ids'] = [
87 | catid2contid[x] for x in img_dict.get("pos_category_ids", [])]
88 | if 'captions' in img_dict:
89 | record['captions'] = img_dict['captions']
90 | if 'caption_features' in img_dict:
91 | record['caption_features'] = img_dict['caption_features']
92 | image_id = record["image_id"] = img_dict["id"]
93 |
94 | objs = []
95 | for anno in anno_dict_list:
96 | assert anno["image_id"] == image_id
97 | if anno.get('iscrowd', 0) > 0:
98 | continue
99 | obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS}
100 | obj["category_id"] = catid2contid[anno['category_id']]
101 | if 'segmentation' in anno:
102 | segm = anno["segmentation"]
103 | valid_segm = [poly for poly in segm
104 | if len(poly) % 2 == 0 and len(poly) >= 6]
105 | # assert len(segm) == len(
106 | # valid_segm
107 | # ), "Annotation contains an invalid polygon with < 3 points"
108 | if not len(segm) == len(valid_segm):
109 | print('Annotation contains an invalid polygon with < 3 points')
110 | assert len(segm) > 0
111 | obj["segmentation"] = segm
112 | objs.append(obj)
113 | record["annotations"] = objs
114 | dataset_dicts.append(record)
115 |
116 | return dataset_dicts
117 |
118 |
119 | _CUSTOM_SPLITS_LVIS = {
120 | "lvis_v1_train_norare": ("coco/", "lvis/lvis_v1_train_norare.json"),
121 | "lvis_v1_val_rareonly": ("coco/", "lvis/lvis_v1_val_rare_only.json"),
122 | "lvis_v1_val_norare": ("coco/", "lvis/lvis_v1_val_norare.json"),
123 | "lvis_v1_train_rareonly": ("coco/", "lvis/lvis_v1_train_rare_only.json"),
124 | }
125 |
126 |
127 | for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items():
128 | _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
129 | custom_register_lvis_instances(
130 | key,
131 | get_lvis_instances_meta(key),
132 | os.path.join(_root, json_file) if "://" not in json_file else json_file,
133 | os.path.join(_root, image_root),
134 | )
135 |
--------------------------------------------------------------------------------
/mmovod/data/transforms/__pycache__/custom_augmentation_impl.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/transforms/__pycache__/custom_augmentation_impl.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/data/transforms/__pycache__/custom_transform.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/transforms/__pycache__/custom_transform.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/data/transforms/custom_augmentation_impl.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 | # Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
4 | # Modified by Xingyi Zhou
5 | # The original code is under Apache-2.0 License
6 | import numpy as np
7 | import sys
8 | from fvcore.transforms.transform import (
9 | BlendTransform,
10 | CropTransform,
11 | HFlipTransform,
12 | NoOpTransform,
13 | Transform,
14 | VFlipTransform,
15 | )
16 | from PIL import Image
17 |
18 | from detectron2.data.transforms.augmentation import Augmentation
19 | from .custom_transform import EfficientDetResizeCropTransform
20 |
21 | __all__ = [
22 | "EfficientDetResizeCrop",
23 | ]
24 |
25 | class EfficientDetResizeCrop(Augmentation):
26 | """
27 | Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge.
28 | If `max_size` is reached, then downscale so that the longer edge does not exceed max_size.
29 | """
30 |
31 | def __init__(
32 | self, size, scale, interp=Image.BILINEAR
33 | ):
34 | """
35 | """
36 | super().__init__()
37 | self.target_size = (size, size)
38 | self.scale = scale
39 | self.interp = interp
40 |
41 | def get_transform(self, img):
42 | # Select a random scale factor.
43 | scale_factor = np.random.uniform(*self.scale)
44 | scaled_target_height = scale_factor * self.target_size[0]
45 | scaled_target_width = scale_factor * self.target_size[1]
46 | # Recompute the accurate scale_factor using rounded scaled image size.
47 | width, height = img.shape[1], img.shape[0]
48 | img_scale_y = scaled_target_height / height
49 | img_scale_x = scaled_target_width / width
50 | img_scale = min(img_scale_y, img_scale_x)
51 |
52 | # Select non-zero random offset (x, y) if scaled image is larger than target size
53 | scaled_h = int(height * img_scale)
54 | scaled_w = int(width * img_scale)
55 | offset_y = scaled_h - self.target_size[0]
56 | offset_x = scaled_w - self.target_size[1]
57 | offset_y = int(max(0.0, float(offset_y)) * np.random.uniform(0, 1))
58 | offset_x = int(max(0.0, float(offset_x)) * np.random.uniform(0, 1))
59 | return EfficientDetResizeCropTransform(
60 | scaled_h, scaled_w, offset_y, offset_x, img_scale, self.target_size, self.interp)
61 |
--------------------------------------------------------------------------------
/mmovod/data/transforms/custom_transform.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 | # Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py
4 | # Modified by Xingyi Zhou
5 | # The original code is under Apache-2.0 License
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from fvcore.transforms.transform import (
10 | CropTransform,
11 | HFlipTransform,
12 | NoOpTransform,
13 | Transform,
14 | TransformList,
15 | )
16 | from PIL import Image
17 |
18 | try:
19 | import cv2 # noqa
20 | except ImportError:
21 | # OpenCV is an optional dependency at the moment
22 | pass
23 |
24 | __all__ = [
25 | "EfficientDetResizeCropTransform",
26 | ]
27 |
28 | class EfficientDetResizeCropTransform(Transform):
29 | """
30 | """
31 |
32 | def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, \
33 | target_size, interp=None):
34 | """
35 | Args:
36 | h, w (int): original image size
37 | new_h, new_w (int): new image size
38 | interp: PIL interpolation methods, defaults to bilinear.
39 | """
40 | # TODO decide on PIL vs opencv
41 | super().__init__()
42 | if interp is None:
43 | interp = Image.BILINEAR
44 | self._set_attributes(locals())
45 |
46 | def apply_image(self, img, interp=None):
47 | assert len(img.shape) <= 4
48 |
49 | if img.dtype == np.uint8:
50 | pil_image = Image.fromarray(img)
51 | interp_method = interp if interp is not None else self.interp
52 | pil_image = pil_image.resize((self.scaled_w, self.scaled_h), interp_method)
53 | ret = np.asarray(pil_image)
54 | right = min(self.scaled_w, self.offset_x + self.target_size[1])
55 | lower = min(self.scaled_h, self.offset_y + self.target_size[0])
56 | if len(ret.shape) <= 3:
57 | ret = ret[self.offset_y: lower, self.offset_x: right]
58 | else:
59 | ret = ret[..., self.offset_y: lower, self.offset_x: right, :]
60 | else:
61 | # PIL only supports uint8
62 | img = torch.from_numpy(img)
63 | shape = list(img.shape)
64 | shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:]
65 | img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw
66 | _PIL_RESIZE_TO_INTERPOLATE_MODE = {Image.BILINEAR: "bilinear", Image.BICUBIC: "bicubic"}
67 | mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp]
68 | img = F.interpolate(img, (self.scaled_h, self.scaled_w), mode=mode, align_corners=False)
69 | shape[:2] = (self.scaled_h, self.scaled_w)
70 | ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c)
71 | right = min(self.scaled_w, self.offset_x + self.target_size[1])
72 | lower = min(self.scaled_h, self.offset_y + self.target_size[0])
73 | if len(ret.shape) <= 3:
74 | ret = ret[self.offset_y: lower, self.offset_x: right]
75 | else:
76 | ret = ret[..., self.offset_y: lower, self.offset_x: right, :]
77 | return ret
78 |
79 |
80 | def apply_coords(self, coords):
81 | coords[:, 0] = coords[:, 0] * self.img_scale
82 | coords[:, 1] = coords[:, 1] * self.img_scale
83 | coords[:, 0] -= self.offset_x
84 | coords[:, 1] -= self.offset_y
85 | return coords
86 |
87 |
88 | def apply_segmentation(self, segmentation):
89 | segmentation = self.apply_image(segmentation, interp=Image.NEAREST)
90 | return segmentation
91 |
92 |
93 | def inverse(self):
94 | raise NotImplementedError
95 |
96 |
97 | def inverse_apply_coords(self, coords):
98 | coords[:, 0] += self.offset_x
99 | coords[:, 1] += self.offset_y
100 | coords[:, 0] = coords[:, 0] / self.img_scale
101 | coords[:, 1] = coords[:, 1] / self.img_scale
102 | return coords
103 |
104 |
105 | def inverse_apply_box(self, box: np.ndarray) -> np.ndarray:
106 | """
107 | """
108 | idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten()
109 | coords = np.asarray(box).reshape(-1, 4)[:, idxs].reshape(-1, 2)
110 | coords = self.inverse_apply_coords(coords).reshape((-1, 4, 2))
111 | minxy = coords.min(axis=1)
112 | maxxy = coords.max(axis=1)
113 | trans_boxes = np.concatenate((minxy, maxxy), axis=1)
114 | return trans_boxes
--------------------------------------------------------------------------------
/mmovod/modeling/__pycache__/debug.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/__pycache__/debug.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/modeling/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/modeling/backbone/__pycache__/swintransformer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/backbone/__pycache__/swintransformer.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/modeling/backbone/__pycache__/timm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/backbone/__pycache__/timm.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/modeling/backbone/timm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | import math
5 | from os.path import join
6 | import numpy as np
7 | import copy
8 | from functools import partial
9 |
10 | import torch
11 | from torch import nn
12 | import torch.utils.model_zoo as model_zoo
13 | import torch.nn.functional as F
14 | import fvcore.nn.weight_init as weight_init
15 |
16 | from detectron2.modeling.backbone import FPN
17 | from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
18 | from detectron2.layers.batch_norm import get_norm, FrozenBatchNorm2d
19 | from detectron2.modeling.backbone import Backbone
20 |
21 | from timm import create_model
22 | from timm.models.helpers import build_model_with_cfg
23 | from timm.models.registry import register_model
24 | from timm.models.resnet import ResNet, Bottleneck
25 | from timm.models.resnet import default_cfgs as default_cfgs_resnet
26 | from timm.models.convnext import ConvNeXt, default_cfgs, checkpoint_filter_fn
27 |
28 |
29 | @register_model
30 | def convnext_tiny_21k(pretrained=False, **kwargs):
31 | model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
32 | cfg = default_cfgs['convnext_tiny']
33 | cfg['url'] = 'https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth'
34 | model = build_model_with_cfg(
35 | ConvNeXt, 'convnext_tiny', pretrained,
36 | default_cfg=cfg,
37 | pretrained_filter_fn=checkpoint_filter_fn,
38 | feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
39 | **model_args)
40 | return model
41 |
42 | class CustomResNet(ResNet):
43 | def __init__(self, **kwargs):
44 | self.out_indices = kwargs.pop('out_indices')
45 | super().__init__(**kwargs)
46 |
47 |
48 | def forward(self, x):
49 | x = self.conv1(x)
50 | x = self.bn1(x)
51 | x = self.act1(x)
52 | x = self.maxpool(x)
53 | ret = [x]
54 | x = self.layer1(x)
55 | ret.append(x)
56 | x = self.layer2(x)
57 | ret.append(x)
58 | x = self.layer3(x)
59 | ret.append(x)
60 | x = self.layer4(x)
61 | ret.append(x)
62 | return [ret[i] for i in self.out_indices]
63 |
64 |
65 | def load_pretrained(self, cached_file):
66 | data = torch.load(cached_file, map_location='cpu')
67 | if 'state_dict' in data:
68 | self.load_state_dict(data['state_dict'])
69 | else:
70 | self.load_state_dict(data)
71 |
72 |
73 | model_params = {
74 | 'resnet50_in21k': dict(block=Bottleneck, layers=[3, 4, 6, 3]),
75 | }
76 |
77 |
78 | def create_timm_resnet(variant, out_indices, pretrained=False, **kwargs):
79 | params = model_params[variant]
80 | default_cfgs_resnet['resnet50_in21k'] = \
81 | copy.deepcopy(default_cfgs_resnet['resnet50'])
82 | default_cfgs_resnet['resnet50_in21k']['url'] = \
83 | 'https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth'
84 | default_cfgs_resnet['resnet50_in21k']['num_classes'] = 11221
85 |
86 | return build_model_with_cfg(
87 | CustomResNet, variant, pretrained,
88 | default_cfg=default_cfgs_resnet[variant],
89 | out_indices=out_indices,
90 | pretrained_custom_load=True,
91 | **params,
92 | **kwargs)
93 |
94 |
95 | class LastLevelP6P7_P5(nn.Module):
96 | """
97 | """
98 | def __init__(self, in_channels, out_channels):
99 | super().__init__()
100 | self.num_levels = 2
101 | self.in_feature = "p5"
102 | self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
103 | self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
104 | for module in [self.p6, self.p7]:
105 | weight_init.c2_xavier_fill(module)
106 |
107 | def forward(self, c5):
108 | p6 = self.p6(c5)
109 | p7 = self.p7(F.relu(p6))
110 | return [p6, p7]
111 |
112 |
113 | def freeze_module(x):
114 | """
115 | """
116 | for p in x.parameters():
117 | p.requires_grad = False
118 | FrozenBatchNorm2d.convert_frozen_batchnorm(x)
119 | return x
120 |
121 |
122 | class TIMM(Backbone):
123 | def __init__(self, base_name, out_levels, freeze_at=0, norm='FrozenBN', pretrained=False):
124 | super().__init__()
125 | out_indices = [x - 1 for x in out_levels]
126 | if base_name in model_params:
127 | self.base = create_timm_resnet(
128 | base_name, out_indices=out_indices,
129 | pretrained=False)
130 | elif 'eff' in base_name or 'resnet' in base_name or 'regnet' in base_name:
131 | self.base = create_model(
132 | base_name, features_only=True,
133 | out_indices=out_indices, pretrained=pretrained)
134 | elif 'convnext' in base_name:
135 | drop_path_rate = 0.2 \
136 | if ('tiny' in base_name or 'small' in base_name) else 0.3
137 | self.base = create_model(
138 | base_name, features_only=True,
139 | out_indices=out_indices, pretrained=pretrained,
140 | drop_path_rate=drop_path_rate)
141 | else:
142 | assert 0, base_name
143 | feature_info = [dict(num_chs=f['num_chs'], reduction=f['reduction']) \
144 | for i, f in enumerate(self.base.feature_info)]
145 | self._out_features = ['layer{}'.format(x) for x in out_levels]
146 | self._out_feature_channels = {
147 | 'layer{}'.format(l): feature_info[l - 1]['num_chs'] for l in out_levels}
148 | self._out_feature_strides = {
149 | 'layer{}'.format(l): feature_info[l - 1]['reduction'] for l in out_levels}
150 | self._size_divisibility = max(self._out_feature_strides.values())
151 | if 'resnet' in base_name:
152 | self.freeze(freeze_at)
153 | if norm == 'FrozenBN':
154 | self = FrozenBatchNorm2d.convert_frozen_batchnorm(self)
155 |
156 | def freeze(self, freeze_at=0):
157 | """
158 | """
159 | if freeze_at >= 1:
160 | print('Frezing', self.base.conv1)
161 | self.base.conv1 = freeze_module(self.base.conv1)
162 | if freeze_at >= 2:
163 | print('Frezing', self.base.layer1)
164 | self.base.layer1 = freeze_module(self.base.layer1)
165 |
166 | def forward(self, x):
167 | features = self.base(x)
168 | ret = {k: v for k, v in zip(self._out_features, features)}
169 | return ret
170 |
171 | @property
172 | def size_divisibility(self):
173 | return self._size_divisibility
174 |
175 |
176 | @BACKBONE_REGISTRY.register()
177 | def build_timm_backbone(cfg, input_shape):
178 | model = TIMM(
179 | cfg.MODEL.TIMM.BASE_NAME,
180 | cfg.MODEL.TIMM.OUT_LEVELS,
181 | freeze_at=cfg.MODEL.TIMM.FREEZE_AT,
182 | norm=cfg.MODEL.TIMM.NORM,
183 | pretrained=cfg.MODEL.TIMM.PRETRAINED,
184 | )
185 | return model
186 |
187 |
188 | @BACKBONE_REGISTRY.register()
189 | def build_p67_timm_fpn_backbone(cfg, input_shape):
190 | """
191 | """
192 | bottom_up = build_timm_backbone(cfg, input_shape)
193 | in_features = cfg.MODEL.FPN.IN_FEATURES
194 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS
195 | backbone = FPN(
196 | bottom_up=bottom_up,
197 | in_features=in_features,
198 | out_channels=out_channels,
199 | norm=cfg.MODEL.FPN.NORM,
200 | top_block=LastLevelP6P7_P5(out_channels, out_channels),
201 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
202 | )
203 | return backbone
204 |
205 | @BACKBONE_REGISTRY.register()
206 | def build_p35_timm_fpn_backbone(cfg, input_shape):
207 | """
208 | """
209 | bottom_up = build_timm_backbone(cfg, input_shape)
210 |
211 | in_features = cfg.MODEL.FPN.IN_FEATURES
212 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS
213 | backbone = FPN(
214 | bottom_up=bottom_up,
215 | in_features=in_features,
216 | out_channels=out_channels,
217 | norm=cfg.MODEL.FPN.NORM,
218 | top_block=None,
219 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
220 | )
221 | return backbone
--------------------------------------------------------------------------------
/mmovod/modeling/debug.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import cv2
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | import os
7 |
8 | COLORS = ((np.random.rand(1300, 3) * 0.4 + 0.6) * 255).astype(
9 | np.uint8).reshape(1300, 1, 1, 3)
10 |
11 | def _get_color_image(heatmap):
12 | heatmap = heatmap.reshape(
13 | heatmap.shape[0], heatmap.shape[1], heatmap.shape[2], 1)
14 | if heatmap.shape[0] == 1:
15 | color_map = (heatmap * np.ones((1, 1, 1, 3), np.uint8) * 255).max(
16 | axis=0).astype(np.uint8) # H, W, 3
17 | else:
18 | color_map = (heatmap * COLORS[:heatmap.shape[0]]).max(axis=0).astype(np.uint8) # H, W, 3
19 |
20 | return color_map
21 |
22 | def _blend_image(image, color_map, a=0.7):
23 | color_map = cv2.resize(color_map, (image.shape[1], image.shape[0]))
24 | ret = np.clip(image * (1 - a) + color_map * a, 0, 255).astype(np.uint8)
25 | return ret
26 |
27 | def _blend_image_heatmaps(image, color_maps, a=0.7):
28 | merges = np.zeros((image.shape[0], image.shape[1], 3), np.float32)
29 | for color_map in color_maps:
30 | color_map = cv2.resize(color_map, (image.shape[1], image.shape[0]))
31 | merges = np.maximum(merges, color_map)
32 | ret = np.clip(image * (1 - a) + merges * a, 0, 255).astype(np.uint8)
33 | return ret
34 |
35 | def _decompose_level(x, shapes_per_level, N):
36 | '''
37 | x: LNHiWi x C
38 | '''
39 | x = x.view(x.shape[0], -1)
40 | ret = []
41 | st = 0
42 | for l in range(len(shapes_per_level)):
43 | ret.append([])
44 | h = shapes_per_level[l][0].int().item()
45 | w = shapes_per_level[l][1].int().item()
46 | for i in range(N):
47 | ret[l].append(x[st + h * w * i:st + h * w * (i + 1)].view(
48 | h, w, -1).permute(2, 0, 1))
49 | st += h * w * N
50 | return ret
51 |
52 | def _imagelist_to_tensor(images):
53 | images = [x for x in images]
54 | image_sizes = [x.shape[-2:] for x in images]
55 | h = max([size[0] for size in image_sizes])
56 | w = max([size[1] for size in image_sizes])
57 | S = 32
58 | h, w = ((h - 1) // S + 1) * S, ((w - 1) // S + 1) * S
59 | images = [F.pad(x, (0, w - x.shape[2], 0, h - x.shape[1], 0, 0)) \
60 | for x in images]
61 | images = torch.stack(images)
62 | return images
63 |
64 |
65 | def _ind2il(ind, shapes_per_level, N):
66 | r = ind
67 | l = 0
68 | S = 0
69 | while r - S >= N * shapes_per_level[l][0] * shapes_per_level[l][1]:
70 | S += N * shapes_per_level[l][0] * shapes_per_level[l][1]
71 | l += 1
72 | i = (r - S) // (shapes_per_level[l][0] * shapes_per_level[l][1])
73 | return i, l
74 |
75 | def debug_train(
76 | images, gt_instances, flattened_hms, reg_targets, labels, pos_inds,
77 | shapes_per_level, locations, strides):
78 | '''
79 | images: N x 3 x H x W
80 | flattened_hms: LNHiWi x C
81 | shapes_per_level: L x 2 [(H_i, W_i)]
82 | locations: LNHiWi x 2
83 | '''
84 | reg_inds = torch.nonzero(
85 | reg_targets.max(dim=1)[0] > 0).squeeze(1)
86 | N = len(images)
87 | images = _imagelist_to_tensor(images)
88 | repeated_locations = [torch.cat([loc] * N, dim=0) \
89 | for loc in locations]
90 | locations = torch.cat(repeated_locations, dim=0)
91 | gt_hms = _decompose_level(flattened_hms, shapes_per_level, N)
92 | masks = flattened_hms.new_zeros((flattened_hms.shape[0], 1))
93 | masks[pos_inds] = 1
94 | masks = _decompose_level(masks, shapes_per_level, N)
95 | for i in range(len(images)):
96 | image = images[i].detach().cpu().numpy().transpose(1, 2, 0)
97 | color_maps = []
98 | for l in range(len(gt_hms)):
99 | color_map = _get_color_image(
100 | gt_hms[l][i].detach().cpu().numpy())
101 | color_maps.append(color_map)
102 | cv2.imshow('gthm_{}'.format(l), color_map)
103 | blend = _blend_image_heatmaps(image.copy(), color_maps)
104 | if gt_instances is not None:
105 | bboxes = gt_instances[i].gt_boxes.tensor
106 | for j in range(len(bboxes)):
107 | bbox = bboxes[j]
108 | cv2.rectangle(
109 | blend,
110 | (int(bbox[0]), int(bbox[1])),
111 | (int(bbox[2]), int(bbox[3])),
112 | (0, 0, 255), 3, cv2.LINE_AA)
113 |
114 | for j in range(len(pos_inds)):
115 | image_id, l = _ind2il(pos_inds[j], shapes_per_level, N)
116 | if image_id != i:
117 | continue
118 | loc = locations[pos_inds[j]]
119 | cv2.drawMarker(
120 | blend, (int(loc[0]), int(loc[1])), (0, 255, 255),
121 | markerSize=(l + 1) * 16)
122 |
123 | for j in range(len(reg_inds)):
124 | image_id, l = _ind2il(reg_inds[j], shapes_per_level, N)
125 | if image_id != i:
126 | continue
127 | ltrb = reg_targets[reg_inds[j]]
128 | ltrb *= strides[l]
129 | loc = locations[reg_inds[j]]
130 | bbox = [(loc[0] - ltrb[0]), (loc[1] - ltrb[1]),
131 | (loc[0] + ltrb[2]), (loc[1] + ltrb[3])]
132 | cv2.rectangle(
133 | blend,
134 | (int(bbox[0]), int(bbox[1])),
135 | (int(bbox[2]), int(bbox[3])),
136 | (255, 0, 0), 1, cv2.LINE_AA)
137 | cv2.circle(blend, (int(loc[0]), int(loc[1])), 2, (255, 0, 0), -1)
138 |
139 | cv2.imshow('blend', blend)
140 | cv2.waitKey()
141 |
142 |
143 | def debug_test(
144 | images, logits_pred, reg_pred, agn_hm_pred=[], preds=[],
145 | vis_thresh=0.3, debug_show_name=False, mult_agn=False):
146 | '''
147 | images: N x 3 x H x W
148 | class_target: LNHiWi x C
149 | cat_agn_heatmap: LNHiWi
150 | shapes_per_level: L x 2 [(H_i, W_i)]
151 | '''
152 | N = len(images)
153 | for i in range(len(images)):
154 | image = images[i].detach().cpu().numpy().transpose(1, 2, 0)
155 | result = image.copy().astype(np.uint8)
156 | pred_image = image.copy().astype(np.uint8)
157 | color_maps = []
158 | L = len(logits_pred)
159 | for l in range(L):
160 | if logits_pred[0] is not None:
161 | stride = min(image.shape[0], image.shape[1]) / min(
162 | logits_pred[l][i].shape[1], logits_pred[l][i].shape[2])
163 | else:
164 | stride = min(image.shape[0], image.shape[1]) / min(
165 | agn_hm_pred[l][i].shape[1], agn_hm_pred[l][i].shape[2])
166 | stride = stride if stride < 60 else 64 if stride < 100 else 128
167 | if logits_pred[0] is not None:
168 | if mult_agn:
169 | logits_pred[l][i] = logits_pred[l][i] * agn_hm_pred[l][i]
170 | color_map = _get_color_image(
171 | logits_pred[l][i].detach().cpu().numpy())
172 | color_maps.append(color_map)
173 | cv2.imshow('predhm_{}'.format(l), color_map)
174 |
175 | if debug_show_name:
176 | from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES
177 | cat2name = [x['name'] for x in LVIS_CATEGORIES]
178 | for j in range(len(preds[i].scores) if preds is not None else 0):
179 | if preds[i].scores[j] > vis_thresh:
180 | bbox = preds[i].proposal_boxes[j] \
181 | if preds[i].has('proposal_boxes') else \
182 | preds[i].pred_boxes[j]
183 | bbox = bbox.tensor[0].detach().cpu().numpy().astype(np.int32)
184 | cat = int(preds[i].pred_classes[j]) \
185 | if preds[i].has('pred_classes') else 0
186 | cl = COLORS[cat, 0, 0]
187 | cv2.rectangle(
188 | pred_image, (int(bbox[0]), int(bbox[1])),
189 | (int(bbox[2]), int(bbox[3])),
190 | (int(cl[0]), int(cl[1]), int(cl[2])), 2, cv2.LINE_AA)
191 | if debug_show_name:
192 | txt = '{}{:.1f}'.format(
193 | cat2name[cat] if cat > 0 else '',
194 | preds[i].scores[j])
195 | font = cv2.FONT_HERSHEY_SIMPLEX
196 | cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
197 | cv2.rectangle(
198 | pred_image,
199 | (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)),
200 | (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)),
201 | (int(cl[0]), int(cl[1]), int(cl[2])), -1)
202 | cv2.putText(
203 | pred_image, txt, (int(bbox[0]), int(bbox[1] - 2)),
204 | font, 0.5, (0, 0, 0), thickness=1, lineType=cv2.LINE_AA)
205 |
206 |
207 | if agn_hm_pred[l] is not None:
208 | agn_hm_ = agn_hm_pred[l][i, 0, :, :, None].detach().cpu().numpy()
209 | agn_hm_ = (agn_hm_ * np.array([255, 255, 255]).reshape(
210 | 1, 1, 3)).astype(np.uint8)
211 | cv2.imshow('agn_hm_{}'.format(l), agn_hm_)
212 | blend = _blend_image_heatmaps(image.copy(), color_maps)
213 | cv2.imshow('blend', blend)
214 | cv2.imshow('preds', pred_image)
215 | cv2.waitKey()
216 |
217 | global cnt
218 | cnt = 0
219 |
220 | def debug_second_stage(images, instances, proposals=None, vis_thresh=0.3,
221 | save_debug=False, debug_show_name=False, image_labels=[],
222 | save_debug_path='output/save_debug/',
223 | bgr=False):
224 | images = _imagelist_to_tensor(images)
225 | if 'COCO' in save_debug_path:
226 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES
227 | cat2name = [x['name'] for x in COCO_CATEGORIES]
228 | else:
229 | from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES
230 | cat2name = ['({}){}'.format(x['frequency'], x['name']) \
231 | for x in LVIS_CATEGORIES]
232 | for i in range(len(images)):
233 | image = images[i].detach().cpu().numpy().transpose(1, 2, 0).astype(np.uint8).copy()
234 | if bgr:
235 | image = image[:, :, ::-1].copy()
236 | if instances[i].has('gt_boxes'):
237 | bboxes = instances[i].gt_boxes.tensor.cpu().numpy()
238 | scores = np.ones(bboxes.shape[0])
239 | cats = instances[i].gt_classes.cpu().numpy()
240 | else:
241 | bboxes = instances[i].pred_boxes.tensor.cpu().numpy()
242 | scores = instances[i].scores.cpu().numpy()
243 | cats = instances[i].pred_classes.cpu().numpy()
244 | for j in range(len(bboxes)):
245 | if scores[j] > vis_thresh:
246 | bbox = bboxes[j]
247 | cl = COLORS[cats[j], 0, 0]
248 | cl = (int(cl[0]), int(cl[1]), int(cl[2]))
249 | cv2.rectangle(
250 | image,
251 | (int(bbox[0]), int(bbox[1])),
252 | (int(bbox[2]), int(bbox[3])),
253 | cl, 2, cv2.LINE_AA)
254 | if debug_show_name:
255 | cat = cats[j]
256 | txt = '{}{:.1f}'.format(
257 | cat2name[cat] if cat > 0 else '',
258 | scores[j])
259 | font = cv2.FONT_HERSHEY_SIMPLEX
260 | cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
261 | cv2.rectangle(
262 | image,
263 | (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)),
264 | (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)),
265 | (int(cl[0]), int(cl[1]), int(cl[2])), -1)
266 | cv2.putText(
267 | image, txt, (int(bbox[0]), int(bbox[1] - 2)),
268 | font, 0.5, (0, 0, 0), thickness=1, lineType=cv2.LINE_AA)
269 | if proposals is not None:
270 | proposal_image = images[i].detach().cpu().numpy().transpose(1, 2, 0).astype(np.uint8).copy()
271 | if bgr:
272 | proposal_image = proposal_image.copy()
273 | else:
274 | proposal_image = proposal_image[:, :, ::-1].copy()
275 | bboxes = proposals[i].proposal_boxes.tensor.cpu().numpy()
276 | if proposals[i].has('scores'):
277 | scores = proposals[i].scores.detach().cpu().numpy()
278 | else:
279 | scores = proposals[i].objectness_logits.detach().cpu().numpy()
280 | # selected = -1
281 | # if proposals[i].has('image_loss'):
282 | # selected = proposals[i].image_loss.argmin()
283 | if proposals[i].has('selected'):
284 | selected = proposals[i].selected
285 | else:
286 | selected = [-1 for _ in range(len(bboxes))]
287 | for j in range(len(bboxes)):
288 | if scores[j] > vis_thresh or selected[j] >= 0:
289 | bbox = bboxes[j]
290 | cl = (209, 159, 83)
291 | th = 2
292 | if selected[j] >= 0:
293 | cl = (0, 0, 0xa4)
294 | th = 4
295 | cv2.rectangle(
296 | proposal_image,
297 | (int(bbox[0]), int(bbox[1])),
298 | (int(bbox[2]), int(bbox[3])),
299 | cl, th, cv2.LINE_AA)
300 | if selected[j] >= 0 and debug_show_name:
301 | cat = selected[j].item()
302 | txt = '{}'.format(cat2name[cat])
303 | font = cv2.FONT_HERSHEY_SIMPLEX
304 | cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0]
305 | cv2.rectangle(
306 | proposal_image,
307 | (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)),
308 | (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)),
309 | (int(cl[0]), int(cl[1]), int(cl[2])), -1)
310 | cv2.putText(
311 | proposal_image, txt,
312 | (int(bbox[0]), int(bbox[1] - 2)),
313 | font, 0.5, (0, 0, 0), thickness=1,
314 | lineType=cv2.LINE_AA)
315 |
316 | if save_debug:
317 | global cnt
318 | cnt = (cnt + 1) % 5000
319 | if not os.path.exists(save_debug_path):
320 | os.mkdir(save_debug_path)
321 | save_name = '{}/{:05d}.jpg'.format(save_debug_path, cnt)
322 | if i < len(image_labels):
323 | image_label = image_labels[i]
324 | save_name = '{}/{:05d}'.format(save_debug_path, cnt)
325 | for x in image_label:
326 | class_name = cat2name[x]
327 | save_name = save_name + '|{}'.format(class_name)
328 | save_name = save_name + '.jpg'
329 | cv2.imwrite(save_name, proposal_image)
330 | else:
331 | cv2.imshow('image', image)
332 | if proposals is not None:
333 | cv2.imshow('proposals', proposal_image)
334 | cv2.waitKey()
--------------------------------------------------------------------------------
/mmovod/modeling/meta_arch/__pycache__/custom_rcnn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/meta_arch/__pycache__/custom_rcnn.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/modeling/meta_arch/__pycache__/d2_deformable_detr.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/meta_arch/__pycache__/d2_deformable_detr.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/modeling/meta_arch/custom_rcnn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import copy
3 | import logging
4 | import numpy as np
5 | from typing import Dict, List, Optional, Tuple
6 | import torch
7 | from torch import nn
8 | import json
9 | from detectron2.utils.events import get_event_storage
10 | from detectron2.config import configurable
11 | from detectron2.structures import ImageList, Instances, Boxes
12 | import detectron2.utils.comm as comm
13 |
14 | from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
15 | from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
16 | from detectron2.modeling.postprocessing import detector_postprocess
17 | from detectron2.utils.visualizer import Visualizer, _create_text_labels
18 | from detectron2.data.detection_utils import convert_image_to_rgb
19 |
20 | from torch.cuda.amp import autocast
21 | from ..text.text_encoder import build_text_encoder
22 | from ..utils import load_class_freq, get_fed_loss_inds
23 |
24 | @META_ARCH_REGISTRY.register()
25 | class CustomRCNN(GeneralizedRCNN):
26 | '''
27 | Add image labels
28 | '''
29 | @configurable
30 | def __init__(
31 | self,
32 | with_image_labels = False,
33 | dataset_loss_weight = [],
34 | fp16 = False,
35 | sync_caption_batch = False,
36 | roi_head_name = '',
37 | cap_batch_ratio = 4,
38 | with_caption = False,
39 | dynamic_classifier = False,
40 | **kwargs):
41 | """
42 | """
43 | self.with_image_labels = with_image_labels
44 | self.dataset_loss_weight = dataset_loss_weight
45 | self.fp16 = fp16
46 | self.with_caption = with_caption
47 | self.sync_caption_batch = sync_caption_batch
48 | self.roi_head_name = roi_head_name
49 | self.cap_batch_ratio = cap_batch_ratio
50 | self.dynamic_classifier = dynamic_classifier
51 | self.return_proposal = False
52 | if self.dynamic_classifier:
53 | self.freq_weight = kwargs.pop('freq_weight')
54 | self.num_classes = kwargs.pop('num_classes')
55 | self.num_sample_cats = kwargs.pop('num_sample_cats')
56 | super().__init__(**kwargs)
57 | assert self.proposal_generator is not None
58 | if self.with_caption:
59 | assert not self.dynamic_classifier
60 | self.text_encoder = build_text_encoder(pretrain=True)
61 | for v in self.text_encoder.parameters():
62 | v.requires_grad = False
63 |
64 |
65 | @classmethod
66 | def from_config(cls, cfg):
67 | ret = super().from_config(cfg)
68 | ret.update({
69 | 'with_image_labels': cfg.WITH_IMAGE_LABELS,
70 | 'dataset_loss_weight': cfg.MODEL.DATASET_LOSS_WEIGHT,
71 | 'fp16': cfg.FP16,
72 | 'with_caption': cfg.MODEL.WITH_CAPTION,
73 | 'sync_caption_batch': cfg.MODEL.SYNC_CAPTION_BATCH,
74 | 'dynamic_classifier': cfg.MODEL.DYNAMIC_CLASSIFIER,
75 | 'roi_head_name': cfg.MODEL.ROI_HEADS.NAME,
76 | 'cap_batch_ratio': cfg.MODEL.CAP_BATCH_RATIO,
77 | })
78 | if ret['dynamic_classifier']:
79 | ret['freq_weight'] = load_class_freq(
80 | cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH,
81 | cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT)
82 | ret['num_classes'] = cfg.MODEL.ROI_HEADS.NUM_CLASSES
83 | ret['num_sample_cats'] = cfg.MODEL.NUM_SAMPLE_CATS
84 | return ret
85 |
86 |
87 | def inference(
88 | self,
89 | batched_inputs: Tuple[Dict[str, torch.Tensor]],
90 | detected_instances: Optional[List[Instances]] = None,
91 | do_postprocess: bool = True,
92 | ):
93 | assert not self.training
94 | assert detected_instances is None
95 |
96 | images = self.preprocess_image(batched_inputs)
97 | features = self.backbone(images.tensor)
98 | proposals, _ = self.proposal_generator(images, features, None)
99 | results, _ = self.roi_heads(images, features, proposals)
100 | if do_postprocess:
101 | assert not torch.jit.is_scripting(), \
102 | "Scripting is not supported for postprocess."
103 | return CustomRCNN._postprocess(
104 | results, batched_inputs, images.image_sizes)
105 | else:
106 | return results
107 |
108 |
109 | def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]):
110 | """
111 | Add ann_type
112 | Ignore proposal loss when training with image labels
113 | """
114 | if not self.training:
115 | return self.inference(batched_inputs)
116 |
117 | images = self.preprocess_image(batched_inputs)
118 |
119 | ann_type = 'box'
120 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
121 | if self.with_image_labels:
122 | for inst, x in zip(gt_instances, batched_inputs):
123 | inst._ann_type = x['ann_type']
124 | inst._pos_category_ids = x['pos_category_ids']
125 | ann_types = [x['ann_type'] for x in batched_inputs]
126 | assert len(set(ann_types)) == 1
127 | ann_type = ann_types[0]
128 | if ann_type in ['prop', 'proptag']:
129 | for t in gt_instances:
130 | t.gt_classes *= 0
131 |
132 | if self.fp16: # TODO (zhouxy): improve
133 | with autocast():
134 | features = self.backbone(images.tensor.half())
135 | features = {k: v.float() for k, v in features.items()}
136 | else:
137 | features = self.backbone(images.tensor)
138 |
139 | cls_features, cls_inds, caption_features = None, None, None
140 |
141 | if self.with_caption and 'caption' in ann_type:
142 | inds = [torch.randint(len(x['captions']), (1,))[0].item() \
143 | for x in batched_inputs]
144 | caps = [x['captions'][ind] for ind, x in zip(inds, batched_inputs)]
145 | caption_features = self.text_encoder(caps).float()
146 | if self.sync_caption_batch:
147 | caption_features = self._sync_caption_features(
148 | caption_features, ann_type, len(batched_inputs))
149 |
150 | if self.dynamic_classifier and ann_type != 'caption':
151 | cls_inds = self._sample_cls_inds(gt_instances, ann_type) # inds, inv_inds
152 | ind_with_bg = cls_inds[0].tolist() + [-1]
153 | cls_features = self.roi_heads.box_predictor[
154 | 0].cls_score.zs_weight[:, ind_with_bg].permute(1, 0).contiguous()
155 |
156 | classifier_info = cls_features, cls_inds, caption_features
157 | proposals, proposal_losses = self.proposal_generator(
158 | images, features, gt_instances)
159 |
160 | if self.roi_head_name in ['StandardROIHeads', 'CascadeROIHeads']:
161 | proposals, detector_losses = self.roi_heads(
162 | images, features, proposals, gt_instances)
163 | else:
164 | proposals, detector_losses = self.roi_heads(
165 | images, features, proposals, gt_instances,
166 | ann_type=ann_type, classifier_info=classifier_info)
167 |
168 | if self.vis_period > 0:
169 | storage = get_event_storage()
170 | if storage.iter % self.vis_period == 0:
171 | self.visualize_training(batched_inputs, proposals)
172 |
173 | losses = {}
174 | losses.update(detector_losses)
175 | if self.with_image_labels:
176 | if ann_type in ['box', 'prop', 'proptag']:
177 | losses.update(proposal_losses)
178 | else: # ignore proposal loss for non-bbox data
179 | losses.update({k: v * 0 for k, v in proposal_losses.items()})
180 | else:
181 | losses.update(proposal_losses)
182 | if len(self.dataset_loss_weight) > 0:
183 | dataset_sources = [x['dataset_source'] for x in batched_inputs]
184 | assert len(set(dataset_sources)) == 1
185 | dataset_source = dataset_sources[0]
186 | for k in losses:
187 | losses[k] *= self.dataset_loss_weight[dataset_source]
188 |
189 | if self.return_proposal:
190 | return proposals, losses
191 | else:
192 | return losses
193 |
194 |
195 | def _sync_caption_features(self, caption_features, ann_type, BS):
196 | has_caption_feature = (caption_features is not None)
197 | BS = (BS * self.cap_batch_ratio) if (ann_type == 'box') else BS
198 | rank = torch.full(
199 | (BS, 1), comm.get_rank(), dtype=torch.float32,
200 | device=self.device)
201 | if not has_caption_feature:
202 | caption_features = rank.new_zeros((BS, 512))
203 | caption_features = torch.cat([caption_features, rank], dim=1)
204 | global_caption_features = comm.all_gather(caption_features)
205 | caption_features = torch.cat(
206 | [x.to(self.device) for x in global_caption_features], dim=0) \
207 | if has_caption_feature else None # (NB) x (D + 1)
208 | return caption_features
209 |
210 |
211 | def _sample_cls_inds(self, gt_instances, ann_type='box'):
212 | if ann_type == 'box':
213 | gt_classes = torch.cat(
214 | [x.gt_classes for x in gt_instances])
215 | C = len(self.freq_weight)
216 | freq_weight = self.freq_weight
217 | else:
218 | gt_classes = torch.cat(
219 | [torch.tensor(
220 | x._pos_category_ids,
221 | dtype=torch.long, device=x.gt_classes.device) \
222 | for x in gt_instances])
223 | C = self.num_classes
224 | freq_weight = None
225 | assert gt_classes.max() < C, '{} {}'.format(gt_classes.max(), C)
226 | inds = get_fed_loss_inds(
227 | gt_classes, self.num_sample_cats, C,
228 | weight=freq_weight)
229 | cls_id_map = gt_classes.new_full(
230 | (self.num_classes + 1,), len(inds))
231 | cls_id_map[inds] = torch.arange(len(inds), device=cls_id_map.device)
232 | return inds, cls_id_map
--------------------------------------------------------------------------------
/mmovod/modeling/roi_heads/__pycache__/detic_fast_rcnn.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/roi_heads/__pycache__/detic_fast_rcnn.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/modeling/roi_heads/__pycache__/detic_roi_heads.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/roi_heads/__pycache__/detic_roi_heads.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/modeling/roi_heads/__pycache__/res5_roi_heads.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/roi_heads/__pycache__/res5_roi_heads.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/modeling/roi_heads/__pycache__/zero_shot_classifier.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/roi_heads/__pycache__/zero_shot_classifier.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/modeling/roi_heads/detic_roi_heads.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import copy
3 | import numpy as np
4 | import json
5 | import math
6 | import torch
7 | from torch import nn
8 | from torch.autograd.function import Function
9 | from typing import Dict, List, Optional, Tuple, Union
10 | from torch.nn import functional as F
11 |
12 | from detectron2.config import configurable
13 | from detectron2.layers import ShapeSpec
14 | from detectron2.layers import batched_nms
15 | from detectron2.structures import Boxes, Instances, pairwise_iou
16 | from detectron2.utils.events import get_event_storage
17 |
18 | from detectron2.modeling.box_regression import Box2BoxTransform
19 | from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference
20 | from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads
21 | from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient
22 | from detectron2.modeling.roi_heads.box_head import build_box_head
23 | from .detic_fast_rcnn import DeticFastRCNNOutputLayers
24 | from ..debug import debug_second_stage
25 |
26 | from torch.cuda.amp import autocast
27 |
28 | @ROI_HEADS_REGISTRY.register()
29 | class DeticCascadeROIHeads(CascadeROIHeads):
30 | @configurable
31 | def __init__(
32 | self,
33 | *,
34 | mult_proposal_score: bool = False,
35 | with_image_labels: bool = False,
36 | add_image_box: bool = False,
37 | image_box_size: float = 1.0,
38 | ws_num_props: int = 512,
39 | add_feature_to_prop: bool = False,
40 | mask_weight: float = 1.0,
41 | one_class_per_proposal: bool = False,
42 | **kwargs,
43 | ):
44 | super().__init__(**kwargs)
45 | self.mult_proposal_score = mult_proposal_score
46 | self.with_image_labels = with_image_labels
47 | self.add_image_box = add_image_box
48 | self.image_box_size = image_box_size
49 | self.ws_num_props = ws_num_props
50 | self.add_feature_to_prop = add_feature_to_prop
51 | self.mask_weight = mask_weight
52 | self.one_class_per_proposal = one_class_per_proposal
53 |
54 | @classmethod
55 | def from_config(cls, cfg, input_shape):
56 | ret = super().from_config(cfg, input_shape)
57 | ret.update({
58 | 'mult_proposal_score': cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE,
59 | 'with_image_labels': cfg.WITH_IMAGE_LABELS,
60 | 'add_image_box': cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX,
61 | 'image_box_size': cfg.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE,
62 | 'ws_num_props': cfg.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS,
63 | 'add_feature_to_prop': cfg.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP,
64 | 'mask_weight': cfg.MODEL.ROI_HEADS.MASK_WEIGHT,
65 | 'one_class_per_proposal': cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL,
66 | })
67 | return ret
68 |
69 |
70 | @classmethod
71 | def _init_box_head(self, cfg, input_shape):
72 | ret = super()._init_box_head(cfg, input_shape)
73 | del ret['box_predictors']
74 | cascade_bbox_reg_weights = cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS
75 | box_predictors = []
76 | for box_head, bbox_reg_weights in zip(ret['box_heads'], \
77 | cascade_bbox_reg_weights):
78 | box_predictors.append(
79 | DeticFastRCNNOutputLayers(
80 | cfg, box_head.output_shape,
81 | box2box_transform=Box2BoxTransform(weights=bbox_reg_weights)
82 | ))
83 | ret['box_predictors'] = box_predictors
84 | return ret
85 |
86 | def _forward_box(self, features, proposals, targets=None,
87 | ann_type='box', classifier_info=(None,None,None)):
88 | """
89 | Add mult proposal scores at testing
90 | Add ann_type
91 | """
92 | if (not self.training) and self.mult_proposal_score:
93 | if len(proposals) > 0 and proposals[0].has('scores'):
94 | proposal_scores = [p.get('scores') for p in proposals]
95 | else:
96 | proposal_scores = [p.get('objectness_logits') for p in proposals]
97 |
98 | features = [features[f] for f in self.box_in_features]
99 | head_outputs = [] # (predictor, predictions, proposals)
100 | prev_pred_boxes = None
101 | image_sizes = [x.image_size for x in proposals]
102 |
103 | for k in range(self.num_cascade_stages):
104 | if k > 0:
105 | proposals = self._create_proposals_from_boxes(
106 | prev_pred_boxes, image_sizes,
107 | logits=[p.objectness_logits for p in proposals])
108 | if self.training and ann_type in ['box']:
109 | proposals = self._match_and_label_boxes(
110 | proposals, k, targets)
111 | predictions = self._run_stage(features, proposals, k,
112 | classifier_info=classifier_info)
113 | prev_pred_boxes = self.box_predictor[k].predict_boxes(
114 | (predictions[0], predictions[1]), proposals)
115 | head_outputs.append((self.box_predictor[k], predictions, proposals))
116 |
117 | if self.training:
118 | losses = {}
119 | storage = get_event_storage()
120 | for stage, (predictor, predictions, proposals) in enumerate(head_outputs):
121 | with storage.name_scope("stage{}".format(stage)):
122 | if ann_type != 'box':
123 | stage_losses = {}
124 | if ann_type in ['image', 'caption', 'captiontag']:
125 | image_labels = [x._pos_category_ids for x in targets]
126 | # import pdb; pdb.set_trace()
127 | weak_losses = predictor.image_label_losses(
128 | predictions, proposals, image_labels,
129 | classifier_info=classifier_info,
130 | ann_type=ann_type)
131 | stage_losses.update(weak_losses)
132 | else: # supervised
133 | stage_losses = predictor.losses(
134 | (predictions[0], predictions[1]), proposals,
135 | classifier_info=classifier_info)
136 | if self.with_image_labels:
137 | stage_losses['image_loss'] = \
138 | predictions[0].new_zeros([1])[0]
139 | losses.update({k + "_stage{}".format(stage): v \
140 | for k, v in stage_losses.items()})
141 | return losses
142 | else:
143 | # Each is a list[Tensor] of length #image. Each tensor is Ri x (K+1)
144 | scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs]
145 | scores = [
146 | sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages)
147 | for scores_per_image in zip(*scores_per_stage)
148 | ]
149 | if self.mult_proposal_score:
150 | scores = [(s * ps[:, None]) ** 0.5 \
151 | for s, ps in zip(scores, proposal_scores)]
152 | if self.one_class_per_proposal:
153 | scores = [s * (s == s[:, :-1].max(dim=1)[0][:, None]).float() for s in scores]
154 | predictor, predictions, proposals = head_outputs[-1]
155 | boxes = predictor.predict_boxes(
156 | (predictions[0], predictions[1]), proposals)
157 | pred_instances, _ = fast_rcnn_inference(
158 | boxes,
159 | scores,
160 | image_sizes,
161 | predictor.test_score_thresh,
162 | predictor.test_nms_thresh,
163 | predictor.test_topk_per_image,
164 | )
165 | return pred_instances
166 |
167 | def forward(self, images, features, proposals, targets=None,
168 | ann_type='box', classifier_info=(None,None,None)):
169 | '''
170 | enable debug and image labels
171 | classifier_info is shared across the batch
172 | '''
173 | if self.training:
174 | if ann_type in ['box', 'prop', 'proptag']:
175 | proposals = self.label_and_sample_proposals(
176 | proposals, targets)
177 | else:
178 | proposals = self.get_top_proposals(proposals)
179 |
180 | losses = self._forward_box(features, proposals, targets, \
181 | ann_type=ann_type, classifier_info=classifier_info)
182 | if ann_type == 'box' and targets[0].has('gt_masks'):
183 | mask_losses = self._forward_mask(features, proposals)
184 | losses.update({k: v * self.mask_weight \
185 | for k, v in mask_losses.items()})
186 | losses.update(self._forward_keypoint(features, proposals))
187 | else:
188 | losses.update(self._get_empty_mask_loss(
189 | features, proposals,
190 | device=proposals[0].objectness_logits.device))
191 | return proposals, losses
192 | else:
193 | pred_instances = self._forward_box(
194 | features, proposals, classifier_info=classifier_info)
195 | pred_instances = self.forward_with_given_boxes(features, pred_instances)
196 | return pred_instances, {}
197 |
198 |
199 | def get_top_proposals(self, proposals):
200 | for i in range(len(proposals)):
201 | proposals[i].proposal_boxes.clip(proposals[i].image_size)
202 | proposals = [p[:self.ws_num_props] for p in proposals]
203 | for i, p in enumerate(proposals):
204 | p.proposal_boxes.tensor = p.proposal_boxes.tensor.detach()
205 | if self.add_image_box:
206 | proposals[i] = self._add_image_box(p)
207 | return proposals
208 |
209 |
210 | def _add_image_box(self, p):
211 | image_box = Instances(p.image_size)
212 | n = 1
213 | h, w = p.image_size
214 | f = self.image_box_size
215 | image_box.proposal_boxes = Boxes(
216 | p.proposal_boxes.tensor.new_tensor(
217 | [w * (1. - f) / 2.,
218 | h * (1. - f) / 2.,
219 | w * (1. - (1. - f) / 2.),
220 | h * (1. - (1. - f) / 2.)]
221 | ).view(n, 4))
222 | image_box.objectness_logits = p.objectness_logits.new_ones(n)
223 | return Instances.cat([p, image_box])
224 |
225 |
226 | def _get_empty_mask_loss(self, features, proposals, device):
227 | if self.mask_on:
228 | return {'loss_mask': torch.zeros(
229 | (1, ), device=device, dtype=torch.float32)[0]}
230 | else:
231 | return {}
232 |
233 |
234 | def _create_proposals_from_boxes(self, boxes, image_sizes, logits):
235 | """
236 | Add objectness_logits
237 | """
238 | boxes = [Boxes(b.detach()) for b in boxes]
239 | proposals = []
240 | for boxes_per_image, image_size, logit in zip(
241 | boxes, image_sizes, logits):
242 | boxes_per_image.clip(image_size)
243 | if self.training:
244 | inds = boxes_per_image.nonempty()
245 | boxes_per_image = boxes_per_image[inds]
246 | logit = logit[inds]
247 | prop = Instances(image_size)
248 | prop.proposal_boxes = boxes_per_image
249 | prop.objectness_logits = logit
250 | proposals.append(prop)
251 | return proposals
252 |
253 |
254 | def _run_stage(self, features, proposals, stage, \
255 | classifier_info=(None,None,None)):
256 | """
257 | Support classifier_info and add_feature_to_prop
258 | """
259 | pool_boxes = [x.proposal_boxes for x in proposals]
260 | box_features = self.box_pooler(features, pool_boxes)
261 | box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages)
262 | box_features = self.box_head[stage](box_features)
263 | if self.add_feature_to_prop:
264 | feats_per_image = box_features.split(
265 | [len(p) for p in proposals], dim=0)
266 | for feat, p in zip(feats_per_image, proposals):
267 | p.feat = feat
268 | return self.box_predictor[stage](
269 | box_features,
270 | classifier_info=classifier_info)
271 |
--------------------------------------------------------------------------------
/mmovod/modeling/roi_heads/res5_roi_heads.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import inspect
3 | import logging
4 | import numpy as np
5 | from typing import Dict, List, Optional, Tuple
6 | import torch
7 | from torch import nn
8 |
9 | from detectron2.config import configurable
10 | from detectron2.layers import ShapeSpec, nonzero_tuple
11 | from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
12 | from detectron2.utils.events import get_event_storage
13 | from detectron2.utils.registry import Registry
14 |
15 | from detectron2.modeling.box_regression import Box2BoxTransform
16 | from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference
17 | from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads
18 | from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient
19 | from detectron2.modeling.roi_heads.box_head import build_box_head
20 |
21 | from .detic_fast_rcnn import DeticFastRCNNOutputLayers
22 | from ..debug import debug_second_stage
23 |
24 | from torch.cuda.amp import autocast
25 |
26 | @ROI_HEADS_REGISTRY.register()
27 | class CustomRes5ROIHeads(Res5ROIHeads):
28 | @configurable
29 | def __init__(self, **kwargs):
30 | cfg = kwargs.pop('cfg')
31 | super().__init__(**kwargs)
32 | stage_channel_factor = 2 ** 3
33 | out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor
34 |
35 | self.with_image_labels = cfg.WITH_IMAGE_LABELS
36 | self.ws_num_props = cfg.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS
37 | self.add_image_box = cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX
38 | self.add_feature_to_prop = cfg.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP
39 | self.image_box_size = cfg.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE
40 | self.box_predictor = DeticFastRCNNOutputLayers(
41 | cfg, ShapeSpec(channels=out_channels, height=1, width=1)
42 | )
43 |
44 | self.save_debug = cfg.SAVE_DEBUG
45 | self.save_debug_path = cfg.SAVE_DEBUG_PATH
46 | if self.save_debug:
47 | self.debug_show_name = cfg.DEBUG_SHOW_NAME
48 | self.vis_thresh = cfg.VIS_THRESH
49 | self.pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(
50 | torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1)
51 | self.pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(
52 | torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1)
53 | self.bgr = (cfg.INPUT.FORMAT == 'BGR')
54 |
55 | @classmethod
56 | def from_config(cls, cfg, input_shape):
57 | ret = super().from_config(cfg, input_shape)
58 | ret['cfg'] = cfg
59 | return ret
60 |
61 | def forward(self, images, features, proposals, targets=None,
62 | ann_type='box', classifier_info=(None,None,None)):
63 | '''
64 | enable debug and image labels
65 | classifier_info is shared across the batch
66 | '''
67 | if not self.save_debug:
68 | del images
69 |
70 | if self.training:
71 | if ann_type in ['box']:
72 | proposals = self.label_and_sample_proposals(
73 | proposals, targets)
74 | else:
75 | proposals = self.get_top_proposals(proposals)
76 |
77 | proposal_boxes = [x.proposal_boxes for x in proposals]
78 | box_features = self._shared_roi_transform(
79 | [features[f] for f in self.in_features], proposal_boxes
80 | )
81 | predictions = self.box_predictor(
82 | box_features.mean(dim=[2, 3]),
83 | classifier_info=classifier_info)
84 |
85 | if self.add_feature_to_prop:
86 | feats_per_image = box_features.mean(dim=[2, 3]).split(
87 | [len(p) for p in proposals], dim=0)
88 | for feat, p in zip(feats_per_image, proposals):
89 | p.feat = feat
90 |
91 | if self.training:
92 | del features
93 | if (ann_type != 'box'):
94 | image_labels = [x._pos_category_ids for x in targets]
95 | losses = self.box_predictor.image_label_losses(
96 | predictions, proposals, image_labels,
97 | classifier_info=classifier_info,
98 | ann_type=ann_type)
99 | else:
100 | losses = self.box_predictor.losses(
101 | (predictions[0], predictions[1]), proposals)
102 | if self.with_image_labels:
103 | assert 'image_loss' not in losses
104 | losses['image_loss'] = predictions[0].new_zeros([1])[0]
105 | if self.save_debug:
106 | denormalizer = lambda x: x * self.pixel_std + self.pixel_mean
107 | if ann_type != 'box':
108 | image_labels = [x._pos_category_ids for x in targets]
109 | else:
110 | image_labels = [[] for x in targets]
111 | debug_second_stage(
112 | [denormalizer(x.clone()) for x in images],
113 | targets, proposals=proposals,
114 | save_debug=self.save_debug,
115 | debug_show_name=self.debug_show_name,
116 | vis_thresh=self.vis_thresh,
117 | image_labels=image_labels,
118 | save_debug_path=self.save_debug_path,
119 | bgr=self.bgr)
120 | return proposals, losses
121 | else:
122 | pred_instances, _ = self.box_predictor.inference(predictions, proposals)
123 | pred_instances = self.forward_with_given_boxes(features, pred_instances)
124 | if self.save_debug:
125 | denormalizer = lambda x: x * self.pixel_std + self.pixel_mean
126 | debug_second_stage(
127 | [denormalizer(x.clone()) for x in images],
128 | pred_instances, proposals=proposals,
129 | save_debug=self.save_debug,
130 | debug_show_name=self.debug_show_name,
131 | vis_thresh=self.vis_thresh,
132 | save_debug_path=self.save_debug_path,
133 | bgr=self.bgr)
134 | return pred_instances, {}
135 |
136 | def get_top_proposals(self, proposals):
137 | for i in range(len(proposals)):
138 | proposals[i].proposal_boxes.clip(proposals[i].image_size)
139 | proposals = [p[:self.ws_num_props] for p in proposals]
140 | for i, p in enumerate(proposals):
141 | p.proposal_boxes.tensor = p.proposal_boxes.tensor.detach()
142 | if self.add_image_box:
143 | proposals[i] = self._add_image_box(p)
144 | return proposals
145 |
146 | def _add_image_box(self, p, use_score=False):
147 | image_box = Instances(p.image_size)
148 | n = 1
149 | h, w = p.image_size
150 | if self.image_box_size < 1.0:
151 | f = self.image_box_size
152 | image_box.proposal_boxes = Boxes(
153 | p.proposal_boxes.tensor.new_tensor(
154 | [w * (1. - f) / 2.,
155 | h * (1. - f) / 2.,
156 | w * (1. - (1. - f) / 2.),
157 | h * (1. - (1. - f) / 2.)]
158 | ).view(n, 4))
159 | else:
160 | image_box.proposal_boxes = Boxes(
161 | p.proposal_boxes.tensor.new_tensor(
162 | [0, 0, w, h]).view(n, 4))
163 | if use_score:
164 | image_box.scores = \
165 | p.objectness_logits.new_ones(n)
166 | image_box.pred_classes = \
167 | p.objectness_logits.new_zeros(n, dtype=torch.long)
168 | image_box.objectness_logits = \
169 | p.objectness_logits.new_ones(n)
170 | else:
171 | image_box.objectness_logits = \
172 | p.objectness_logits.new_ones(n)
173 | return Instances.cat([p, image_box])
--------------------------------------------------------------------------------
/mmovod/modeling/roi_heads/zero_shot_classifier.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from detectron2.config import configurable
7 | from detectron2.layers import ShapeSpec
8 |
9 |
10 | class ZeroShotClassifier(nn.Module):
11 | @configurable
12 | def __init__(
13 | self,
14 | input_shape: ShapeSpec,
15 | *,
16 | num_classes: int,
17 | zs_weight_path: str,
18 | zs_weight_dim: int = 512,
19 | use_bias: float = 0.0,
20 | norm_weight: bool = True,
21 | norm_temperature: float = 50.0,
22 | ):
23 | super().__init__()
24 | if isinstance(input_shape, int): # some backward compatibility
25 | input_shape = ShapeSpec(channels=input_shape)
26 | input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1)
27 | self.norm_weight = norm_weight
28 | self.norm_temperature = norm_temperature
29 |
30 | self.use_bias = use_bias < 0
31 | if self.use_bias:
32 | self.cls_bias = nn.Parameter(torch.ones(1) * use_bias)
33 |
34 | self.linear = nn.Linear(input_size, zs_weight_dim)
35 |
36 | if zs_weight_path == 'rand':
37 | zs_weight = torch.randn((zs_weight_dim, num_classes))
38 | nn.init.normal_(zs_weight, std=0.01)
39 | else:
40 | zs_weight = torch.tensor(
41 | np.load(zs_weight_path),
42 | dtype=torch.float32).permute(1, 0).contiguous() # D x C
43 | zs_weight = torch.cat(
44 | [zs_weight, zs_weight.new_zeros((zs_weight_dim, 1))],
45 | dim=1) # D x (C + 1)
46 |
47 | if self.norm_weight:
48 | zs_weight = F.normalize(zs_weight, p=2, dim=0)
49 |
50 | if zs_weight_path == 'rand':
51 | self.zs_weight = nn.Parameter(zs_weight)
52 | else:
53 | self.register_buffer('zs_weight', zs_weight)
54 |
55 | assert self.zs_weight.shape[1] == num_classes + 1, self.zs_weight.shape
56 |
57 | @classmethod
58 | def from_config(cls, cfg, input_shape):
59 | return {
60 | 'input_shape': input_shape,
61 | 'num_classes': cfg.MODEL.ROI_HEADS.NUM_CLASSES,
62 | 'zs_weight_path': cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH,
63 | 'zs_weight_dim': cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_DIM,
64 | 'use_bias': cfg.MODEL.ROI_BOX_HEAD.USE_BIAS,
65 | 'norm_weight': cfg.MODEL.ROI_BOX_HEAD.NORM_WEIGHT,
66 | 'norm_temperature': cfg.MODEL.ROI_BOX_HEAD.NORM_TEMP,
67 | }
68 |
69 | def forward(self, x, classifier=None):
70 | '''
71 | Inputs:
72 | x: B x D'
73 | classifier_info: (C', C' x D)
74 | '''
75 | x = self.linear(x)
76 | if classifier is not None:
77 | zs_weight = classifier.permute(1, 0).contiguous() # D x C'
78 | zs_weight = F.normalize(zs_weight, p=2, dim=0) \
79 | if self.norm_weight else zs_weight
80 | else:
81 | zs_weight = self.zs_weight
82 | if self.norm_weight:
83 | x = self.norm_temperature * F.normalize(x, p=2, dim=1)
84 | x = torch.mm(x, zs_weight)
85 | if self.use_bias:
86 | x = x + self.cls_bias
87 | return x
88 |
--------------------------------------------------------------------------------
/mmovod/modeling/text/__pycache__/text_encoder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/text/__pycache__/text_encoder.cpython-38.pyc
--------------------------------------------------------------------------------
/mmovod/modeling/text/text_encoder.py:
--------------------------------------------------------------------------------
1 | # This code is modified from https://github.com/openai/CLIP/blob/main/clip/clip.py
2 | # Modified by Xingyi Zhou
3 | # The original code is under MIT license
4 | # Copyright (c) Facebook, Inc. and its affiliates.
5 | from typing import Union, List
6 | from collections import OrderedDict
7 | import torch
8 | from torch import nn
9 | import torch
10 |
11 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
12 |
13 | __all__ = ["tokenize"]
14 |
15 | count = 0
16 |
17 | class LayerNorm(nn.LayerNorm):
18 | """Subclass torch's LayerNorm to handle fp16."""
19 |
20 | def forward(self, x: torch.Tensor):
21 | orig_type = x.dtype
22 | ret = super().forward(x.type(torch.float32))
23 | return ret.type(orig_type)
24 |
25 |
26 | class QuickGELU(nn.Module):
27 | def forward(self, x: torch.Tensor):
28 | return x * torch.sigmoid(1.702 * x)
29 |
30 |
31 | class ResidualAttentionBlock(nn.Module):
32 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
33 | super().__init__()
34 |
35 | self.attn = nn.MultiheadAttention(d_model, n_head)
36 | self.ln_1 = LayerNorm(d_model)
37 | self.mlp = nn.Sequential(OrderedDict([
38 | ("c_fc", nn.Linear(d_model, d_model * 4)),
39 | ("gelu", QuickGELU()),
40 | ("c_proj", nn.Linear(d_model * 4, d_model))
41 | ]))
42 | self.ln_2 = LayerNorm(d_model)
43 | self.attn_mask = attn_mask
44 |
45 | def attention(self, x: torch.Tensor):
46 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
47 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
48 |
49 | def forward(self, x: torch.Tensor):
50 | x = x + self.attention(self.ln_1(x))
51 | x = x + self.mlp(self.ln_2(x))
52 | return x
53 |
54 |
55 | class Transformer(nn.Module):
56 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
57 | super().__init__()
58 | self.width = width
59 | self.layers = layers
60 | self.resblocks = nn.Sequential(
61 | *[ResidualAttentionBlock(width, heads, attn_mask) \
62 | for _ in range(layers)])
63 |
64 | def forward(self, x: torch.Tensor):
65 | return self.resblocks(x)
66 |
67 | class CLIPTEXT(nn.Module):
68 | def __init__(self,
69 | embed_dim=512,
70 | # text
71 | context_length=77,
72 | vocab_size=49408,
73 | transformer_width=512,
74 | transformer_heads=8,
75 | transformer_layers=12
76 | ):
77 | super().__init__()
78 |
79 | self._tokenizer = _Tokenizer()
80 | self.context_length = context_length
81 |
82 | self.transformer = Transformer(
83 | width=transformer_width,
84 | layers=transformer_layers,
85 | heads=transformer_heads,
86 | attn_mask=self.build_attention_mask()
87 | )
88 |
89 | self.vocab_size = vocab_size
90 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
91 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
92 | self.ln_final = LayerNorm(transformer_width)
93 |
94 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
95 | # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
96 |
97 | self.initialize_parameters()
98 |
99 | def initialize_parameters(self):
100 | nn.init.normal_(self.token_embedding.weight, std=0.02)
101 | nn.init.normal_(self.positional_embedding, std=0.01)
102 |
103 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
104 | attn_std = self.transformer.width ** -0.5
105 | fc_std = (2 * self.transformer.width) ** -0.5
106 | for block in self.transformer.resblocks:
107 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
108 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
109 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
110 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
111 |
112 | if self.text_projection is not None:
113 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
114 |
115 | def build_attention_mask(self):
116 | # lazily create causal attention mask, with full attention between the vision tokens
117 | # pytorch uses additive attention mask; fill with -inf
118 | mask = torch.empty(self.context_length, self.context_length)
119 | mask.fill_(float("-inf"))
120 | mask.triu_(1) # zero out the lower diagonal
121 | return mask
122 |
123 | @property
124 | def device(self):
125 | return self.text_projection.device
126 |
127 | @property
128 | def dtype(self):
129 | return self.text_projection.dtype
130 |
131 | def tokenize(self,
132 | texts: Union[str, List[str]], \
133 | context_length: int = 77) -> torch.LongTensor:
134 | """
135 | """
136 | if isinstance(texts, str):
137 | texts = [texts]
138 |
139 | sot_token = self._tokenizer.encoder["<|startoftext|>"]
140 | eot_token = self._tokenizer.encoder["<|endoftext|>"]
141 | all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
142 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
143 |
144 | for i, tokens in enumerate(all_tokens):
145 | if len(tokens) > context_length:
146 | st = torch.randint(
147 | len(tokens) - context_length + 1, (1,))[0].item()
148 | tokens = tokens[st: st + context_length]
149 | # raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
150 | result[i, :len(tokens)] = torch.tensor(tokens)
151 |
152 | return result
153 |
154 | def encode_text(self, text):
155 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
156 | x = x + self.positional_embedding.type(self.dtype)
157 | x = x.permute(1, 0, 2) # NLD -> LND
158 | x = self.transformer(x)
159 | x = x.permute(1, 0, 2) # LND -> NLD
160 | x = self.ln_final(x).type(self.dtype)
161 | # take features from the eot embedding (eot_token is the highest number in each sequence)
162 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
163 | return x
164 |
165 | def forward(self, captions):
166 | '''
167 | captions: list of strings
168 | '''
169 | text = self.tokenize(captions).to(self.device) # B x L x D
170 | features = self.encode_text(text) # B x D
171 | return features
172 |
173 |
174 | def build_text_encoder(pretrain=True):
175 | text_encoder = CLIPTEXT()
176 | if pretrain:
177 | import clip
178 | pretrained_model, _ = clip.load("ViT-B/32", device='cpu')
179 | state_dict = pretrained_model.state_dict()
180 | to_delete_keys = ["logit_scale", "input_resolution", \
181 | "context_length", "vocab_size"] + \
182 | [k for k in state_dict.keys() if k.startswith('visual.')]
183 | for k in to_delete_keys:
184 | if k in state_dict:
185 | del state_dict[k]
186 | print('Loading pretrained CLIP')
187 | text_encoder.load_state_dict(state_dict)
188 | # import pdb; pdb.set_trace()
189 | return text_encoder
--------------------------------------------------------------------------------
/mmovod/modeling/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import torch
3 | import json
4 | import numpy as np
5 | from torch.nn import functional as F
6 |
7 | def load_class_freq(
8 | path='datasets/metadata/lvis_v1_train_cat_info.json', freq_weight=1.0):
9 | cat_info = json.load(open(path, 'r'))
10 | cat_info = torch.tensor(
11 | [c['image_count'] for c in sorted(cat_info, key=lambda x: x['id'])])
12 | freq_weight = cat_info.float() ** freq_weight
13 | return freq_weight
14 |
15 |
16 | def get_fed_loss_inds(gt_classes, num_sample_cats, C, weight=None):
17 | appeared = torch.unique(gt_classes) # C'
18 | prob = appeared.new_ones(C + 1).float()
19 | prob[-1] = 0
20 | if len(appeared) < num_sample_cats:
21 | if weight is not None:
22 | prob[:C] = weight.float().clone()
23 | prob[appeared] = 0
24 | more_appeared = torch.multinomial(
25 | prob, num_sample_cats - len(appeared),
26 | replacement=False)
27 | appeared = torch.cat([appeared, more_appeared])
28 | return appeared
29 |
30 |
31 |
32 | def reset_cls_test(model, cls_path, num_classes):
33 | model.roi_heads.num_classes = num_classes
34 | if type(cls_path) == str:
35 | print('Resetting zs_weight', cls_path)
36 | zs_weight = torch.tensor(
37 | np.load(cls_path),
38 | dtype=torch.float32).permute(1, 0).contiguous() # D x C
39 | else:
40 | zs_weight = cls_path
41 | zs_weight = torch.cat(
42 | [zs_weight, zs_weight.new_zeros((zs_weight.shape[0], 1))],
43 | dim=1) # D x (C + 1)
44 | if model.roi_heads.box_predictor[0].cls_score.norm_weight:
45 | zs_weight = F.normalize(zs_weight, p=2, dim=0)
46 | zs_weight = zs_weight.to(model.device)
47 | for k in range(len(model.roi_heads.box_predictor)):
48 | del model.roi_heads.box_predictor[k].cls_score.zs_weight
49 | model.roi_heads.box_predictor[k].cls_score.zs_weight = zs_weight
--------------------------------------------------------------------------------
/mmovod/predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import atexit
3 | import bisect
4 | import multiprocessing as mp
5 | from collections import deque
6 | import cv2
7 | import torch
8 |
9 | from detectron2.data import MetadataCatalog
10 | from detectron2.engine.defaults import DefaultPredictor
11 | from detectron2.utils.video_visualizer import VideoVisualizer
12 | from detectron2.utils.visualizer import ColorMode, Visualizer
13 |
14 | from .modeling.utils import reset_cls_test
15 |
16 |
17 | def get_clip_embeddings(vocabulary, prompt='a '):
18 | from detic.modeling.text.text_encoder import build_text_encoder
19 | text_encoder = build_text_encoder(pretrain=True)
20 | text_encoder.eval()
21 | texts = [prompt + x for x in vocabulary]
22 | emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu()
23 | return emb
24 |
25 | BUILDIN_CLASSIFIER = {
26 | 'lvis': 'datasets/metadata/lvis_v1_clip_a+cname.npy',
27 | 'objects365': 'datasets/metadata/o365_clip_a+cnamefix.npy',
28 | 'openimages': 'datasets/metadata/oid_clip_a+cname.npy',
29 | 'coco': 'datasets/metadata/coco_clip_a+cname.npy',
30 | }
31 |
32 | BUILDIN_METADATA_PATH = {
33 | 'lvis': 'lvis_v1_val',
34 | 'objects365': 'objects365_v2_val',
35 | 'openimages': 'oid_val_expanded',
36 | 'coco': 'coco_2017_val',
37 | }
38 |
39 | class VisualizationDemo(object):
40 | def __init__(self, cfg, args,
41 | instance_mode=ColorMode.IMAGE, parallel=False):
42 | """
43 | Args:
44 | cfg (CfgNode):
45 | instance_mode (ColorMode):
46 | parallel (bool): whether to run the model in different processes from visualization.
47 | Useful since the visualization logic can be slow.
48 | """
49 | if args.vocabulary == 'custom':
50 | self.metadata = MetadataCatalog.get("__unused")
51 | self.metadata.thing_classes = args.custom_vocabulary.split(',')
52 | classifier = get_clip_embeddings(self.metadata.thing_classes)
53 | else:
54 | self.metadata = MetadataCatalog.get(
55 | BUILDIN_METADATA_PATH[args.vocabulary])
56 | classifier = BUILDIN_CLASSIFIER[args.vocabulary]
57 |
58 | num_classes = len(self.metadata.thing_classes)
59 | self.cpu_device = torch.device("cpu")
60 | self.instance_mode = instance_mode
61 |
62 | self.parallel = parallel
63 | if parallel:
64 | num_gpu = torch.cuda.device_count()
65 | self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu)
66 | else:
67 | self.predictor = DefaultPredictor(cfg)
68 | reset_cls_test(self.predictor.model, classifier, num_classes)
69 |
70 | def run_on_image(self, image):
71 | """
72 | Args:
73 | image (np.ndarray): an image of shape (H, W, C) (in BGR order).
74 | This is the format used by OpenCV.
75 |
76 | Returns:
77 | predictions (dict): the output of the model.
78 | vis_output (VisImage): the visualized image output.
79 | """
80 | vis_output = None
81 | predictions = self.predictor(image)
82 | # Convert image from OpenCV BGR format to Matplotlib RGB format.
83 | image = image[:, :, ::-1]
84 | visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode)
85 | if "panoptic_seg" in predictions:
86 | panoptic_seg, segments_info = predictions["panoptic_seg"]
87 | vis_output = visualizer.draw_panoptic_seg_predictions(
88 | panoptic_seg.to(self.cpu_device), segments_info
89 | )
90 | else:
91 | if "sem_seg" in predictions:
92 | vis_output = visualizer.draw_sem_seg(
93 | predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
94 | )
95 | if "instances" in predictions:
96 | instances = predictions["instances"].to(self.cpu_device)
97 | vis_output = visualizer.draw_instance_predictions(predictions=instances)
98 |
99 | return predictions, vis_output
100 |
101 | def _frame_from_video(self, video):
102 | while video.isOpened():
103 | success, frame = video.read()
104 | if success:
105 | yield frame
106 | else:
107 | break
108 |
109 | def run_on_video(self, video):
110 | """
111 | Visualizes predictions on frames of the input video.
112 |
113 | Args:
114 | video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be
115 | either a webcam or a video file.
116 |
117 | Yields:
118 | ndarray: BGR visualizations of each video frame.
119 | """
120 | video_visualizer = VideoVisualizer(self.metadata, self.instance_mode)
121 |
122 | def process_predictions(frame, predictions):
123 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
124 | if "panoptic_seg" in predictions:
125 | panoptic_seg, segments_info = predictions["panoptic_seg"]
126 | vis_frame = video_visualizer.draw_panoptic_seg_predictions(
127 | frame, panoptic_seg.to(self.cpu_device), segments_info
128 | )
129 | elif "instances" in predictions:
130 | predictions = predictions["instances"].to(self.cpu_device)
131 | vis_frame = video_visualizer.draw_instance_predictions(frame, predictions)
132 | elif "sem_seg" in predictions:
133 | vis_frame = video_visualizer.draw_sem_seg(
134 | frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device)
135 | )
136 |
137 | # Converts Matplotlib RGB format to OpenCV BGR format
138 | vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR)
139 | return vis_frame
140 |
141 | frame_gen = self._frame_from_video(video)
142 | if self.parallel:
143 | buffer_size = self.predictor.default_buffer_size
144 |
145 | frame_data = deque()
146 |
147 | for cnt, frame in enumerate(frame_gen):
148 | frame_data.append(frame)
149 | self.predictor.put(frame)
150 |
151 | if cnt >= buffer_size:
152 | frame = frame_data.popleft()
153 | predictions = self.predictor.get()
154 | yield process_predictions(frame, predictions)
155 |
156 | while len(frame_data):
157 | frame = frame_data.popleft()
158 | predictions = self.predictor.get()
159 | yield process_predictions(frame, predictions)
160 | else:
161 | for frame in frame_gen:
162 | yield process_predictions(frame, self.predictor(frame))
163 |
164 |
165 | class AsyncPredictor:
166 | """
167 | A predictor that runs the model asynchronously, possibly on >1 GPUs.
168 | Because rendering the visualization takes considerably amount of time,
169 | this helps improve throughput a little bit when rendering videos.
170 | """
171 |
172 | class _StopToken:
173 | pass
174 |
175 | class _PredictWorker(mp.Process):
176 | def __init__(self, cfg, task_queue, result_queue):
177 | self.cfg = cfg
178 | self.task_queue = task_queue
179 | self.result_queue = result_queue
180 | super().__init__()
181 |
182 | def run(self):
183 | predictor = DefaultPredictor(self.cfg)
184 |
185 | while True:
186 | task = self.task_queue.get()
187 | if isinstance(task, AsyncPredictor._StopToken):
188 | break
189 | idx, data = task
190 | result = predictor(data)
191 | self.result_queue.put((idx, result))
192 |
193 | def __init__(self, cfg, num_gpus: int = 1):
194 | """
195 | Args:
196 | cfg (CfgNode):
197 | num_gpus (int): if 0, will run on CPU
198 | """
199 | num_workers = max(num_gpus, 1)
200 | self.task_queue = mp.Queue(maxsize=num_workers * 3)
201 | self.result_queue = mp.Queue(maxsize=num_workers * 3)
202 | self.procs = []
203 | for gpuid in range(max(num_gpus, 1)):
204 | cfg = cfg.clone()
205 | cfg.defrost()
206 | cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu"
207 | self.procs.append(
208 | AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue)
209 | )
210 |
211 | self.put_idx = 0
212 | self.get_idx = 0
213 | self.result_rank = []
214 | self.result_data = []
215 |
216 | for p in self.procs:
217 | p.start()
218 | atexit.register(self.shutdown)
219 |
220 | def put(self, image):
221 | self.put_idx += 1
222 | self.task_queue.put((self.put_idx, image))
223 |
224 | def get(self):
225 | self.get_idx += 1 # the index needed for this request
226 | if len(self.result_rank) and self.result_rank[0] == self.get_idx:
227 | res = self.result_data[0]
228 | del self.result_data[0], self.result_rank[0]
229 | return res
230 |
231 | while True:
232 | # make sure the results are returned in the correct order
233 | idx, res = self.result_queue.get()
234 | if idx == self.get_idx:
235 | return res
236 | insert = bisect.bisect(self.result_rank, idx)
237 | self.result_rank.insert(insert, idx)
238 | self.result_data.insert(insert, res)
239 |
240 | def __len__(self):
241 | return self.put_idx - self.get_idx
242 |
243 | def __call__(self, image):
244 | self.put(image)
245 | return self.get()
246 |
247 | def shutdown(self):
248 | for _ in self.procs:
249 | self.task_queue.put(AsyncPredictor._StopToken())
250 |
251 | @property
252 | def default_buffer_size(self):
253 | return len(self.procs) * 5
254 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | mss
3 | timm==0.5.4
4 | dataclasses
5 | ftfy
6 | regex
7 | fasttext
8 | scikit-learn
9 | lvis
10 | nltk
11 | openai
12 | numpy==1.23.3
13 | git+https://github.com/openai/CLIP.git
14 |
--------------------------------------------------------------------------------
/tools/collate_exemplar_dict.py:
--------------------------------------------------------------------------------
1 | import time
2 | import json
3 | import glob
4 | import copy
5 | import pickle
6 | import os
7 | import argparse
8 |
9 | from collections import defaultdict
10 | from multiprocessing import Pool
11 |
12 | import numpy as np
13 |
14 | from nltk.corpus import wordnet as wn
15 | from lvis import LVIS
16 | from tqdm.auto import tqdm
17 |
18 |
19 | def get_code(syn):
20 | return syn.pos() + str(syn.offset()).zfill(8)
21 |
22 |
23 | def get_lemma_names(syn):
24 | return [lemma.name() for lemma in syn.lemmas()]
25 |
26 |
27 | def get_args():
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument(
30 | "--lvis-dir",
31 | type=str,
32 | default="datasets/lvis",
33 | )
34 | parser.add_argument(
35 | "--imagenet-dir",
36 | type=str,
37 | default="datasets/imagenet",
38 | )
39 | parser.add_argument(
40 | "--visual-genome-dir",
41 | type=str,
42 | default="datasets/VisualGenome",
43 | )
44 | parser.add_argument(
45 | "--output-path",
46 | type=str,
47 | default="datasets/metadata/exemplar_dict.json",
48 | )
49 | args = parser.parse_args()
50 | return args
51 |
52 |
53 | def main(args):
54 | lvis_train_dataset = LVIS(os.path.join(args.lvis_dir, "lvis_v1_train.json"))
55 | # In LVIS the 'stopsign' category is not a wordnet synset and so replace with 'street_sign', a near cousin
56 | street_sign_synset = wn.synset("street_sign.n.01")
57 | lvis_synsets = [
58 | wn.synset(c['synset']) if "stop_sign" not in c['name'] else street_sign_synset
59 | for c in lvis_train_dataset.cats.values()
60 | ]
61 | synset2catid = {v['synset']: v['id'] for v in lvis_train_dataset.cats.values()}
62 | catid2synset = {v['id']: v['synset'] for v in lvis_train_dataset.cats.values()}
63 |
64 | # Lets start by collecting images in ImageNet
65 | # we do this using all ImageNet21k with tree hierarchy as defined in the dataset preparation file
66 | synsets_pos_off_paths = (
67 | glob.glob(os.path.join(args.imagenet_dir, "train/*"))
68 | + glob.glob(os.path.join(args.imagenet_dir, "imagenet21k_small_classes/*"))
69 | )
70 | # import pdb; pdb.set_trace()
71 | synsets_pos_off2path = {os.path.basename(d): d for d in synsets_pos_off_paths}
72 | synsets_pos_off = [os.path.basename(d) for d in synsets_pos_off_paths]
73 | available_imagenet_synsets = [wn.synset_from_pos_and_offset(a[0], int(a[1:])) for a in synsets_pos_off]
74 | direct_match_imagenet_synsets = [v for v in available_imagenet_synsets if v.name() in synset2catid.keys()]
75 | assert len(direct_match_imagenet_synsets) == 997
76 | exemplar_dict_imagenet = defaultdict(list)
77 |
78 | for v in tqdm(direct_match_imagenet_synsets, total=len(direct_match_imagenet_synsets)):
79 | code = get_code(v)
80 | filenames = glob.glob(os.path.join(synsets_pos_off2path[code], "*.JPEG"))
81 | anns = []
82 | for filename in filenames:
83 | # need the last three parts of the path
84 | ann = {
85 | "file_name": "/".join(filename.split("/")[-3:]),
86 | "category_id": synset2catid[v.name()],
87 | "dataset": "imagenet21k",
88 | "synset": v.name(),
89 | }
90 | anns.append(ann)
91 | exemplar_dict_imagenet[v.name()].extend(anns)
92 |
93 | # Now lets collect images from LVIS with area > 32*32
94 | exemplar_dict_lvis = defaultdict(list)
95 | lvis_anns = lvis_train_dataset.load_anns(
96 | lvis_train_dataset.get_ann_ids(
97 | cat_ids=lvis_train_dataset.get_cat_ids(),
98 | area_rng=[32.0**2, float('inf')]
99 | ))
100 |
101 | for ann in tqdm(lvis_anns):
102 | img = lvis_train_dataset.load_imgs([ann['image_id']])[0]
103 | ann['dataset'] = "lvis"
104 | ann['file_name'] = "/".join(img['coco_url'].split("/")[-2:])
105 | ann['synset'] = catid2synset[ann['category_id']]
106 | exemplar_dict_lvis[catid2synset[ann['category_id']]].append(ann)
107 |
108 | # get keys from both dictionaries
109 | keys = list(set(exemplar_dict_imagenet.keys()).union(set(exemplar_dict_lvis.keys())))
110 | exemplar_dict_combined_two = defaultdict(list)
111 | for key in keys:
112 | exemplar_dict_combined_two[key] = exemplar_dict_imagenet[key] + exemplar_dict_lvis[key]
113 |
114 | # let's find the lacking synsets
115 | lacking_synsets = [k.name() for k in lvis_synsets if len(exemplar_dict_combined_two[k.name()]) < 40]
116 | print(f"After collecting exemplars from ImageNet and LVIS, there are still {len(lacking_synsets)} without at least 40 exemplars")
117 |
118 | # Now let's collect images from Visual Genome
119 | exemplar_dict_vg = defaultdict(list)
120 | vg_objects_path = os.path.join(args.visual_genome_dir, "objects.json")
121 | with open(vg_objects_path, "r") as f:
122 | visual_genome_objects = json.load(f)
123 |
124 | vg_images_path = os.path.join(args.visual_genome_dir, "image_data.json")
125 | with open(vg_images_path, "r") as f:
126 | visual_genome_images = json.load(f)
127 | visual_genome_iid2path = {v['image_id']: "/".join(v['url'].split("/")[-2:]) for v in visual_genome_images}
128 |
129 | synsets2boxes = defaultdict(list)
130 | for i, img in tqdm(enumerate(visual_genome_objects), total=len(visual_genome_objects)):
131 | for j, obj in enumerate(img['objects']):
132 | if len(obj['synsets']) != 1:
133 | continue
134 | synsets2boxes[obj['synsets'][0]].append((i, j, obj['w'] * obj['h']))
135 |
136 | # We shall only use visual genome for synsets with less than 40 exemplars
137 | for k in tqdm(lacking_synsets, total=len(lacking_synsets)):
138 | visual_genome_ids = synsets2boxes[k]
139 | anns = []
140 | for a in visual_genome_ids:
141 | if a[-1] < 32**2:
142 | continue
143 | img_objects = visual_genome_objects[a[0]]
144 | iid = img_objects['image_id']
145 | ann = {
146 | 'image_id': iid,
147 | 'dataset': 'visual_genome',
148 | 'file_name': visual_genome_iid2path[iid],
149 | 'category_id': synset2catid[k],
150 | 'synset': k,
151 | }
152 | ann.update(img_objects['objects'][a[1]])
153 | assert ann['synsets'][0] == k
154 | anns.append(ann)
155 | exemplar_dict_vg[k] = anns
156 |
157 | # Now let's combine all the dictionaries
158 | exemplar_dict_combined_three = defaultdict(list)
159 | for syn in synset2catid.keys():
160 | exemplar_dict_combined_three[syn].extend(
161 | exemplar_dict_lvis[syn]
162 | + exemplar_dict_imagenet[syn]
163 | + exemplar_dict_vg[syn]
164 | )
165 |
166 | # At this point there should be at least TEN exemplars in 1160 out of 1203 synsets
167 | # as described in the appendix of the paper, some other synsets in imagenet are suitable
168 | # and in some cases we use close cousins
169 |
170 | manual_synsets = {
171 | "anklet.n.03": ["anklet.n.02"],
172 | "beach_ball.n.01": ["volleyball.n.02"],
173 | "bible.n.01": ["book.n.11"],
174 | "black_flag.n.01": ["flag.n.01"],
175 | "bob.n.05": ["spinner.n.03"],
176 | "bowl.n.08": ["pipe_smoker.n.01"],
177 | "brooch.n.01": ["pectoral.n.02", "bling.n.01"],
178 | "card.n.02": ["business_card.n.01", "library_card.n.01"],
179 | "checkbook.n.01": ["daybook.n.02"],
180 | "coil.n.05": ["coil_spring.n.01"],
181 | "coloring_material.n.01": ["crayon.n.01"],
182 | "crab.n.05": ["shellfish.n.01", "lobster.n.01"],
183 | "cube.n.05": ["die.n.01"],
184 | "cufflink.n.01": ["bling.n.01"],
185 | "dishwasher_detergent.n.01": ["laundry_detergent.n.01", "cleansing_agent.n.01"],
186 | "diving_board.n.01": ["springboard.n.01"],
187 | "dollar.n.02": ["money.n.01", "paper_money.n.01"],
188 | "eel.n.01": ["electric_eel.n.01"],
189 | "escargot.n.01": ["snail.n.01"],
190 | "gargoyle.n.02": ["statue.n.01"],
191 | "gem.n.02": ["crystal.n.01", "bling.n.01"],
192 | "grits.n.01": ["congee.n.01"],
193 | "hardback.n.01": ["book.n.07"],
194 | "jewel.n.01": ["bling.n.01"],
195 | "keycard.n.01": ["magnetic_stripe.n.01"],
196 | "lamb_chop.n.01": ["porkchop.n.01", "rib.n.03"],
197 | "mail_slot.n.01": ["maildrop.n.01", "mailbox.n.01"],
198 | "milestone.n.01": ["cairn.n.01"],
199 | "pad.n.03": ["handstamp.n.01"],
200 | "paperback_book.n.01": ["book.n.07"],
201 | "paperweight.n.01": ["letter_opener.n.01"],
202 | "pennant.n.02": ["bunting.n.01"],
203 | "penny.n.02": ["coin.n.01"],
204 | "plume.n.02": ["headdress.n.01"],
205 | "poker.n.01": ["fire_tongs.n.01"],
206 | "rag_doll.n.01": ["doll.n.01"],
207 | "road_map.n.02": ["map.n.01"],
208 | "scarecrow.n.01": ["creche.n.02"],
209 | "sparkler.n.02": ["firework.n.01"],
210 | "sugarcane.n.01": ["cane_sugar.n.02"],
211 | "water_pistol.n.01": ["pistol.n.01"],
212 | "wedding_ring.n.01": ["ring.n.01"],
213 | "windsock.n.01": ["weathervane.n.01"],
214 | }
215 |
216 | manual_exemplar_dict = defaultdict(list)
217 |
218 | remaining_lacking_synsets = [(k, len(v)) for k, v in exemplar_dict_combined_three.items() if len(v) < 10]
219 | assert len(remaining_lacking_synsets) == len(manual_synsets) == 43
220 |
221 | for k, v in tqdm(manual_synsets.items(), total=len(remaining_lacking_synsets)):
222 | for manual_synset in v:
223 | code = get_code(wn.synset(manual_synset))
224 | if code not in synsets_pos_off2path:
225 | continue
226 |
227 | filenames = glob.glob(os.path.join(synsets_pos_off2path[code], "*.JPEG"))
228 | anns = []
229 | for filename in filenames:
230 | ann = {
231 | "file_name": "/".join(filename.split("/")[-3:]),
232 | "category_id": synset2catid[k],
233 | "dataset": "imagenet21k",
234 | "synset": manual_synset,
235 | }
236 | anns.append(ann)
237 | # if k in remaining_lacking_synsets:
238 | # print(k, manual_synset, len(anns))
239 | manual_exemplar_dict[k].extend(anns)
240 |
241 | for k, v in tqdm(manual_synsets.items(), total=len(remaining_lacking_synsets)):
242 | for manual_synset in v:
243 | if manual_synset not in synsets2boxes.keys():
244 | continue
245 | visual_genome_ids = synsets2boxes[manual_synset]
246 | anns = []
247 | for a in visual_genome_ids:
248 | if a[-1] < 32**2:
249 | continue
250 | img_objects = visual_genome_objects[a[0]]
251 | iid = img_objects['image_id']
252 | ann = {
253 | 'image_id': iid,
254 | 'dataset': 'visual_genome',
255 | 'file_name': visual_genome_iid2path[iid],
256 | 'category_id': synset2catid[k],
257 | 'synset': manual_synset,
258 | }
259 | ann.update(img_objects['objects'][a[1]])
260 | # assert ann['synsets'][0] == k
261 | anns.append(ann)
262 | manual_exemplar_dict[k].extend(anns)
263 |
264 | # combine manual_exemplar_dict with exemplar_dict_combined_three
265 | for k, v in manual_exemplar_dict.items():
266 | exemplar_dict_combined_three[k].extend(v)
267 |
268 | assert min([len(v) for k, v in exemplar_dict_combined_three.items()]) >= 10, "some synsets have less than 10 exemplars"
269 | with open(args.output_path, "w") as f:
270 | json.dump(exemplar_dict_combined_three, f)
271 |
272 |
273 | if __name__ == "__main__":
274 | args = get_args()
275 | main(args)
276 |
--------------------------------------------------------------------------------
/tools/convert-thirdparty-pretrained-model-to-d2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3 | import argparse
4 | import pickle
5 | import torch
6 |
7 | """
8 | Usage:
9 |
10 | cd DETIC_ROOT/models/
11 | wget https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth
12 | python ../tools/convert-thirdparty-pretrained-model-to-d2.py --path resnet50_miil_21k.pth
13 |
14 | wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth
15 | python ../tools/convert-thirdparty-pretrained-model-to-d2.py --path swin_base_patch4_window7_224_22k.pth
16 |
17 | """
18 |
19 |
20 | if __name__ == "__main__":
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument('--path', default='')
23 | args = parser.parse_args()
24 |
25 | print('Loading', args.path)
26 | model = torch.load(args.path, map_location="cpu")
27 | # import pdb; pdb.set_trace()
28 | if 'model' in model:
29 | model = model['model']
30 | if 'state_dict' in model:
31 | model = model['state_dict']
32 | ret = {
33 | "model": model,
34 | "__author__": "third_party",
35 | "matching_heuristics": True
36 | }
37 | out_path = args.path.replace('.pth', '.pkl')
38 | print('Saving to', out_path)
39 | pickle.dump(ret, open(out_path, "wb"))
40 |
--------------------------------------------------------------------------------
/tools/create_imagenetlvis_json.py:
--------------------------------------------------------------------------------
1 | # edited from script in Detic from Facebook, Inc. and its affiliates.
2 | import argparse
3 | import json
4 | import os
5 | from nltk.corpus import wordnet as wn
6 | from detectron2.data.detection_utils import read_image
7 |
8 |
9 | def get_code(syn):
10 | return syn.pos() + str(syn.offset()).zfill(8)
11 |
12 |
13 | if __name__ == '__main__':
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--imagenet-path', default='datasets/imagenet/imagenet21k_P')
16 | parser.add_argument('--lvis-meta-path', default='datasets/lvis/lvis_v1_val.json')
17 | parser.add_argument('--out-path', default='datasets/imagenet/annotations/imagenet_lvis_image_info.json')
18 | args = parser.parse_args()
19 |
20 | print('Loading LVIS meta')
21 | data = json.load(open(args.lvis_meta_path, 'r'))
22 | print('Done')
23 | synset2cat = {x['synset']: x for x in data['categories']}
24 | count = 0
25 | images = []
26 | image_counts = {}
27 | synset2folders = {}
28 | for synset in synset2cat:
29 | if synset == "stop_sign.n.01":
30 | synset2folders["stop_sign.n.01"] = []
31 | continue
32 | code = get_code(wn.synset(synset))
33 | folders = [
34 | os.path.join(args.imagenet_path, folder, code) for folder in ["train", "val", "imagenet21k_small_classes"]
35 | ]
36 | synset2folders[synset] = [f for f in folders if os.path.exists(f)]
37 |
38 | for synset, synset_folders in synset2folders.items():
39 | if len(synset_folders) == 0:
40 | continue
41 | cat = synset2cat[synset]
42 | cat_id = cat['id']
43 | cat_name = cat['name']
44 | cat_images = []
45 | files = []
46 | for folder in synset_folders:
47 | folder_files = os.listdir(folder)
48 | for file in folder_files:
49 | count = count + 1
50 | # file_name only needs to be last two parts of path
51 | # import pdb; pdb.set_trace()
52 | file_name = '{}/{}/{}'.format(*(folder.split("/")[-2:]), file)
53 | assert os.path.join(folder, file) == os.path.join(args.imagenet_path, file_name)
54 | img = read_image(os.path.join(args.imagenet_path, file_name))
55 | h, w = img.shape[:2]
56 | image = {
57 | 'id': count,
58 | 'file_name': file_name,
59 | 'pos_category_ids': [cat_id],
60 | 'width': w,
61 | 'height': h
62 | }
63 | cat_images.append(image)
64 | images.extend(cat_images)
65 | image_counts[cat_id] = len(cat_images)
66 | print(cat_id, cat_name, len(cat_images))
67 | print('# Images', len(images))
68 | for x in data['categories']:
69 | x['image_count'] = image_counts[x['id']] if x['id'] in image_counts else 0
70 | out = {'categories': data['categories'], 'images': images, 'annotations': []}
71 | print('Writing to', args.out_path)
72 | json.dump(out, open(args.out_path, 'w'))
73 |
--------------------------------------------------------------------------------
/tools/dump_clip_features.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import argparse
3 | import json
4 | import torch
5 | import numpy as np
6 | import itertools
7 | from nltk.corpus import wordnet
8 | import sys
9 |
10 | if __name__ == '__main__':
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--ann', default='datasets/lvis/lvis_v1_val.json')
13 | parser.add_argument('--out_path', default='')
14 | parser.add_argument('--prompt', default='a')
15 | parser.add_argument('--model', default='clip')
16 | parser.add_argument('--clip_model', default="ViT-B/32")
17 | parser.add_argument('--fix_space', action='store_true')
18 | parser.add_argument('--use_underscore', action='store_true')
19 | parser.add_argument('--avg_synonyms', action='store_true')
20 | parser.add_argument('--use_wn_name', action='store_true')
21 | args = parser.parse_args()
22 |
23 | print('Loading', args.ann)
24 | data = json.load(open(args.ann, 'r'))
25 | cat_names = [x['name'] for x in \
26 | sorted(data['categories'], key=lambda x: x['id'])]
27 | if 'synonyms' in data['categories'][0]:
28 | if args.use_wn_name:
29 | synonyms = [
30 | [xx.name() for xx in wordnet.synset(x['synset']).lemmas()] \
31 | if x['synset'] != 'stop_sign.n.01' else ['stop_sign'] \
32 | for x in sorted(data['categories'], key=lambda x: x['id'])]
33 | else:
34 | synonyms = [x['synonyms'] for x in \
35 | sorted(data['categories'], key=lambda x: x['id'])]
36 | else:
37 | synonyms = []
38 | if args.fix_space:
39 | cat_names = [x.replace('_', ' ') for x in cat_names]
40 | if args.use_underscore:
41 | cat_names = [x.strip().replace('/ ', '/').replace(' ', '_') for x in cat_names]
42 | print('cat_names', cat_names)
43 | device = "cuda" if torch.cuda.is_available() else "cpu"
44 |
45 | if args.prompt == 'a':
46 | sentences = ['a ' + x for x in cat_names]
47 | sentences_synonyms = [['a ' + xx for xx in x] for x in synonyms]
48 | if args.prompt == 'none':
49 | sentences = [x for x in cat_names]
50 | sentences_synonyms = [[xx for xx in x] for x in synonyms]
51 | elif args.prompt == 'photo':
52 | sentences = ['a photo of a {}'.format(x) for x in cat_names]
53 | sentences_synonyms = [['a photo of a {}'.format(xx) for xx in x] \
54 | for x in synonyms]
55 | elif args.prompt == 'scene':
56 | sentences = ['a photo of a {} in the scene'.format(x) for x in cat_names]
57 | sentences_synonyms = [['a photo of a {} in the scene'.format(xx) for xx in x] \
58 | for x in synonyms]
59 |
60 | print('sentences_synonyms', len(sentences_synonyms), \
61 | sum(len(x) for x in sentences_synonyms))
62 | if args.model == 'clip':
63 | import clip
64 | print('Loading CLIP')
65 | model, preprocess = clip.load(args.clip_model, device=device)
66 | if args.avg_synonyms:
67 | sentences = list(itertools.chain.from_iterable(sentences_synonyms))
68 | print('flattened_sentences', len(sentences))
69 | text = clip.tokenize(sentences).to(device)
70 | with torch.no_grad():
71 | if len(text) > 10000:
72 | text_features = torch.cat([
73 | model.encode_text(text[:len(text) // 2]),
74 | model.encode_text(text[len(text) // 2:])],
75 | dim=0)
76 | else:
77 | text_features = model.encode_text(text)
78 | print('text_features.shape', text_features.shape)
79 | if args.avg_synonyms:
80 | synonyms_per_cat = [len(x) for x in sentences_synonyms]
81 | text_features = text_features.split(synonyms_per_cat, dim=0)
82 | text_features = [x.mean(dim=0) for x in text_features]
83 | text_features = torch.stack(text_features, dim=0)
84 | print('after stack', text_features.shape)
85 | text_features = text_features.cpu().numpy()
86 | elif args.model in ['bert', 'roberta']:
87 | from transformers import AutoTokenizer, AutoModel
88 | if args.model == 'bert':
89 | model_name = 'bert-large-uncased'
90 | if args.model == 'roberta':
91 | model_name = 'roberta-large'
92 | tokenizer = AutoTokenizer.from_pretrained(model_name)
93 | model = AutoModel.from_pretrained(model_name)
94 | model.eval()
95 | if args.avg_synonyms:
96 | sentences = list(itertools.chain.from_iterable(sentences_synonyms))
97 | print('flattened_sentences', len(sentences))
98 | inputs = tokenizer(sentences, padding=True, return_tensors="pt")
99 | with torch.no_grad():
100 | model_outputs = model(**inputs)
101 | outputs = model_outputs.pooler_output
102 | text_features = outputs.detach().cpu()
103 | if args.avg_synonyms:
104 | synonyms_per_cat = [len(x) for x in sentences_synonyms]
105 | text_features = text_features.split(synonyms_per_cat, dim=0)
106 | text_features = [x.mean(dim=0) for x in text_features]
107 | text_features = torch.stack(text_features, dim=0)
108 | print('after stack', text_features.shape)
109 | text_features = text_features.numpy()
110 | print('text_features.shape', text_features.shape)
111 | else:
112 | assert 0, args.model
113 | if args.out_path != '':
114 | print('saveing to', args.out_path)
115 | np.save(open(args.out_path, 'wb'), text_features)
116 | import pdb; pdb.set_trace()
117 |
--------------------------------------------------------------------------------
/tools/dump_clip_features_lvis_sentences.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import torch
4 | import numpy as np
5 | from tqdm.auto import tqdm
6 |
7 | if __name__ == '__main__':
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument(
10 | "--descriptions-path",
11 | type=str,
12 | default="datasets/metadata/lvis_gpt3_text-davinci-002_descriptions_author.json"
13 | )
14 | parser.add_argument(
15 | "--ann-path",
16 | type=str,
17 | default="datasets/lvis/lvis_v1_val.json"
18 | )
19 | parser.add_argument(
20 | "--out-path",
21 | type=str,
22 | default="datasets/metadata/lvis_gpt3_text-davinci-002_features_author.npy"
23 | )
24 | parser.add_argument('--model', default='clip')
25 | parser.add_argument('--clip_model', default="ViT-B/32")
26 |
27 | args = parser.parse_args()
28 |
29 | print("Loading descriptions from: {}".format(args.descriptions_path))
30 | with open(args.descriptions_path, 'r') as f:
31 | descriptions = json.load(f)
32 |
33 | print("Loading annotations from: {}".format(args.ann_path))
34 | with open(args.ann_path, 'r') as f:
35 | ann_data = json.load(f)
36 |
37 | lvis_cats = ann_data['categories']
38 | lvis_cats = [(c['id'], c['synonyms'][0].replace("_", " ")) for c in lvis_cats]
39 | sentences_per_cat = [descriptions[c[1]] for c in lvis_cats]
40 |
41 | device = "cuda" if torch.cuda.is_available() else "cpu"
42 | print(
43 | "Total number of sentences for {} classes: {}".format(
44 | len(sentences_per_cat), sum(len(x) for x in sentences_per_cat))
45 | )
46 |
47 | if args.model == 'clip':
48 | import clip
49 | print('Loading CLIP')
50 | model, preprocess = clip.load(args.clip_model, device=device)
51 | model.eval()
52 | all_text_features = []
53 | for cat_sentences in tqdm(sentences_per_cat):
54 | text = clip.tokenize(cat_sentences, truncate=True)
55 | with torch.no_grad():
56 | if len(text) > 10000:
57 | split_text = text.split(128)
58 | split_features = []
59 | for t in tqdm(split_text, total=len(split_text)):
60 | split_features.append(model.encode_text(t.to(device)).cpu())
61 | text_features = torch.cat(split_features, dim=0)
62 | # text_features = torch.cat([
63 | # model.encode_text(t) for t in text.split(128)
64 | # ], dim=0)
65 | else:
66 | text_features = model.encode_text(text.to(device))
67 | all_text_features.append(text_features.mean(dim=0))
68 | all_text_features = torch.stack(all_text_features)
69 | print("Output text features shape: ", all_text_features.shape)
70 | text_features = text_features.cpu().numpy()
71 |
72 | else:
73 | assert 0, "Model {} is not supported only clip".format(args.model)
74 | if args.out_path != '':
75 | print('saveing to', args.out_path)
76 | np.save(open(args.out_path, 'wb'), text_features)
77 |
--------------------------------------------------------------------------------
/tools/generate_descriptions.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import json
4 |
5 | import openai
6 | from openai.error import RateLimitError
7 | from lvis import LVIS
8 |
9 |
10 | API_KEY = "YOUR_API_KEY"
11 |
12 |
13 | def get_args():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument(
16 | "--ann-path",
17 | type=str,
18 | default="datasets/lvis/lvis_v1_val.json"
19 | )
20 | parser.add_argument(
21 | "--output-path",
22 | type=str,
23 | default="datasets/metadata/lvis_gpt3_{}_descriptions_own.json"
24 | )
25 | parser.add_argument(
26 | "--openai-model",
27 | type=str,
28 | default="text-davinci-002"
29 | )
30 |
31 |
32 | def main(args):
33 | lvis_gt = LVIS(args.ann_path)
34 | categories = sorted(lvis_gt.cats.values(), key=lambda x: x["id"])
35 |
36 | category_list = [c['synonyms'][0].replace('_', ' ') for c in categories]
37 | all_responses = {}
38 | vowel_list = ['a', 'e', 'i', 'o', 'u']
39 |
40 | for i, category in enumerate(category_list):
41 | if category[0] in vowel_list:
42 | article = 'an'
43 | else:
44 | article = 'a'
45 | prompts = []
46 | prompts.append("Describe what " + article + " " + category + " looks like.")
47 |
48 | all_result = []
49 | # call openai api taking into account rate limits
50 | for curr_prompt in prompts:
51 | try:
52 | response = openai.Completion.create(
53 | model="text-davinci-002",
54 | prompt=curr_prompt,
55 | temperature=0.99,
56 | max_tokens=50,
57 | n=10,
58 | stop="."
59 | )
60 | except RateLimitError:
61 | print("Hit rate limit. Waiting 15 seconds.")
62 | time.sleep(15)
63 | response = openai.Completion.create(
64 | model="text-davinci-002",
65 | prompt=curr_prompt,
66 | temperature=.99,
67 | max_tokens=50,
68 | n=10,
69 | stop="."
70 | )
71 |
72 | time.sleep(0.15)
73 |
74 | for r in range(len(response["choices"])):
75 | result = response["choices"][r]["text"]
76 | all_result.append(result.replace("\n\n", "") + ".")
77 | all_responses[category] = all_result
78 |
79 | output_path = args.output_path.format(args.openai_model)
80 | with open(output_path, 'w') as f:
81 | json.dump(all_responses, f, indent=4)
82 |
83 |
84 | if __name__ == "__main__":
85 | args = get_args()
86 | main(args)
87 |
--------------------------------------------------------------------------------
/tools/get_exemplars_tta.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | import clip
5 | import json
6 | from PIL import Image
7 | import sys
8 | import argparse
9 | # from lvis import LVIS
10 | import os
11 | # from pprint import pprint
12 | import random
13 | import numpy as np
14 | from tqdm.auto import tqdm
15 | import torch
16 | import torchvision.transforms as tvt
17 | from torch.utils.data import Dataset, DataLoader
18 |
19 | try:
20 | from torchvision.transforms import InterpolationMode
21 | BICUBIC = InterpolationMode.BICUBIC
22 | except ImportError:
23 | BICUBIC = Image.BICUBIC
24 |
25 |
26 | # PATHS = {
27 | # "imagenet21k": "",
28 | # "visual_genome": "/scratch/local/hdd/prannay/datasets/VisualGenome/",
29 | # "lvis": "/scratch/local/hdd/prannay/datasets/coco/",
30 | # }
31 |
32 |
33 | def _convert_image_to_rgb(image: Image.Image):
34 | return image.convert("RGB")
35 |
36 |
37 | def get_crop(img, bb, context=0.0, square=True):
38 | # print(bb)
39 | x1, y1, w, h = bb
40 | W, H = img.size
41 | y, x = y1 + h / 2.0, x1 + w / 2.0
42 | h, w = h * (1. + context), w * (1. + context)
43 | if square:
44 | w = max(w, h)
45 | h = max(w, h)
46 | # print(x, y, w, h)
47 | x1, x2 = x - w / 2.0, x + w / 2.0
48 | y1, y2 = y - h / 2.0, y + h / 2.0
49 | # print([x1, y1, x2, y2])
50 | x1, x2 = max(0, x1), min(W, x2)
51 | y1, y2 = max(0, y1), min(H, y2)
52 | # print([x1, y1, x2, y2])
53 | bb_new = [int(c) for c in [x1, y1, x2, y2]]
54 | # print(bb_new)
55 | crop = img.crop(bb_new)
56 | return crop
57 |
58 |
59 | def run_crop(d, paths, context=0.4, square=True):
60 | dataset = d['dataset']
61 | file_name = os.path.join(paths[dataset], d['file_name'])
62 | # with open(file_name, "rb") as f:
63 | img = Image.open(file_name)
64 | if dataset == "imagenet21k":
65 | bb = [0, 0, 0, 0]
66 | return img
67 | elif dataset == "lvis":
68 | bb = [
69 | int(c)
70 | for c in [
71 | d['bbox'][0] // 1,
72 | d['bbox'][1] // 1,
73 | d['bbox'][2] // 1 + 1,
74 | d['bbox'][3] // 1 + 1
75 | ]
76 | ]
77 | elif dataset == "visual_genome":
78 | bb = [int(c) for c in [d['x'], d['y'], d['w'], d['h']]]
79 | return get_crop(img, bb, context=context, square=square)
80 |
81 |
82 | def get_args():
83 | parser = argparse.ArgumentParser()
84 | parser.add_argument(
85 | "--ann-path",
86 | type=str,
87 | default="datasets/metadata/lvis_image_exemplar_dict_K-005_own.json"
88 | )
89 | parser.add_argument(
90 | "--output-path",
91 | type=str,
92 | default="datasets/metadata/lvis_image_exemplar_features_avg_K-005_own.npy"
93 | )
94 | parser.add_argument(
95 | "--lvis-img-dir",
96 | type=str,
97 | default="datasets/coco/"
98 | )
99 | parser.add_argument(
100 | "--imagenet-img-dir",
101 | type=str,
102 | default="datasets/imagenet"
103 | )
104 | parser.add_argument(
105 | "--visual-genome-img-dir",
106 | type=str,
107 | default="datasets/VisualGenome/"
108 | )
109 | parser.add_argument("--num-augs", type=int, default=5)
110 | parser.add_argument("--nw", type=int, default=8)
111 |
112 | args = parser.parse_args()
113 | return args
114 |
115 |
116 | def main(args):
117 |
118 | anns_path = args.ann_path
119 | num_augs = args.num_augs
120 |
121 | model, transform = clip.load("ViT-B/32", device="cpu")
122 | del model.transformer
123 |
124 | model = model.to("cuda:0")
125 |
126 | run(anns_path, num_augs, model, transform, args)
127 |
128 |
129 | class CropDataset(Dataset):
130 | def __init__(
131 | self,
132 | exemplar_dict,
133 | num_augs,
134 | transform,
135 | transform2,
136 | args,
137 | ):
138 | self.exemplar_dict = exemplar_dict
139 | self.transform = transform
140 | self.transform2 = transform2
141 | self.num_augs = num_augs
142 | self.paths = {
143 | "imagenet21k": args.imagenet_img_dir,
144 | "visual_genome": args.visual_genome_img_dir,
145 | "lvis": args.lvis_img_dir,
146 | }
147 |
148 | def __len__(self):
149 | return len(self.exemplar_dict)
150 |
151 | def __getitem__(self, idx):
152 | chosen_anns = self.exemplar_dict[idx]
153 | crops = [run_crop(ann, self.paths) for ann in chosen_anns]
154 | # add the tta in here somewhere
155 | crops = [
156 | self.transform(self.transform2(crop))
157 | for crop in crops for _ in range(self.num_augs)
158 | ]
159 | return torch.stack(crops)
160 |
161 |
162 | def run(anns_path, num_augs, model, transform, args):
163 | random.seed(100000 + num_augs)
164 | torch.manual_seed(100000 + num_augs)
165 | s = 0.25
166 | color_jitter = tvt.ColorJitter(
167 | 0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
168 | )
169 | transform2 = tvt.Compose([
170 | tvt.RandomResizedCrop(size=224 * 4, scale=(0.8, 1.0), interpolation=BICUBIC),
171 | tvt.RandomHorizontalFlip(), # with 0.5 probability
172 | _convert_image_to_rgb,
173 | tvt.RandomApply([color_jitter], p=0.8),
174 | ])
175 | with open(anns_path, "r") as fp:
176 | exemplar_dict = json.load(fp)
177 | dataset = CropDataset(exemplar_dict, num_augs, transform, transform2, args)
178 | dataloader = DataLoader(
179 | dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=args.nw)
180 | feats = []
181 | # chosen_anns_all = []
182 | for crops in tqdm(dataloader, total=len(dataloader)):
183 | # rng = np.random.default_rng(seed)
184 | # synset = catid2synset[cat_id]
185 | # trial_anns = exemplar_dict[synset]
186 | # probs = [a['area'] for a in trial_anns]
187 | # probs = np.array(probs) / sum(probs)
188 | # chosen_anns = rng.choice(trial_anns, size=K, p=probs, replace=len(trial_anns) < K)
189 | # if len(trial_anns) < K:
190 | # chosen_anns = rng.choice(trial_anns, size=K, p=[a['area'] for a in trial_anns], replace=False)
191 | # else:
192 | # chosen_anns = rng.sample(trial_anns, k=K, counts=[a['area'] for a in trial_anns
193 | # crops = [run_crop(ann) for ann in chosen_anns]
194 | # # add the tta in here somewhere
195 | # crops = [transform(transform2(crop)) for crop in crops for _ in range(num_augs)]
196 | with torch.no_grad():
197 | image_embeddings = model.encode_image(crops[0].to("cuda:0"))
198 | # print(image_embeddings.size())
199 | feats.append(image_embeddings.cpu())
200 | # chosen_anns_all.append(chosen_anns.tolist())
201 | # crops.append([run_crop(ann) for ann in chosen_anns])
202 |
203 | feats_all = torch.stack(feats)
204 |
205 | save_basename = args.output_path
206 |
207 | np.save(save_basename, feats_all.mean(dim=1).numpy())
208 |
209 |
210 | if __name__ == "__main__":
211 | args = get_args()
212 | main(args)
213 |
--------------------------------------------------------------------------------
/tools/get_lvis_cat_info.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import argparse
3 | import json
4 |
5 | if __name__ == '__main__':
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument("--ann", default='datasets/lvis/lvis_v1_train.json')
8 | parser.add_argument("--add_freq", action='store_true')
9 | parser.add_argument("--r_thresh", type=int, default=10)
10 | parser.add_argument("--c_thresh", type=int, default=100)
11 | parser.add_argument("--bypass", action='store_true')
12 | args = parser.parse_args()
13 |
14 | print('Loading', args.ann)
15 | data = json.load(open(args.ann, 'r'))
16 | cats = data['categories']
17 | image_count = {x['id']: set() for x in cats}
18 | ann_count = {x['id']: 0 for x in cats}
19 | if args.bypass:
20 | for x in data['images']:
21 | for y in x['pos_category_ids']:
22 | image_count[y].add(x['id'])
23 | ann_count[y] += 1
24 | else:
25 | for x in data['annotations']:
26 | image_count[x['category_id']].add(x['image_id'])
27 | ann_count[x['category_id']] += 1
28 | num_freqs = {x: 0 for x in ['r', 'f', 'c']}
29 | for x in cats:
30 | x['image_count'] = len(image_count[x['id']])
31 | x['instance_count'] = ann_count[x['id']]
32 | if args.add_freq:
33 | freq = 'f'
34 | if x['image_count'] < args.c_thresh:
35 | freq = 'c'
36 | if x['image_count'] < args.r_thresh:
37 | freq = 'r'
38 | x['frequency'] = freq
39 | num_freqs[freq] += 1
40 | print(cats)
41 | image_counts = sorted([x['image_count'] for x in cats])
42 | # print('image count', image_counts)
43 | # import pdb; pdb.set_trace()
44 | if args.add_freq:
45 | for x in ['r', 'c', 'f']:
46 | print(x, num_freqs[x])
47 | out = cats # {'categories': cats}
48 | out_path = args.ann[:-5] + '_cat_info.json'
49 | print('Saving to', out_path)
50 | json.dump(out, open(out_path, 'w'))
51 |
52 |
--------------------------------------------------------------------------------
/tools/norm_feat_sum_norm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import numpy as np
3 |
4 | def get_args():
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument(
7 | "--feat1-path",
8 | type=str,
9 | default="datasets/metadata/lvis_gpt3_text-davinci-002_features_author.npy"
10 | )
11 | parser.add_argument(
12 | "--feat2-path",
13 | type=str,
14 | default="datasets/metadata/lvis_image_exemplar_features_avg_K-005_own.npy"
15 | )
16 | parser.add_argument(
17 | "--out-path",
18 | type=str,
19 | default="datasets/metadata/lvis_multi-modal_avg_K-005_own.npy"
20 | )
21 | return parser.parse_args()
22 |
23 |
24 | def main(args):
25 |
26 | feat1 = np.load(args.feat1_path)
27 | feat2 = np.load(args.feat2_path)
28 | # l2 normalize each
29 | feat1 = feat1 / np.linalg.norm(feat1, axis=1, keepdims=True)
30 | feat2 = feat2 / np.linalg.norm(feat2, axis=1, keepdims=True)
31 | # take sum
32 | feat = feat1 + feat2
33 | # l2 normalize again
34 | feat = feat / np.linalg.norm(feat, axis=1, keepdims=True)
35 | np.save(args.out_path, feat)
36 |
37 |
38 | if __name__ == '__main__':
39 | args = get_args()
40 | main(args)
41 |
--------------------------------------------------------------------------------
/tools/remove_lvis_rare.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import argparse
3 | import json
4 |
5 | if __name__ == '__main__':
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument('--ann', default='datasets/lvis/lvis_v1_train.json')
8 | args = parser.parse_args()
9 |
10 | print('Loading', args.ann)
11 | data = json.load(open(args.ann, 'r'))
12 | catid2freq = {x['id']: x['frequency'] for x in data['categories']}
13 | print('ori #anns', len(data['annotations']))
14 | exclude = ['r']
15 | data['annotations'] = [x for x in data['annotations'] \
16 | if catid2freq[x['category_id']] not in exclude]
17 | print('filtered #anns', len(data['annotations']))
18 | out_path = args.ann[:-5] + '_norare.json'
19 | print('Saving to', out_path)
20 | json.dump(data, open(out_path, 'w'))
21 |
--------------------------------------------------------------------------------
/tools/sample_exemplars.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | from lvis import LVIS
4 | from tqdm.auto import tqdm
5 | import numpy as np
6 |
7 |
8 | def get_args():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument(
11 | "--exemplar-dict-path",
12 | type=str,
13 | default="datasets/metadata/exemplar_dict.json"
14 | )
15 | parser.add_argument(
16 | "--lvis-ann-path",
17 | type=str,
18 | default="datasets/lvis/lvis_v1_val.json"
19 | )
20 | parser.add_argument(
21 | "--out-path",
22 | type=str,
23 | default="datasets/metadata/lvis_image_exemplar_dict_K-005_own.json"
24 | )
25 | parser.add_argument(
26 | "-K",
27 | type=int,
28 | default=5
29 | )
30 | parser.add_argument(
31 | "--seed",
32 | type=int,
33 | default=42
34 | )
35 | args = parser.parse_args()
36 | return args
37 |
38 |
39 | def run(catid2synset, exemplar_dict, K, seed, out_path):
40 | chosen_anns_all = []
41 | for cat_id in tqdm(sorted(catid2synset.keys()), total=len(catid2synset)):
42 | rng = np.random.default_rng(seed)
43 | synset = catid2synset[cat_id]
44 | trial_anns = exemplar_dict[synset]
45 | probs = [a['area'] for a in trial_anns]
46 | probs = np.array(probs) / sum(probs)
47 | chosen_anns = rng.choice(trial_anns, size=K, p=probs, replace=False)
48 | chosen_anns_all.append(chosen_anns.tolist())
49 | with open(out_path, "w") as f:
50 | json.dump(chosen_anns_all, f)
51 |
52 |
53 | def main(args):
54 | lvis_cats = LVIS(args.lvis_ann_path).cats
55 | catid2synset = {v['id']: v['synset'] for v in lvis_cats.values()}
56 | with open(args.exemplar_dict_path, 'r') as f:
57 | exemplar_dict = json.load(f)
58 | for k, v in exemplar_dict.items():
59 | for ann in v:
60 | dataset = ann['dataset']
61 | if dataset == "imagenet21k":
62 | ann['area'] = 100. * 100.
63 | elif dataset == "lvis":
64 | ann['area'] = float(ann['bbox'][2] * ann['bbox'][3])
65 | elif dataset == "visual_genome":
66 | ann['area'] = float(ann['w'] * ann['h'])
67 |
68 | run(catid2synset, exemplar_dict, args.K, args.seed, args.out_path)
69 |
70 |
71 | if __name__ == "__main__":
72 | args = get_args()
73 | main(args)
74 |
--------------------------------------------------------------------------------
/tools/unzip_imagenet_lvis.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import os
3 | import argparse
4 |
5 | if __name__ == '__main__':
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument('--src_path', default='datasets/imagenet/ImageNet-21K/')
8 | parser.add_argument('--dst_path', default='datasets/imagenet/ImageNet-LVIS/')
9 | parser.add_argument('--data_path', default='datasets/imagenet_lvis_wnid.txt')
10 | args = parser.parse_args()
11 |
12 | f = open(args.data_path)
13 | for i, line in enumerate(f):
14 | cmd = 'mkdir {x} && tar -xf {src}/{l}.tar -C {x}'.format(
15 | src=args.src_path,
16 | l=line.strip(),
17 | x=args.dst_path + '/' + line.strip())
18 | print(i, cmd)
19 | os.system(cmd)
20 |
--------------------------------------------------------------------------------
/train_net_auto.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import logging
3 | import os
4 | import sys
5 | from collections import OrderedDict
6 | import torch
7 | from torch.nn.parallel import DistributedDataParallel
8 | import time
9 | import datetime
10 | from typing import Any
11 |
12 | from fvcore.common.timer import Timer
13 | from iopath.common.file_io import PathManager
14 | import detectron2.utils.comm as comm
15 | from detectron2.checkpoint import (
16 | DetectionCheckpointer, PeriodicCheckpointer, Checkpointer
17 | )
18 | from detectron2.config import get_cfg
19 | from detectron2.data import (
20 | MetadataCatalog,
21 | build_detection_test_loader,
22 | )
23 | from detectron2.engine import default_argument_parser, default_setup, launch
24 |
25 | from detectron2.evaluation import (
26 | inference_on_dataset,
27 | print_csv_format,
28 | LVISEvaluator,
29 | COCOEvaluator,
30 | )
31 | from detectron2.modeling import build_model
32 | from detectron2.solver import build_lr_scheduler, build_optimizer
33 | from detectron2.utils.events import (
34 | CommonMetricPrinter,
35 | EventStorage,
36 | JSONWriter,
37 | TensorboardXWriter,
38 | )
39 | from detectron2.data.dataset_mapper import DatasetMapper
40 | from detectron2.data.build import build_detection_train_loader
41 | from detectron2.utils.logger import setup_logger
42 | from torch.cuda.amp import GradScaler
43 |
44 | sys.path.insert(0, 'third_party/CenterNet2/')
45 | from centernet.config import add_centernet_config
46 |
47 | from mmovod.config import add_mmovod_config
48 | from mmovod.data.custom_build_augmentation import build_custom_augmentation
49 | from mmovod.data.custom_dataset_dataloader import build_custom_train_loader
50 | from mmovod.data.custom_dataset_mapper import CustomDatasetMapper
51 | from mmovod.custom_solver import build_custom_optimizer
52 | from mmovod.modeling.utils import reset_cls_test
53 |
54 |
55 | logger = logging.getLogger("detectron2")
56 |
57 |
58 | class LatestCheckpointer:
59 | """
60 | Save checkpoints periodically. When `.step(iteration)` is called, it will
61 | execute `checkpointer.save` on the given checkpointer, if iteration is a
62 | multiple of period or if `max_iter` is reached.
63 |
64 | Attributes:
65 | checkpointer (Checkpointer): the underlying checkpointer object
66 | """
67 |
68 | def __init__(
69 | self,
70 | checkpointer: Checkpointer,
71 | period: int,
72 | file_prefix: str = "model",
73 | ) -> None:
74 | """
75 | Args:
76 | checkpointer: the checkpointer object used to save checkpoints.
77 | period (int): the period to save checkpoint.
78 | max_iter (int): maximum number of iterations. When it is reached,
79 | a checkpoint named "{file_prefix}_final" will be saved.
80 | max_to_keep (int): maximum number of most current checkpoints to keep,
81 | previous checkpoints will be deleted
82 | file_prefix (str): the prefix of checkpoint's filename
83 | """
84 | self.checkpointer = checkpointer
85 | self.period = int(period)
86 | self.path_manager: PathManager = checkpointer.path_manager
87 | self.file_prefix = file_prefix
88 |
89 | def step(self, iteration: int, **kwargs: Any) -> None:
90 | """
91 | Perform the appropriate action at the given iteration.
92 |
93 | Args:
94 | iteration (int): the current iteration, ranged in [0, max_iter-1].
95 | kwargs (Any): extra data to save, same as in
96 | :meth:`Checkpointer.save`.
97 | """
98 | iteration = int(iteration)
99 | additional_state = {"iteration": iteration}
100 | additional_state.update(kwargs)
101 |
102 | if (iteration + 1) % self.period == 0:
103 | self.checkpointer.save(
104 | f"{self.file_prefix}_latest", **additional_state
105 | )
106 |
107 | def save(self, name: str, **kwargs: Any) -> None:
108 | """
109 | Same argument as :meth:`Checkpointer.save`.
110 | Use this method to manually save checkpoints outside the schedule.
111 |
112 | Args:
113 | name (str): file name.
114 | kwargs (Any): extra data to save, same as in
115 | :meth:`Checkpointer.save`.
116 | """
117 | self.checkpointer.save(name, **kwargs)
118 |
119 |
120 | def do_test(cfg, model):
121 | results = OrderedDict()
122 | for d, dataset_name in enumerate(cfg.DATASETS.TEST):
123 | if cfg.MODEL.RESET_CLS_TESTS:
124 | reset_cls_test(
125 | model,
126 | cfg.MODEL.TEST_CLASSIFIERS[d],
127 | cfg.MODEL.TEST_NUM_CLASSES[d])
128 | mapper = None if cfg.INPUT.TEST_INPUT_TYPE == 'default' \
129 | else DatasetMapper(
130 | cfg, False, augmentations=build_custom_augmentation(cfg, False))
131 | data_loader = build_detection_test_loader(cfg, dataset_name, mapper=mapper)
132 | output_folder = os.path.join(
133 | cfg.OUTPUT_DIR, "inference_{}".format(dataset_name))
134 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
135 |
136 | if evaluator_type == "lvis" or cfg.GEN_PSEDO_LABELS:
137 | evaluator = LVISEvaluator(dataset_name, cfg, True, output_folder)
138 | elif evaluator_type == 'coco':
139 | evaluator = COCOEvaluator(dataset_name, cfg, True, output_folder)
140 | else:
141 | assert 0, evaluator_type
142 |
143 | results[dataset_name] = inference_on_dataset(
144 | model, data_loader, evaluator)
145 | if comm.is_main_process():
146 | logger.info("Evaluation results for {} in csv format:".format(
147 | dataset_name))
148 | print_csv_format(results[dataset_name])
149 | if len(results) == 1:
150 | results = list(results.values())[0]
151 | return results
152 |
153 |
154 | def do_train(cfg, model, resume=False):
155 | model.train()
156 | if cfg.SOLVER.USE_CUSTOM_SOLVER:
157 | optimizer = build_custom_optimizer(cfg, model)
158 | else:
159 | assert cfg.SOLVER.OPTIMIZER == 'SGD'
160 | assert cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE != 'full_model'
161 | assert cfg.SOLVER.BACKBONE_MULTIPLIER == 1.
162 | optimizer = build_optimizer(cfg, model)
163 | scheduler = build_lr_scheduler(cfg, optimizer)
164 | logger.info("Following parameters will be trained:")
165 | for n, p in model.named_parameters():
166 | if p.requires_grad:
167 | logger.info("{}".format(n))
168 |
169 | checkpointer = DetectionCheckpointer(
170 | model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
171 | )
172 |
173 | start_iter = checkpointer.resume_or_load(
174 | cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
175 | if not resume:
176 | start_iter = 0
177 | max_iter = cfg.SOLVER.MAX_ITER if cfg.SOLVER.TRAIN_ITER < 0 else cfg.SOLVER.TRAIN_ITER
178 |
179 | periodic_checkpointer = PeriodicCheckpointer(
180 | checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
181 | )
182 |
183 | latest_checkpointer = LatestCheckpointer(
184 | checkpointer, 15000,)
185 |
186 | writers = (
187 | [
188 | CommonMetricPrinter(max_iter),
189 | JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")),
190 | TensorboardXWriter(cfg.OUTPUT_DIR),
191 | ]
192 | if comm.is_main_process()
193 | else []
194 | )
195 |
196 | use_custom_mapper = cfg.WITH_IMAGE_LABELS
197 | MapperClass = CustomDatasetMapper if use_custom_mapper else DatasetMapper
198 | mapper = MapperClass(cfg, True) if cfg.INPUT.CUSTOM_AUG == '' else \
199 | MapperClass(cfg, True, augmentations=build_custom_augmentation(cfg, True))
200 | if cfg.DATALOADER.SAMPLER_TRAIN in ['TrainingSampler', 'RepeatFactorTrainingSampler']:
201 | data_loader = build_detection_train_loader(cfg, mapper=mapper)
202 | else:
203 | data_loader = build_custom_train_loader(cfg, mapper=mapper)
204 |
205 | if cfg.FP16:
206 | scaler = GradScaler()
207 |
208 | logger.info("Starting training from iteration {}".format(start_iter))
209 | with EventStorage(start_iter) as storage:
210 | step_timer = Timer()
211 | data_timer = Timer()
212 | start_time = time.perf_counter()
213 | for data, iteration in zip(data_loader, range(start_iter, max_iter)):
214 | data_time = data_timer.seconds()
215 | storage.put_scalars(data_time=data_time)
216 | step_timer.reset()
217 | iteration = iteration + 1
218 | storage.step()
219 | loss_dict = model(data)
220 |
221 | losses = sum(
222 | loss for k, loss in loss_dict.items())
223 | assert torch.isfinite(losses).all(), loss_dict
224 |
225 | loss_dict_reduced = {
226 | k: v.item()
227 | for k, v in comm.reduce_dict(loss_dict).items()
228 | }
229 | losses_reduced = sum(loss for loss in loss_dict_reduced.values())
230 | if comm.is_main_process():
231 | storage.put_scalars(
232 | total_loss=losses_reduced, **loss_dict_reduced)
233 |
234 | optimizer.zero_grad()
235 | if cfg.FP16:
236 | scaler.scale(losses).backward()
237 | scaler.step(optimizer)
238 | scaler.update()
239 | else:
240 | losses.backward()
241 | optimizer.step()
242 |
243 | storage.put_scalar(
244 | "lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
245 |
246 | step_time = step_timer.seconds()
247 | storage.put_scalars(time=step_time)
248 | data_timer.reset()
249 | scheduler.step()
250 |
251 | if (cfg.TEST.EVAL_PERIOD > 0
252 | and iteration % cfg.TEST.EVAL_PERIOD == 0
253 | and iteration != max_iter):
254 | do_test(cfg, model)
255 | comm.synchronize()
256 |
257 | if (iteration - start_iter > 5
258 | and (iteration % 20 == 0 or iteration == max_iter)):
259 | for writer in writers:
260 | writer.write()
261 | latest_checkpointer.step(iteration)
262 | periodic_checkpointer.step(iteration)
263 |
264 | total_time = time.perf_counter() - start_time
265 | logger.info(
266 | "Total training time: {}".format(
267 | str(datetime.timedelta(seconds=int(total_time)))))
268 |
269 |
270 | def setup(args):
271 | """
272 | Create configs and perform basic setups.
273 | """
274 | cfg = get_cfg()
275 | add_centernet_config(cfg)
276 | add_mmovod_config(cfg)
277 | cfg.merge_from_file(args.config_file)
278 | cfg.merge_from_list(args.opts)
279 | if '/auto' in cfg.OUTPUT_DIR:
280 | if "configs/" in args.config_file:
281 | new_sub_folder = args.config_file.replace("configs", "")[:-5]
282 | new_output_dir = cfg.OUTPUT_DIR.replace("/auto", new_sub_folder)
283 | cfg.OUTPUT_DIR = new_output_dir
284 | else:
285 | file_name = os.path.basename(args.config_file)[:-5]
286 | cfg.OUTPUT_DIR = cfg.OUTPUT_DIR.replace('/auto', '/{}'.format(file_name))
287 | print(cfg.OUTPUT_DIR)
288 | cfg.freeze()
289 | default_setup(cfg, args)
290 | setup_logger(
291 | output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="mmovod")
292 | return cfg
293 |
294 |
295 | def main(args):
296 | cfg = setup(args)
297 |
298 | model = build_model(cfg)
299 | logger.info("Model:\n{}".format(model))
300 | if args.eval_only:
301 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
302 | cfg.MODEL.WEIGHTS, resume=args.resume
303 | )
304 |
305 | return do_test(cfg, model)
306 |
307 | distributed = comm.get_world_size() > 1
308 | if distributed:
309 | model = DistributedDataParallel(
310 | model, device_ids=[comm.get_local_rank()], broadcast_buffers=False,
311 | find_unused_parameters=cfg.FIND_UNUSED_PARAM
312 | )
313 |
314 | do_train(cfg, model, resume=args.resume)
315 | return do_test(cfg, model)
316 |
317 |
318 | if __name__ == "__main__":
319 | args = default_argument_parser()
320 | args = args.parse_args()
321 | if args.num_machines == 1:
322 | args.dist_url = 'tcp://127.0.0.1:{}'.format(
323 | torch.randint(11111, 60000, (1,))[0].item())
324 | else:
325 | if args.dist_url == 'host':
326 | args.dist_url = 'tcp://{}:12345'.format(
327 | os.environ['SLURM_JOB_NODELIST'])
328 | elif not args.dist_url.startswith('tcp'):
329 | tmp = os.popen(
330 | 'echo $(scontrol show job {} | grep BatchHost)'.format(
331 | args.dist_url)
332 | ).read()
333 | tmp = tmp[tmp.find('=') + 1: -1]
334 | args.dist_url = 'tcp://{}:12345'.format(tmp)
335 | print("Command Line Args:", args)
336 | launch(
337 | main,
338 | args.num_gpus,
339 | num_machines=args.num_machines,
340 | machine_rank=args.machine_rank,
341 | dist_url=args.dist_url,
342 | args=(args,),
343 | )
344 |
--------------------------------------------------------------------------------