├── .gitmodules ├── README.md ├── configs ├── base_r50_4x_clip.yaml ├── lvis-base_in-l_r50_4x_4x_clip_gpt3_descriptions.yaml ├── lvis-base_in-l_r50_4x_4x_clip_image_exemplars_agg.yaml ├── lvis-base_in-l_r50_4x_4x_clip_image_exemplars_avg.yaml ├── lvis-base_in-l_r50_4x_4x_clip_multi_modal_agg.yaml ├── lvis-base_in-l_r50_4x_4x_clip_multi_modal_avg.yaml ├── lvis-base_r50_4x_clip_gpt3_descriptions.yaml ├── lvis-base_r50_4x_clip_image_exemplars_agg.yaml ├── lvis-base_r50_4x_clip_image_exemplars_avg.yaml ├── lvis-base_r50_4x_clip_multi_modal_agg.yaml └── lvis-base_r50_4x_clip_multi_modal_avg.yaml ├── datasets ├── README.md └── metadata │ ├── lvis_gpt3_text-davinci-002_descriptions_author.json │ ├── lvis_gpt3_text-davinci-002_features_author.npy │ ├── lvis_image_exemplar_dict_K-005_author.json │ ├── lvis_image_exemplar_features_agg_K-005_author.npy │ ├── lvis_image_exemplar_features_avg_K-005_author.npy │ ├── lvis_multi-modal_agg_K-005_author.npy │ ├── lvis_multi-modal_avg_K-005_author.npy │ ├── lvis_v1_clip_a+cname.npy │ └── lvis_v1_train_cat_info.json ├── docs ├── INSTALL.md ├── MODEL_ZOO.md └── teaser.jpg ├── mmovod ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── config.cpython-38.pyc │ └── custom_solver.cpython-38.pyc ├── config.py ├── custom_solver.py ├── data │ ├── __pycache__ │ │ ├── custom_build_augmentation.cpython-38.pyc │ │ ├── custom_dataset_dataloader.cpython-38.pyc │ │ ├── custom_dataset_mapper.cpython-38.pyc │ │ └── tar_dataset.cpython-38.pyc │ ├── custom_build_augmentation.py │ ├── custom_dataset_dataloader.py │ ├── custom_dataset_mapper.py │ ├── datasets │ │ ├── __pycache__ │ │ │ ├── cc.cpython-38.pyc │ │ │ ├── coco_zeroshot.cpython-38.pyc │ │ │ ├── imagenet.cpython-38.pyc │ │ │ ├── lvis_22k_categories.cpython-38.pyc │ │ │ ├── lvis_v1.cpython-38.pyc │ │ │ ├── objects365.cpython-38.pyc │ │ │ ├── oid.cpython-38.pyc │ │ │ └── register_oid.cpython-38.pyc │ │ ├── imagenet.py │ │ └── lvis_v1.py │ └── transforms │ │ ├── __pycache__ │ │ ├── custom_augmentation_impl.cpython-38.pyc │ │ └── custom_transform.cpython-38.pyc │ │ ├── custom_augmentation_impl.py │ │ └── custom_transform.py ├── modeling │ ├── __pycache__ │ │ ├── debug.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── backbone │ │ ├── __pycache__ │ │ │ ├── swintransformer.cpython-38.pyc │ │ │ └── timm.cpython-38.pyc │ │ ├── swintransformer.py │ │ └── timm.py │ ├── debug.py │ ├── meta_arch │ │ ├── __pycache__ │ │ │ ├── custom_rcnn.cpython-38.pyc │ │ │ └── d2_deformable_detr.cpython-38.pyc │ │ └── custom_rcnn.py │ ├── roi_heads │ │ ├── __pycache__ │ │ │ ├── detic_fast_rcnn.cpython-38.pyc │ │ │ ├── detic_roi_heads.cpython-38.pyc │ │ │ ├── res5_roi_heads.cpython-38.pyc │ │ │ └── zero_shot_classifier.cpython-38.pyc │ │ ├── detic_fast_rcnn.py │ │ ├── detic_roi_heads.py │ │ ├── res5_roi_heads.py │ │ └── zero_shot_classifier.py │ ├── text │ │ ├── __pycache__ │ │ │ └── text_encoder.cpython-38.pyc │ │ └── text_encoder.py │ └── utils.py └── predictor.py ├── requirements.txt ├── tools ├── collate_exemplar_dict.py ├── convert-thirdparty-pretrained-model-to-d2.py ├── create_imagenetlvis_json.py ├── dump_clip_features.py ├── dump_clip_features_lvis_sentences.py ├── generate_descriptions.py ├── generate_vision_classifier_agg.py ├── get_exemplars_tta.py ├── get_lvis_cat_info.py ├── norm_feat_sum_norm.py ├── remove_lvis_rare.py ├── sample_exemplars.py └── unzip_imagenet_lvis.py └── train_net_auto.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/CenterNet2"] 2 | path = third_party/CenterNet2 3 | url = https://github.com/xingyizhou/CenterNet2.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Modal Classifiers for Open-Vocabulary Object Detection 2 | 3 |

4 | 5 | > [**Multi-Modal Classifiers for Open Vocabulary Object Detection**](https://arxiv.org/abs/2306.05493), 6 | > Prannay Kaul, Weidi Xie, Andrew Zisserman 7 | > *ICML 2023 ([arXiv 2201.02605](https://arxiv.org/abs/2306.05493))* 8 | 9 | 10 | ## Updates 11 | 12 | - **June 2023** Code and checkpoints for LVIS models in the main paper are released. Training code for visual aggregator to follow soon. 13 | 14 | 15 | ## Installation 16 | 17 | See [installation instructions](docs/INSTALL.md). 18 | 19 | 20 | ## Benchmark evaluation and training 21 | 22 | Please first [prepare datasets](datasets/README.md), then check our [MODEL ZOO](docs/MODEL_ZOO.md) to reproduce results in our paper. 23 | 24 | 25 | ## License 26 | 27 | See [Detic](https://github.com/facebookresearch/Detic). Our code is based on this repository. 28 | 29 | ## Citation 30 | 31 | If you find this project useful for your research, please use the following BibTeX entry. 32 | 33 | @inproceedings{Kaul2023, 34 | title={Multi-Modal Classifiers for Open-Vocabulary Object Detection}, 35 | author={Kaul, Prannay and Xie, Weidi and Zisserman, Andrew}, 36 | booktitle={ICML}, 37 | year={2023} 38 | } 39 | -------------------------------------------------------------------------------- /configs/base_r50_4x_clip.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "CustomRCNN" 3 | MASK_ON: True 4 | PROPOSAL_GENERATOR: 5 | NAME: "CenterNet" 6 | WEIGHTS: "checkpoints/resnet50_miil_21k.pkl" 7 | BACKBONE: 8 | NAME: build_p67_timm_fpn_backbone 9 | TIMM: 10 | BASE_NAME: resnet50_in21k 11 | FPN: 12 | IN_FEATURES: ["layer3", "layer4", "layer5"] 13 | PIXEL_MEAN: [123.675, 116.280, 103.530] 14 | PIXEL_STD: [58.395, 57.12, 57.375] 15 | ROI_HEADS: 16 | NAME: DeticCascadeROIHeads 17 | IN_FEATURES: ["p3", "p4", "p5"] 18 | IOU_THRESHOLDS: [0.6] 19 | NUM_CLASSES: 1203 20 | SCORE_THRESH_TEST: 0.02 21 | NMS_THRESH_TEST: 0.5 22 | ROI_BOX_CASCADE_HEAD: 23 | IOUS: [0.6, 0.7, 0.8] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | CLS_AGNOSTIC_BBOX_REG: True 29 | MULT_PROPOSAL_SCORE: True 30 | USE_SIGMOID_CE: True 31 | USE_FED_LOSS: True 32 | ROI_MASK_HEAD: 33 | NAME: "MaskRCNNConvUpsampleHead" 34 | NUM_CONV: 4 35 | POOLER_RESOLUTION: 14 36 | CLS_AGNOSTIC_MASK: True 37 | CENTERNET: 38 | NUM_CLASSES: 1203 39 | REG_WEIGHT: 1. 40 | NOT_NORM_REG: True 41 | ONLY_PROPOSAL: True 42 | WITH_AGN_HM: True 43 | INFERENCE_TH: 0.0001 44 | PRE_NMS_TOPK_TRAIN: 4000 45 | POST_NMS_TOPK_TRAIN: 2000 46 | PRE_NMS_TOPK_TEST: 1000 47 | POST_NMS_TOPK_TEST: 256 48 | NMS_TH_TRAIN: 0.9 49 | NMS_TH_TEST: 0.9 50 | POS_WEIGHT: 0.5 51 | NEG_WEIGHT: 0.5 52 | IGNORE_HIGH_FP: 0.85 53 | DATASETS: 54 | TRAIN: ("lvis_v1_train",) 55 | TEST: ("lvis_v1_val",) 56 | DATALOADER: 57 | SAMPLER_TRAIN: "RepeatFactorTrainingSampler" 58 | REPEAT_THRESHOLD: 0.001 59 | NUM_WORKERS: 8 60 | TEST: 61 | DETECTIONS_PER_IMAGE: 300 62 | SOLVER: 63 | LR_SCHEDULER_NAME: "WarmupCosineLR" 64 | CHECKPOINT_PERIOD: 30000 65 | WARMUP_ITERS: 10000 66 | WARMUP_FACTOR: 0.0001 67 | USE_CUSTOM_SOLVER: True 68 | OPTIMIZER: "ADAMW" 69 | MAX_ITER: 90000 70 | IMS_PER_BATCH: 64 71 | BASE_LR: 0.0002 72 | CLIP_GRADIENTS: 73 | ENABLED: True 74 | INPUT: 75 | FORMAT: RGB 76 | CUSTOM_AUG: EfficientDetResizeCrop 77 | TRAIN_SIZE: 640 78 | OUTPUT_DIR: "./output/mm-ovod/auto" 79 | EVAL_PROPOSAL_AR: False 80 | VERSION: 2 81 | FP16: True 82 | -------------------------------------------------------------------------------- /configs/lvis-base_in-l_r50_4x_4x_clip_gpt3_descriptions.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base_r50_4x_clip.yaml" 2 | MODEL: 3 | ROI_BOX_HEAD: 4 | USE_ZEROSHOT_CLS: True 5 | IMAGE_LABEL_LOSS: 'max_size' 6 | USE_BIAS: -2.0 7 | ZEROSHOT_WEIGHT_PATH: "datasets/metadata/lvis_gpt3_text-davinci-002_features_author.npy" 8 | WEIGHTS: "output/mm-ovod/lvis-base_r50_4x_clip_gpt3_descriptions/model_final.pth" 9 | SOLVER: 10 | MAX_ITER: 90000 11 | IMS_PER_BATCH: 64 12 | BASE_LR: 0.0002 13 | WARMUP_ITERS: 1000 14 | WARMUP_FACTOR: 0.001 15 | DATASETS: 16 | TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1") 17 | DATALOADER: 18 | SAMPLER_TRAIN: "MultiDatasetSampler" 19 | DATASET_RATIO: [1, 4] 20 | USE_DIFF_BS_SIZE: True 21 | DATASET_BS: [16, 64] 22 | DATASET_INPUT_SIZE: [640, 320] 23 | USE_RFS: [True, False] 24 | DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] 25 | FILTER_EMPTY_ANNOTATIONS: False 26 | MULTI_DATASET_GROUPING: True 27 | DATASET_ANN: ['box', 'image'] 28 | NUM_WORKERS: 8 29 | WITH_IMAGE_LABELS: True 30 | -------------------------------------------------------------------------------- /configs/lvis-base_in-l_r50_4x_4x_clip_image_exemplars_agg.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base_r50_4x_clip.yaml" 2 | MODEL: 3 | ROI_BOX_HEAD: 4 | USE_ZEROSHOT_CLS: True 5 | IMAGE_LABEL_LOSS: 'max_size' 6 | USE_BIAS: -2.0 7 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_image_exemplar_features_agg_K-005_author.npy' 8 | WEIGHTS: "output/mm-ovod/lvis-base_r50_4x_clip_image_exemplars_agg/model_final.pth" 9 | SOLVER: 10 | MAX_ITER: 90000 11 | IMS_PER_BATCH: 64 12 | BASE_LR: 0.0002 13 | WARMUP_ITERS: 1000 14 | WARMUP_FACTOR: 0.001 15 | DATASETS: 16 | TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1") 17 | DATALOADER: 18 | SAMPLER_TRAIN: "MultiDatasetSampler" 19 | DATASET_RATIO: [1, 4] 20 | USE_DIFF_BS_SIZE: True 21 | DATASET_BS: [16, 64] 22 | DATASET_INPUT_SIZE: [640, 320] 23 | USE_RFS: [True, False] 24 | DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] 25 | FILTER_EMPTY_ANNOTATIONS: False 26 | MULTI_DATASET_GROUPING: True 27 | DATASET_ANN: ['box', 'image'] 28 | NUM_WORKERS: 8 29 | WITH_IMAGE_LABELS: True 30 | -------------------------------------------------------------------------------- /configs/lvis-base_in-l_r50_4x_4x_clip_image_exemplars_avg.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base_r50_4x_clip.yaml" 2 | MODEL: 3 | ROI_BOX_HEAD: 4 | USE_ZEROSHOT_CLS: True 5 | IMAGE_LABEL_LOSS: 'max_size' 6 | USE_BIAS: -2.0 7 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_image_exemplar_features_avg_K-005_author.npy' 8 | WEIGHTS: "output/mm-ovod/lvis-base_r50_4x_clip_image_exemplars_avg/model_final.pth" 9 | SOLVER: 10 | MAX_ITER: 90000 11 | IMS_PER_BATCH: 64 12 | BASE_LR: 0.0002 13 | WARMUP_ITERS: 1000 14 | WARMUP_FACTOR: 0.001 15 | DATASETS: 16 | TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1") 17 | DATALOADER: 18 | SAMPLER_TRAIN: "MultiDatasetSampler" 19 | DATASET_RATIO: [1, 4] 20 | USE_DIFF_BS_SIZE: True 21 | DATASET_BS: [16, 64] 22 | DATASET_INPUT_SIZE: [640, 320] 23 | USE_RFS: [True, False] 24 | DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] 25 | FILTER_EMPTY_ANNOTATIONS: False 26 | MULTI_DATASET_GROUPING: True 27 | DATASET_ANN: ['box', 'image'] 28 | NUM_WORKERS: 8 29 | WITH_IMAGE_LABELS: True 30 | -------------------------------------------------------------------------------- /configs/lvis-base_in-l_r50_4x_4x_clip_multi_modal_agg.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base_r50_4x_clip.yaml" 2 | MODEL: 3 | ROI_BOX_HEAD: 4 | USE_ZEROSHOT_CLS: True 5 | IMAGE_LABEL_LOSS: 'max_size' 6 | USE_BIAS: -2.0 7 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_multi-modal_agg_K-005_author.npy' 8 | WEIGHTS: "output/mm-ovod/lvis-base_r50_4x_clip_multi_modal_agg/model_final.pth" 9 | SOLVER: 10 | MAX_ITER: 90000 11 | IMS_PER_BATCH: 64 12 | BASE_LR: 0.0002 13 | WARMUP_ITERS: 1000 14 | WARMUP_FACTOR: 0.001 15 | DATASETS: 16 | TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1") 17 | DATALOADER: 18 | SAMPLER_TRAIN: "MultiDatasetSampler" 19 | DATASET_RATIO: [1, 4] 20 | USE_DIFF_BS_SIZE: True 21 | DATASET_BS: [16, 64] 22 | DATASET_INPUT_SIZE: [640, 320] 23 | USE_RFS: [True, False] 24 | DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] 25 | FILTER_EMPTY_ANNOTATIONS: False 26 | MULTI_DATASET_GROUPING: True 27 | DATASET_ANN: ['box', 'image'] 28 | NUM_WORKERS: 8 29 | WITH_IMAGE_LABELS: True 30 | -------------------------------------------------------------------------------- /configs/lvis-base_in-l_r50_4x_4x_clip_multi_modal_avg.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base_r50_4x_clip.yaml" 2 | MODEL: 3 | ROI_BOX_HEAD: 4 | USE_ZEROSHOT_CLS: True 5 | IMAGE_LABEL_LOSS: 'max_size' 6 | USE_BIAS: -2.0 7 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_multi-modal_avg_K-005_author.npy' 8 | WEIGHTS: "output/mm-ovod/lvis-base_r50_4x_clip_multi_modal_avg/model_final.pth" 9 | SOLVER: 10 | MAX_ITER: 90000 11 | IMS_PER_BATCH: 64 12 | BASE_LR: 0.0002 13 | WARMUP_ITERS: 1000 14 | WARMUP_FACTOR: 0.001 15 | DATASETS: 16 | TRAIN: ("lvis_v1_train_norare","imagenet_lvis_v1") 17 | DATALOADER: 18 | SAMPLER_TRAIN: "MultiDatasetSampler" 19 | DATASET_RATIO: [1, 4] 20 | USE_DIFF_BS_SIZE: True 21 | DATASET_BS: [16, 64] 22 | DATASET_INPUT_SIZE: [640, 320] 23 | USE_RFS: [True, False] 24 | DATASET_INPUT_SCALE: [[0.1, 2.0], [0.5, 1.5]] 25 | FILTER_EMPTY_ANNOTATIONS: False 26 | MULTI_DATASET_GROUPING: True 27 | DATASET_ANN: ['box', 'image'] 28 | NUM_WORKERS: 8 29 | WITH_IMAGE_LABELS: True 30 | -------------------------------------------------------------------------------- /configs/lvis-base_r50_4x_clip_gpt3_descriptions.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base_r50_4x_clip.yaml" 2 | MODEL: 3 | ROI_BOX_HEAD: 4 | USE_ZEROSHOT_CLS: True 5 | ZEROSHOT_WEIGHT_PATH: "datasets/metadata/lvis_gpt3_text-davinci-002_features_author.npy" 6 | USE_BIAS: -2.0 7 | DATASETS: 8 | TRAIN: ("lvis_v1_train_norare",) 9 | -------------------------------------------------------------------------------- /configs/lvis-base_r50_4x_clip_image_exemplars_agg.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base_r50_4x_clip.yaml" 2 | MODEL: 3 | ROI_BOX_HEAD: 4 | USE_ZEROSHOT_CLS: True 5 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_image_exemplar_features_agg_K-005_author.npy' 6 | USE_BIAS: -2.0 7 | DATASETS: 8 | TRAIN: ("lvis_v1_train_norare",) 9 | -------------------------------------------------------------------------------- /configs/lvis-base_r50_4x_clip_image_exemplars_avg.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base_r50_4x_clip.yaml" 2 | MODEL: 3 | ROI_BOX_HEAD: 4 | USE_ZEROSHOT_CLS: True 5 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_image_exemplar_features_avg_K-005_author.npy' 6 | USE_BIAS: -2.0 7 | DATASETS: 8 | TRAIN: ("lvis_v1_train_norare",) 9 | -------------------------------------------------------------------------------- /configs/lvis-base_r50_4x_clip_multi_modal_agg.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base_r50_4x_clip.yaml" 2 | MODEL: 3 | ROI_BOX_HEAD: 4 | USE_ZEROSHOT_CLS: True 5 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_multi-modal_agg_K-005_author.npy' 6 | USE_BIAS: -2.0 7 | DATASETS: 8 | TRAIN: ("lvis_v1_train_norare",) 9 | -------------------------------------------------------------------------------- /configs/lvis-base_r50_4x_clip_multi_modal_avg.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "base_r50_4x_clip.yaml" 2 | MODEL: 3 | ROI_BOX_HEAD: 4 | USE_ZEROSHOT_CLS: True 5 | ZEROSHOT_WEIGHT_PATH: 'datasets/metadata/lvis_multi-modal_avg_K-005_author.npy' 6 | USE_BIAS: -2.0 7 | DATASETS: 8 | TRAIN: ("lvis_v1_train_norare",) 9 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Prepare datasets for MM-OVOD (borrows and edits from Detic) 2 | 3 | The basic training of our model uses [LVIS](https://www.lvisdataset.org/) (which uses [COCO](https://cocodataset.org/) images) and [ImageNet-21K](https://www.image-net.org/download.php). 4 | 5 | Before starting processing, please download the (selected) datasets from the official websites and place or sim-link them under `${mm-ovod_ROOT}/datasets/` with details shown below. 6 | 7 | ``` 8 | ${mm-ovod_ROOT}/datasets/ 9 | metadata/ 10 | lvis/ 11 | coco/ 12 | imagenet/ 13 | VisualGenome/ 14 | ``` 15 | `metadata/` is our preprocessed meta-data (included in the repo). See the below [section](#Metadata) for details. 16 | Please follow the following instruction to pre-process individual datasets. 17 | 18 | ### COCO and LVIS 19 | 20 | First, download COCO images and LVIS data place them in the following way: 21 | 22 | ``` 23 | lvis/ 24 | lvis_v1_train.json 25 | lvis_v1_val.json 26 | coco/ 27 | train2017/ 28 | val2017/ 29 | ``` 30 | 31 | Next, prepare the open-vocabulary LVIS training set using 32 | 33 | ``` 34 | python tools/remove_lvis_rare.py --ann datasets/lvis/lvis_v1_train.json 35 | ``` 36 | 37 | This will generate `datasets/lvis/lvis_v1_train_norare.json`. 38 | 39 | ### ImageNet-21K 40 | 41 | The imagenet folder should look like the below, after following the data-processing 42 | [script](https://github.com/Alibaba-MIIL/ImageNet21K/blob/main/dataset_preprocessing/processing_script.sh) from ImageNet-21K Pretraining for the Masses ensuring to use the FALL 2011 version. 43 | After this script has run, please rename folders to give the below structure: 44 | ``` 45 | imagenet/ 46 | imagenet21k_P/ 47 | train/ 48 | n00005787/ 49 | n00005787_*.JPEG 50 | n00006484/ 51 | n00006484_*.JPEG 52 | ... 53 | val/ 54 | n00005787/ 55 | n00005787_*.JPEG 56 | n00006484/ 57 | n00006484_*.JPEG 58 | ... 59 | imagenet21k_small_classes/ 60 | n00004475/ 61 | n00004475_*.JPEG 62 | n00006024/ 63 | n00006024_*.JPEG 64 | ... 65 | ``` 66 | 67 | The subset of ImageNet that overlaps with LVIS (IN-L in the paper) will be created from this directory 68 | structure. 69 | 70 | ~~~ 71 | cd ${mm-ovod_ROOT}/datasets/ 72 | mkdir imagenet/annotations 73 | python tools/create_imagenetlvis_json.py --imagenet-path datasets/imagenet/imagenet21k_P --out-path datasets/imagenet/annotations/imagenet_lvis_image_info.json 74 | ~~~ 75 | This creates `datasets/imagenet/annotations/imagenet_lvis_image_info.json`. 76 | 77 | 78 | ### VisualGenome 79 | 80 | Some of our image exemplars are sourced from VisualGenome and so download the dataset ensuring the following 81 | files are present with the below structure: 82 | ``` 83 | VisualGenome/ 84 | VG_100K/ 85 | *.jpg 86 | VG_100K_2/ 87 | *.jpg 88 | objects.json 89 | image_data.json 90 | ``` 91 | 92 | ### Metadata 93 | 94 | ``` 95 | metadata/ 96 | lvis_v1_train_cat_info.json 97 | lvis_gpt3_text-davinci-002_descriptions_author.json 98 | lvis_gpt3_text-davinci-002_features_author.npy 99 | lvis_image_exemplar_dict_K-005_author.json 100 | lvis_image_exemplar_features_agg_K-005_author.npy 101 | lvis_image_exemplar_features_avg_K-005_author.npy 102 | lvis_multi-modal_agg_K-005_author.npy 103 | lvis_multi-modal_avg_K-005_author.npy 104 | lvis_v1_clip_a+cname.npy 105 | ``` 106 | 107 | `lvis_v1_train_cat_info.json` is used by the Federated loss. 108 | This is created by 109 | ~~~ 110 | python tools/get_lvis_cat_info.py --ann datasets/lvis/lvis_v1_train.json 111 | ~~~ 112 | 113 | `lvis_gpt3_text-davinci-002_descriptions_author.json` are the descriptions for each LVIS class 114 | we found using the (now deprecated) text-davinci-002 model from OpenAI. 115 | 116 | Users may create their own descriptions by: 117 | ~~~ 118 | python tools/generate_descriptions.py --ann-path datasets/lvis/lvis_v1_val.json --openai-model text-davinci-003 119 | ~~~ 120 | which will create a file called `lvis_gpt3_text-davinci-003_descriptions_own.json`. 121 | Be sure to include your own OpenAI API key at the top of `tools/generate_descriptions.py`. 122 | 123 | `lvis_gpt3_text-davinci-002_features_author.npy` is the CLIP embeddings for each class in the LVIS 124 | dataset using the descriptions we generate using GPT-3. 125 | ~~~ 126 | python tools/dump_clip_features_lvis_sentences.py --descriptions-path datasets/metadata/lvis_gpt3_text-davinci-002_descriptions_author.json --ann-path datasets/lvis/lvis_v1_val.json --out-path datasets/metadata/lvis_gpt3_text-davinci-002_features_author.npy 127 | ~~~ 128 | 129 | `lvis_image_exemplar_dict_K-005_author.json` is the dictionary of image exemplars for each LVIS class 130 | used in the paper and produce our results. 131 | One can create their own as follows: 132 | ~~~ 133 | python tools/sample_exemplars.py --lvis-ann-path datasets/lvis/lvis_v1_val.json --exemplar-dict-path datasets/metadata/exemplar_dict.json -K 5 --out-path datasets/metadata/lvis_image_exemplar_dict_K-005_own.json 134 | ~~~ 135 | 136 | `lvis_image_exemplar_features_agg_K-005_author.npy` is the CLIP embeddings for each class in the LVIS dataset 137 | when using image examplars AND our trained visual aggregator for combining multiple exemplars. 138 | One can create their own using our trained visual aggregator as follows (see [INSTALL.md](../../docs/INSTALL.md) for downloading 139 | visual aggregator weights): 140 | ~~~ 141 | python tools/generate_vision_classifier_agg.py --exemplar-list datasets/metadata/lvis_image_exemplar_dict_K-005_own.json --out-path datasets/metadata/lvis_image_exemplar_features_agg_K-005_own.npy --num-augs 5 --tta --load-path checkpoints/visual_aggregator/visual_aggregator_ckpt_4_transformer.pth 142 | ~~~ 143 | 144 | `lvis_image_exemplar_features_avg_K-005_author.npy` is the CLIP embeddings for each class in the LVIS dataset 145 | when using image examplars AND averaging the CLIP embeddings of multiple exemplars (not using our trained 146 | aggregator). 147 | One can create their own for example: 148 | ~~~ 149 | python tools/get_exemplars_tta.py --ann-path /users/prannay/mm-ovod/datasets/metadata/lvis_image_exemplar_dict_K-005_own.json --output-path datasets/metadata/lvis_image_exemplar_features_avg_K-005_own.npy --num-augs 5 150 | ~~~ 151 | 152 | `lvis_multi-modal_agg_K-005_author.npy` is the CLIP embeddings for each class in the LVIS dataset 153 | when using image examplars AND descriptions AND our trained visual aggregator for combining multiple exemplars. 154 | One can create their own for example: 155 | ~~~ 156 | python tools/norm_feat_sum_norm.py --feat1-path datasets/metadata/lvis_gpt3_text-davinci-002_features_own.npy --feat2-path datasets/metadata/lvis_image_exemplar_features_agg_K-005_own.npy --out-path datasets/metadata/lvis_multi-modal_agg_K-005_own.npy 157 | ~~~ 158 | 159 | `lvis_multi-modal_avg_K-005_author.npy` is the CLIP embeddings for each class in the LVIS dataset 160 | when using image examplars AND descriptions AND averaging the CLIP embeddings of multiple exemplars (not using our trained aggregator). 161 | One can create their own for example: 162 | ~~~ 163 | python tools/norm_feat_sum_norm.py --feat1-path datasets/metadata/lvis_gpt3_text-davinci-002_features_own.npy --feat2-path datasets/metadata/lvis_image_exemplar_features_avg_K-005_own.npy --out-path datasets/metadata/lvis_multi-modal_avg_K-005_own.npy 164 | ~~~ 165 | 166 | `lvis_clip_a+cname.npy` is the pre-computed CLIP embeddings for each class in the LVIS dataset (from Detic) 167 | They are created by: 168 | ~~~ 169 | python tools/dump_clip_features.py --ann datasets/lvis/lvis_v1_val.json --out_path metadata/lvis_v1_clip_a+cname.npy 170 | ~~~ 171 | 172 | ### Collating Image Exemplars 173 | 174 | We provide the exact image exemplars used (5 per LVIS category) in our results in the metadata folder defined 175 | above. 176 | However, if you wish to create your own, one first needs to create a full dictionary of exemplars for each 177 | LVIS category. 178 | This is done by: 179 | ~~~ 180 | python tools/collate_exemplar_dict.py --lvis-dir datasets/lvis --imagenet-dir datasets/imagenet --visual-genome-dir datasets/VisualGenome --output-path datasets/metadata/exemplar_dict.json 181 | ~~~ 182 | This will create `datasets/metadata/exemplar_dict.json` which is a dictionary of exemplars with 183 | at least 10 exemplars per LVIS category. -------------------------------------------------------------------------------- /datasets/metadata/lvis_gpt3_text-davinci-002_features_author.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/datasets/metadata/lvis_gpt3_text-davinci-002_features_author.npy -------------------------------------------------------------------------------- /datasets/metadata/lvis_image_exemplar_features_agg_K-005_author.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/datasets/metadata/lvis_image_exemplar_features_agg_K-005_author.npy -------------------------------------------------------------------------------- /datasets/metadata/lvis_image_exemplar_features_avg_K-005_author.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/datasets/metadata/lvis_image_exemplar_features_avg_K-005_author.npy -------------------------------------------------------------------------------- /datasets/metadata/lvis_multi-modal_agg_K-005_author.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/datasets/metadata/lvis_multi-modal_agg_K-005_author.npy -------------------------------------------------------------------------------- /datasets/metadata/lvis_multi-modal_avg_K-005_author.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/datasets/metadata/lvis_multi-modal_avg_K-005_author.npy -------------------------------------------------------------------------------- /datasets/metadata/lvis_v1_clip_a+cname.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/datasets/metadata/lvis_v1_clip_a+cname.npy -------------------------------------------------------------------------------- /docs/INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ### Requirements 4 | - Linux or macOS with Python ≥ 3.6 5 | - PyTorch ≥ 1.8. 6 | Install them together at [pytorch.org](https://pytorch.org) to make sure of this. Note, please check 7 | PyTorch version matches that is required by Detectron2. 8 | - Detectron2: follow [Detectron2 installation instructions](https://detectron2.readthedocs.io/tutorials/install.html). 9 | 10 | 11 | ### Author conda environment setup 12 | ```bash 13 | conda create --name mm-ovod python=3.8 -y 14 | conda activate mm-ovod 15 | conda install pytorch torchvision=0.9.2 torchaudio cudatoolkit=11.1 -c pytorch-lts -c nvidia 16 | 17 | # under your working directory 18 | git clone git@github.com:facebookresearch/detectron2.git 19 | cd detectron2 20 | git checkout 2b98c273b240b54d2d0ee6853dc331c4f2ca87b9 21 | pip install -e . 22 | 23 | cd .. 24 | git clone https://github.com/prannaykaul/mm-ovod.git --recurse-submodules 25 | cd mm-ovod 26 | pip install -r requirements.txt 27 | pip uninstall pillow 28 | CC="cc -mavx2" pip install -U --force-reinstall pillow-simd 29 | ``` 30 | 31 | Our project (like Detic) use a submodule: [CenterNet2](https://github.com/xingyizhou/CenterNet2.git). If you forget to add `--recurse-submodules`, do `git submodule init` and then `git submodule update`. 32 | 33 | 34 | ### Downloading pre-trained ResNet-50 backbone 35 | We use the ResNet-50 backbone pre-trained on ImageNet-21k-P from [here]( 36 | https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth). Please download it from the previous link, place it in the `${mm-ovod_ROOT}/checkpoints` folder and use the following command to convert it for use with detectron2: 37 | ```bash 38 | cd ${mm-ovod_ROOT} 39 | mkdir checkpoints 40 | cd checkpoints 41 | wget https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth 42 | python ../tools/convert-thirdparty-pretrained-model-to-d2.py --path resnet50_miil_21k.pth 43 | ``` 44 | 45 | ### Downloading pre-trained visual aggregator 46 | The pretrained model for the visual aggregator is required if one wants to use their own image exemplars to produce a vison-based 47 | classifier. The model can be downloaded from [here](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/visual_aggregator_ckpt_4_transformer.pth.tar) and should be placed in the `${mm-ovod_ROOT}/checkpoints` folder. 48 | ```bash 49 | cd ${mm-ovod_ROOT} 50 | mkdir checkpoints 51 | cd checkpoints 52 | wget https://robots.ox.ac.uk/~prannay/public_models/mm-ovod/visual_aggregator_ckpt_4_transformer.pth.tar 53 | tar -xf visual_aggregator_ckpt_4_transformer.pth.tar 54 | rm visual_aggregator_ckpt_4_transformer.pth.tar 55 | ``` 56 | -------------------------------------------------------------------------------- /docs/MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | # Multi-Modal Open-Vocabulary Object Detection Model Zoo 2 | 3 | ## Introduction 4 | 5 | This file documents a collection of models reported in our paper. 6 | Training in all cases is done with 4 32GB V100 GPUs. 7 | 8 | #### How to Read the Tables 9 | 10 | The "Name" column contains a link to the config file. 11 | To train a model, run 12 | 13 | ``` 14 | python train_net_auto.py --num-gpus 4 --config-file /path/to/config/name.yaml 15 | ``` 16 | 17 | To evaluate a model with a trained/ pretrained model, run 18 | 19 | ``` 20 | python train_net_auto.py --num-gpus 4 --config-file /path/to/config/name.yaml --eval-only MODEL.WEIGHTS /path/to/weight.pth 21 | ``` 22 | 23 | 24 | ## Open-vocabulary LVIS 25 | 26 | | Name | APr | mAP | Weights | 27 | |------------------------------------------------------------------------------------------------------------------------|:----:|:----:|------------------------------------------------------------------| 28 | | [lvis-base_r50_4x_clip_gpt3_descriptions](../configs/lvis-base_r50_4x_clip_gpt3_descriptions.yaml) | 19.3 | 30.3 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_r50_4x_clip_gpt3_descriptions.pth.tar) | 29 | | [lvis-base_r50_4x_clip_image_exemplars_avg](../configs/lvis-base_r50_4x_clip_image_exemplars_avg.yaml) | 14.8 | 28.8 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_r50_4x_clip_image_exemplars_avg.pth.tar) | 30 | | [lvis-base_r50_4x_clip_image_exemplars_agg](../configs/lvis-base_r50_4x_clip_image_exemplars_agg.yaml) | 18.3 | 29.2 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_r50_4x_clip_image_exemplars_agg.pth.tar) | 31 | | [lvis-base_r50_4x_clip_multi_modal_avg](../configs/lvis-base_r50_4x_clip_multi_modal_avg.yaml) | 20.7 | 30.5 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_r50_4x_clip_multi_modal_avg.pth.tar) | 32 | | [lvis-base_r50_4x_clip_multi_modal_agg](../configs/lvis-base_r50_4x_clip_multi_modal_agg.yaml) | 19.2 | 30.6 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_r50_4x_clip_multi_modal_agg.pth.tar) | 33 | | [lvis-base_in-l_r50_4x_4x_clip_gpt3_descriptions](../configs/lvis-base_in-l_r50_4x_4x_clip_gpt3_descriptions.yaml) | 25.8 | 32.6 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_in-l_r50_4x_4x_clip_gpt3_descriptions.pth.tar) | 34 | | [lvis-base_in-l_r50_4x_4x_clip_image_exemplars_avg](../configs/lvis-base_in-l_r50_4x_4x_clip_image_exemplars_avg.yaml) | 21.6 | 31.3 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_in-l_r50_4x_4x_clip_image_exemplars_avg.pth.tar) | 35 | | [lvis-base_in-l_r50_4x_4x_clip_image_exemplars_agg](../configs/lvis-base_in-l_r50_4x_4x_clip_image_exemplars_agg.yaml) | 23.8 | 31.3 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_in-l_r50_4x_4x_clip_image_exemplars_agg.pth.tar) | 36 | | [lvis-base_in-l_r50_4x_4x_clip_multi_modal_avg](../configs/lvis-base_in-l_r50_4x_4x_clip_multi_modal_avg.yaml) | 26.5 | 32.8 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_in-l_r50_4x_4x_clip_multi_modal_avg.pth.tar) | 37 | | [lvis-base_in-l_r50_4x_4x_clip_multi_modal_agg](../configs/lvis-base_in-l_r50_4x_4x_clip_multi_modal_agg.yaml) | 27.3 | 33.1 | [model](https://www.robots.ox.ac.uk/~prannay/public_models/mm-ovod/lvis-base_in-l_r50_4x_4x_clip_multi_modal_agg.pth.tar) | 38 | 39 | 40 | #### Note 41 | 42 | - The open-vocabulary LVIS setup is LVIS without rare class annotations in training. We evaluate rare classes as novel classes in testing. 43 | 44 | - All models use [CLIP](https://github.com/openai/CLIP) embeddings as classifiers. This makes the box-supervised models have non-zero mAP on novel classes. 45 | 46 | - The models with `in-l` use the overlap classes between ImageNet-21K and LVIS as image-labeled data. 47 | 48 | - The models which are trained on `in-l` require the corresponding models _without_ `in-l` (indicated by MODEL.WEIGHTS in the config files). Please train or download the model without `in-l` and place them under `${mm-ovod_ROOT}/output/..` before training the model using `in-l` (check the config file). 49 | 50 | -------------------------------------------------------------------------------- /docs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/docs/teaser.jpg -------------------------------------------------------------------------------- /mmovod/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from .modeling.meta_arch import custom_rcnn 3 | from .modeling.roi_heads import detic_roi_heads 4 | from .modeling.roi_heads import res5_roi_heads 5 | from .modeling.backbone import swintransformer 6 | from .modeling.backbone import timm 7 | 8 | 9 | from .data.datasets import lvis_v1 10 | from .data.datasets import imagenet 11 | # from .data.datasets import objects365 12 | -------------------------------------------------------------------------------- /mmovod/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/__pycache__/custom_solver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/__pycache__/custom_solver.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | from detectron2.config import CfgNode as CN 3 | 4 | 5 | def add_mmovod_config(cfg): 6 | _C = cfg 7 | 8 | _C.WITH_IMAGE_LABELS = False # Turn on co-training with classification data 9 | 10 | # Open-vocabulary classifier 11 | _C.MODEL.ROI_BOX_HEAD.USE_ZEROSHOT_CLS = False # Use fixed classifier for open-vocabulary detection 12 | _C.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = 'datasets/metadata/lvis_v1_clip_a+cname.npy' 13 | _C.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_DIM = 512 14 | _C.MODEL.ROI_BOX_HEAD.NORM_WEIGHT = True 15 | _C.MODEL.ROI_BOX_HEAD.NORM_TEMP = 50.0 16 | _C.MODEL.ROI_BOX_HEAD.IGNORE_ZERO_CATS = False 17 | _C.MODEL.ROI_BOX_HEAD.USE_BIAS = 0.0 # >= 0: not use 18 | 19 | _C.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE = False # CenterNet2 20 | _C.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE = False 21 | _C.MODEL.ROI_BOX_HEAD.PRIOR_PROB = 0.01 22 | _C.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = False # Federated Loss 23 | _C.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = \ 24 | 'datasets/metadata/lvis_v1_train_cat_info.json' 25 | _C.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CAT = 50 26 | _C.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT = 0.5 27 | 28 | # Classification data configs 29 | _C.MODEL.ROI_BOX_HEAD.IMAGE_LABEL_LOSS = 'max_size' # max, softmax, sum 30 | _C.MODEL.ROI_BOX_HEAD.IMAGE_LOSS_WEIGHT = 0.1 31 | _C.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE = 1.0 32 | _C.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX = False # Used for image-box loss and caption loss 33 | _C.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS = 128 # num proposals for image-labeled data 34 | _C.MODEL.ROI_BOX_HEAD.WITH_SOFTMAX_PROP = False # Used for WSDDN 35 | _C.MODEL.ROI_BOX_HEAD.CAPTION_WEIGHT = 1.0 # Caption loss weight 36 | _C.MODEL.ROI_BOX_HEAD.NEG_CAP_WEIGHT = 0.125 # Caption loss hyper-parameter 37 | _C.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP = False # Used for WSDDN 38 | _C.MODEL.ROI_BOX_HEAD.SOFTMAX_WEAK_LOSS = False # Used when USE_SIGMOID_CE is False 39 | 40 | _C.MODEL.ROI_HEADS.MASK_WEIGHT = 1.0 41 | _C.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = False # For demo only 42 | 43 | # Caption losses 44 | _C.MODEL.CAP_BATCH_RATIO = 4 # Ratio between detection data and caption data 45 | _C.MODEL.WITH_CAPTION = False 46 | _C.MODEL.SYNC_CAPTION_BATCH = False # synchronize across GPUs to enlarge # "classes" 47 | 48 | # dynamic class sampling when training with 21K classes 49 | _C.MODEL.DYNAMIC_CLASSIFIER = False 50 | _C.MODEL.NUM_SAMPLE_CATS = 50 51 | 52 | # Different classifiers in testing, used in cross-dataset evaluation 53 | _C.MODEL.RESET_CLS_TESTS = False 54 | _C.MODEL.TEST_CLASSIFIERS = [] 55 | _C.MODEL.TEST_NUM_CLASSES = [] 56 | 57 | # Backbones 58 | _C.MODEL.SWIN = CN() 59 | _C.MODEL.SWIN.SIZE = 'T' # 'T', 'S', 'B' 60 | _C.MODEL.SWIN.USE_CHECKPOINT = False 61 | _C.MODEL.SWIN.OUT_FEATURES = (1, 2, 3) # FPN stride 8 - 32 62 | 63 | _C.MODEL.TIMM = CN() 64 | _C.MODEL.TIMM.BASE_NAME = 'resnet50' 65 | _C.MODEL.TIMM.OUT_LEVELS = (3, 4, 5) 66 | _C.MODEL.TIMM.NORM = 'FrozenBN' 67 | _C.MODEL.TIMM.FREEZE_AT = 0 68 | _C.MODEL.TIMM.PRETRAINED = False 69 | _C.MODEL.DATASET_LOSS_WEIGHT = [] 70 | 71 | # Multi-dataset dataloader 72 | _C.DATALOADER.DATASET_RATIO = [1, 1] # sample ratio 73 | _C.DATALOADER.USE_RFS = [False, False] 74 | _C.DATALOADER.MULTI_DATASET_GROUPING = False # Always true when multi-dataset is enabled 75 | _C.DATALOADER.DATASET_ANN = ['box', 'box'] # Annotation type of each dataset 76 | _C.DATALOADER.USE_DIFF_BS_SIZE = False # Use different batchsize for each dataset 77 | _C.DATALOADER.DATASET_BS = [8, 32] # Used when USE_DIFF_BS_SIZE is on PER GPU!!!! 78 | _C.DATALOADER.DATASET_INPUT_SIZE = [896, 384] # Used when USE_DIFF_BS_SIZE is on 79 | _C.DATALOADER.DATASET_INPUT_SCALE = [(0.1, 2.0), (0.5, 1.5)] # Used when USE_DIFF_BS_SIZE is on 80 | _C.DATALOADER.DATASET_MIN_SIZES = [(640, 800), (320, 400)] # Used when USE_DIFF_BS_SIZE is on 81 | _C.DATALOADER.DATASET_MAX_SIZES = [1333, 667] # Used when USE_DIFF_BS_SIZE is on 82 | _C.DATALOADER.USE_TAR_DATASET = False # for ImageNet-21K, directly reading from unziped files 83 | _C.DATALOADER.TARFILE_PATH = 'datasets/imagenet/metadata-22k/tar_files.npy' 84 | _C.DATALOADER.TAR_INDEX_DIR = 'datasets/imagenet/metadata-22k/tarindex_npy' 85 | 86 | _C.SOLVER.USE_CUSTOM_SOLVER = False 87 | _C.SOLVER.OPTIMIZER = 'SGD' 88 | 89 | _C.INPUT.CUSTOM_AUG = '' 90 | _C.INPUT.TRAIN_SIZE = 640 91 | _C.INPUT.TEST_SIZE = 640 92 | _C.INPUT.SCALE_RANGE = (0.1, 2.) 93 | # 'default' for fixed short/ long edge, 'square' for max size=INPUT.SIZE 94 | _C.INPUT.TEST_INPUT_TYPE = 'default' 95 | 96 | _C.FIND_UNUSED_PARAM = True 97 | _C.EVAL_PRED_AR = False 98 | _C.EVAL_PROPOSAL_AR = False 99 | _C.EVAL_CAT_SPEC_AR = False 100 | _C.IS_DEBUG = False 101 | _C.QUICK_DEBUG = False 102 | _C.FP16 = False 103 | _C.EVAL_AP_FIX = False 104 | _C.GEN_PSEDO_LABELS = False 105 | _C.SAVE_DEBUG_PATH = 'output/save_debug/' 106 | -------------------------------------------------------------------------------- /mmovod/custom_solver.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from enum import Enum 3 | import itertools 4 | from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union 5 | import torch 6 | 7 | from detectron2.config import CfgNode 8 | 9 | from detectron2.solver.build import maybe_add_gradient_clipping 10 | 11 | def match_name_keywords(n, name_keywords): 12 | out = False 13 | for b in name_keywords: 14 | if b in n: 15 | out = True 16 | break 17 | return out 18 | 19 | def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer: 20 | """ 21 | Build an optimizer from config. 22 | """ 23 | params: List[Dict[str, Any]] = [] 24 | memo: Set[torch.nn.parameter.Parameter] = set() 25 | 26 | optimizer_type = cfg.SOLVER.OPTIMIZER 27 | for key, value in model.named_parameters(recurse=True): 28 | if not value.requires_grad: 29 | continue 30 | # Avoid duplicating parameters 31 | if value in memo: 32 | continue 33 | memo.add(value) 34 | lr = cfg.SOLVER.BASE_LR 35 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 36 | param = {"params": [value], "lr": lr} 37 | if optimizer_type != 'ADAMW': 38 | param['weight_decay'] = weight_decay 39 | params += [param] 40 | 41 | def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class 42 | # detectron2 doesn't have full model gradient clipping now 43 | clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE 44 | enable = ( 45 | cfg.SOLVER.CLIP_GRADIENTS.ENABLED 46 | and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" 47 | and clip_norm_val > 0.0 48 | ) 49 | 50 | class FullModelGradientClippingOptimizer(optim): 51 | def step(self, closure=None): 52 | all_params = itertools.chain(*[x["params"] for x in self.param_groups]) 53 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) 54 | super().step(closure=closure) 55 | 56 | return FullModelGradientClippingOptimizer if enable else optim 57 | 58 | if optimizer_type == 'SGD': 59 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( 60 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, 61 | nesterov=cfg.SOLVER.NESTEROV 62 | ) 63 | elif optimizer_type == 'ADAMW': 64 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( 65 | params, cfg.SOLVER.BASE_LR, 66 | weight_decay=cfg.SOLVER.WEIGHT_DECAY 67 | ) 68 | else: 69 | raise NotImplementedError(f"no optimizer type {optimizer_type}") 70 | if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": 71 | optimizer = maybe_add_gradient_clipping(cfg, optimizer) 72 | return optimizer 73 | -------------------------------------------------------------------------------- /mmovod/data/__pycache__/custom_build_augmentation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/__pycache__/custom_build_augmentation.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/data/__pycache__/custom_dataset_dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/__pycache__/custom_dataset_dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/data/__pycache__/custom_dataset_mapper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/__pycache__/custom_dataset_mapper.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/data/__pycache__/tar_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/__pycache__/tar_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/data/custom_build_augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | import numpy as np 4 | import pycocotools.mask as mask_util 5 | import torch 6 | from fvcore.common.file_io import PathManager 7 | from PIL import Image 8 | 9 | 10 | from detectron2.data import transforms as T 11 | from .transforms.custom_augmentation_impl import EfficientDetResizeCrop 12 | 13 | def build_custom_augmentation(cfg, is_train, scale=None, size=None, \ 14 | min_size=None, max_size=None): 15 | """ 16 | Create a list of default :class:`Augmentation` from config. 17 | Now it includes resizing and flipping. 18 | 19 | Returns: 20 | list[Augmentation] 21 | """ 22 | if cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge': 23 | if is_train: 24 | min_size = cfg.INPUT.MIN_SIZE_TRAIN if min_size is None else min_size 25 | max_size = cfg.INPUT.MAX_SIZE_TRAIN if max_size is None else max_size 26 | sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING 27 | else: 28 | min_size = cfg.INPUT.MIN_SIZE_TEST 29 | max_size = cfg.INPUT.MAX_SIZE_TEST 30 | sample_style = "choice" 31 | augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)] 32 | elif cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop': 33 | if is_train: 34 | scale = cfg.INPUT.SCALE_RANGE if scale is None else scale 35 | size = cfg.INPUT.TRAIN_SIZE if size is None else size 36 | else: 37 | scale = (1, 1) 38 | size = cfg.INPUT.TEST_SIZE 39 | augmentation = [EfficientDetResizeCrop(size, scale)] 40 | else: 41 | assert 0, cfg.INPUT.CUSTOM_AUG 42 | 43 | if is_train: 44 | augmentation.append(T.RandomFlip()) 45 | return augmentation 46 | 47 | 48 | build_custom_transform_gen = build_custom_augmentation 49 | """ 50 | Alias for backward-compatibility. 51 | """ -------------------------------------------------------------------------------- /mmovod/data/custom_dataset_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/multi_dataset_dataloader.py (Apache-2.0 License) 3 | import copy 4 | import logging 5 | import numpy as np 6 | import operator 7 | import torch 8 | import torch.utils.data 9 | import json 10 | from detectron2.utils.comm import get_world_size 11 | from detectron2.utils.logger import _log_api_usage, log_first_n 12 | 13 | from detectron2.config import configurable 14 | from detectron2.data import samplers 15 | from torch.utils.data.sampler import BatchSampler, Sampler 16 | from detectron2.data.common import DatasetFromList, MapDataset 17 | from detectron2.data.dataset_mapper import DatasetMapper 18 | from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader 19 | from detectron2.data.samplers import TrainingSampler, RepeatFactorTrainingSampler 20 | from detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram 21 | from detectron2.data.build import filter_images_with_only_crowd_annotations 22 | from detectron2.data.build import filter_images_with_few_keypoints 23 | from detectron2.data.build import check_metadata_consistency 24 | from detectron2.data.catalog import MetadataCatalog, DatasetCatalog 25 | from detectron2.utils import comm 26 | import itertools 27 | import math 28 | from collections import defaultdict 29 | from typing import Optional 30 | 31 | 32 | def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): 33 | sampler_name = cfg.DATALOADER.SAMPLER_TRAIN 34 | if 'MultiDataset' in sampler_name: 35 | dataset_dicts = get_detection_dataset_dicts_with_source( 36 | cfg.DATASETS.TRAIN, 37 | filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, 38 | min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE 39 | if cfg.MODEL.KEYPOINT_ON else 0, 40 | proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, 41 | ann_types=cfg.DATALOADER.DATASET_ANN, 42 | ) 43 | else: 44 | dataset_dicts = get_detection_dataset_dicts( 45 | cfg.DATASETS.TRAIN, 46 | filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, 47 | min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE 48 | if cfg.MODEL.KEYPOINT_ON else 0, 49 | proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, 50 | ) 51 | 52 | if mapper is None: 53 | mapper = DatasetMapper(cfg, True) 54 | 55 | if sampler is not None: 56 | pass 57 | elif sampler_name == "TrainingSampler": 58 | sampler = TrainingSampler(len(dataset)) 59 | elif sampler_name == "MultiDatasetSampler": 60 | sampler = MultiDatasetSampler( 61 | dataset_dicts, 62 | dataset_ratio = cfg.DATALOADER.DATASET_RATIO, 63 | use_rfs = cfg.DATALOADER.USE_RFS, 64 | dataset_ann = cfg.DATALOADER.DATASET_ANN, 65 | repeat_threshold = cfg.DATALOADER.REPEAT_THRESHOLD, 66 | ) 67 | elif sampler_name == "RepeatFactorTrainingSampler": 68 | repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( 69 | dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD 70 | ) 71 | sampler = RepeatFactorTrainingSampler(repeat_factors) 72 | else: 73 | raise ValueError("Unknown training sampler: {}".format(sampler_name)) 74 | 75 | return { 76 | "dataset": dataset_dicts, 77 | "sampler": sampler, 78 | "mapper": mapper, 79 | "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, 80 | "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, 81 | "num_workers": cfg.DATALOADER.NUM_WORKERS, 82 | 'multi_dataset_grouping': cfg.DATALOADER.MULTI_DATASET_GROUPING, 83 | 'use_diff_bs_size': cfg.DATALOADER.USE_DIFF_BS_SIZE, 84 | 'dataset_bs': cfg.DATALOADER.DATASET_BS, 85 | 'num_datasets': len(cfg.DATASETS.TRAIN) 86 | } 87 | 88 | 89 | @configurable(from_config=_custom_train_loader_from_config) 90 | def build_custom_train_loader( 91 | dataset, *, mapper, sampler, 92 | total_batch_size=16, 93 | aspect_ratio_grouping=True, 94 | num_workers=0, 95 | num_datasets=1, 96 | multi_dataset_grouping=False, 97 | use_diff_bs_size=False, 98 | dataset_bs=[] 99 | ): 100 | """ 101 | Modified from detectron2.data.build.build_custom_train_loader, but supports 102 | different samplers 103 | """ 104 | if isinstance(dataset, list): 105 | dataset = DatasetFromList(dataset, copy=False) 106 | if mapper is not None: 107 | dataset = MapDataset(dataset, mapper) 108 | if sampler is None: 109 | sampler = TrainingSampler(len(dataset)) 110 | assert isinstance(sampler, torch.utils.data.sampler.Sampler) 111 | if multi_dataset_grouping: 112 | return build_multi_dataset_batch_data_loader( 113 | use_diff_bs_size, 114 | dataset_bs, 115 | dataset, 116 | sampler, 117 | total_batch_size, 118 | num_datasets=num_datasets, 119 | num_workers=num_workers, 120 | ) 121 | else: 122 | return build_batch_data_loader( 123 | dataset, 124 | sampler, 125 | total_batch_size, 126 | aspect_ratio_grouping=aspect_ratio_grouping, 127 | num_workers=num_workers, 128 | ) 129 | 130 | 131 | def build_multi_dataset_batch_data_loader( 132 | use_diff_bs_size, dataset_bs, 133 | dataset, sampler, total_batch_size, num_datasets, num_workers=0 134 | ): 135 | """ 136 | """ 137 | world_size = get_world_size() 138 | assert ( 139 | total_batch_size > 0 and total_batch_size % world_size == 0 140 | ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format( 141 | total_batch_size, world_size 142 | ) 143 | 144 | batch_size = total_batch_size // world_size 145 | data_loader = torch.utils.data.DataLoader( 146 | dataset, 147 | sampler=sampler, 148 | num_workers=num_workers, 149 | batch_sampler=None, 150 | collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements 151 | worker_init_fn=worker_init_reset_seed, 152 | ) # yield individual mapped dict 153 | if use_diff_bs_size: 154 | return DIFFMDAspectRatioGroupedDataset( 155 | data_loader, dataset_bs, num_datasets) 156 | else: 157 | return MDAspectRatioGroupedDataset( 158 | data_loader, batch_size, num_datasets) 159 | 160 | 161 | def filter_images_with_only_crowd_annotations_detic(dataset_dicts): 162 | """ 163 | Filter out images with none annotations or only crowd annotations 164 | (i.e., images without non-crowd annotations). 165 | A common training-time preprocessing on COCO dataset. 166 | 167 | Args: 168 | dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. 169 | 170 | Returns: 171 | list[dict]: the same format, but filtered. 172 | """ 173 | num_before = len(dataset_dicts) 174 | 175 | def valid(anns, ann_type): 176 | if ann_type != "box": 177 | return True 178 | for ann in anns: 179 | if ann.get("iscrowd", 0) == 0: 180 | return True 181 | return False 182 | 183 | dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"], x['ann_type'])] 184 | num_after = len(dataset_dicts) 185 | logger = logging.getLogger(__name__) 186 | logger.info( 187 | "Removed {} images with no usable annotations. {} images left.".format( 188 | num_before - num_after, num_after 189 | ) 190 | ) 191 | return dataset_dicts 192 | 193 | 194 | def get_detection_dataset_dicts_with_source( 195 | dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None, ann_types=[], 196 | ): 197 | assert len(dataset_names) 198 | dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] 199 | for dataset_name, dicts in zip(dataset_names, dataset_dicts): 200 | assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) 201 | 202 | for source_id, (dataset_name, dicts) in enumerate(zip( 203 | dataset_names, dataset_dicts)): 204 | assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) 205 | for d in dicts: 206 | d['dataset_source'] = source_id 207 | d['ann_type'] = ann_types[source_id] 208 | 209 | if "annotations" in dicts[0]: 210 | try: 211 | class_names = MetadataCatalog.get(dataset_name).thing_classes 212 | check_metadata_consistency("thing_classes", dataset_name) 213 | print_instances_class_histogram(dicts, class_names) 214 | except AttributeError: # class names are not available for this dataset 215 | pass 216 | 217 | assert proposal_files is None 218 | 219 | dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) 220 | 221 | has_instances = "annotations" in dataset_dicts[0] 222 | if filter_empty and has_instances: 223 | dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) 224 | if min_keypoints > 0 and has_instances: 225 | dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) 226 | 227 | return dataset_dicts 228 | 229 | 230 | class MultiDatasetSampler(Sampler): 231 | def __init__( 232 | self, 233 | dataset_dicts, 234 | dataset_ratio, 235 | use_rfs, 236 | dataset_ann, 237 | repeat_threshold=0.001, 238 | seed: Optional[int] = None, 239 | ): 240 | """ 241 | """ 242 | sizes = [0 for _ in range(len(dataset_ratio))] 243 | for d in dataset_dicts: 244 | sizes[d['dataset_source']] += 1 245 | print('dataset sizes', sizes) 246 | self.sizes = sizes 247 | assert len(dataset_ratio) == len(sizes), \ 248 | 'length of dataset ratio {} should be equal to number if dataset {}'.format( 249 | len(dataset_ratio), len(sizes) 250 | ) 251 | if seed is None: 252 | seed = comm.shared_random_seed() 253 | self._seed = int(seed) 254 | self._rank = comm.get_rank() 255 | self._world_size = comm.get_world_size() 256 | 257 | self.dataset_ids = torch.tensor( 258 | [d['dataset_source'] for d in dataset_dicts], dtype=torch.long) 259 | 260 | dataset_weight = ( 261 | [torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) 262 | for i, (r, s) in enumerate(zip(dataset_ratio, sizes))] 263 | ) 264 | dataset_weight = torch.cat(dataset_weight) 265 | 266 | rfs_factors = [] 267 | st = 0 268 | for i, s in enumerate(sizes): 269 | if use_rfs[i]: 270 | if dataset_ann[i] == 'box': 271 | rfs_func = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency 272 | else: 273 | rfs_func = repeat_factors_from_tag_frequency 274 | rfs_factor = rfs_func( 275 | dataset_dicts[st: st + s], 276 | repeat_thresh=repeat_threshold) 277 | rfs_factor = rfs_factor * (s / rfs_factor.sum()) 278 | else: 279 | rfs_factor = torch.ones(s) 280 | rfs_factors.append(rfs_factor) 281 | st = st + s 282 | rfs_factors = torch.cat(rfs_factors) 283 | 284 | self.weights = dataset_weight * rfs_factors 285 | self.sample_epoch_size = len(self.weights) 286 | 287 | def __iter__(self): 288 | start = self._rank 289 | yield from itertools.islice( 290 | self._infinite_indices(), start, None, self._world_size) 291 | 292 | def _infinite_indices(self): 293 | g = torch.Generator() 294 | g.manual_seed(self._seed) 295 | while True: 296 | ids = torch.multinomial( 297 | self.weights, self.sample_epoch_size, generator=g, 298 | replacement=True) 299 | nums = ([(self.dataset_ids[ids] == i).sum().int().item() 300 | for i in range(len(self.sizes))]) 301 | yield from ids 302 | 303 | 304 | class MDAspectRatioGroupedDataset(torch.utils.data.IterableDataset): 305 | def __init__(self, dataset, batch_size, num_datasets): 306 | """ 307 | """ 308 | self.dataset = dataset 309 | self.batch_size = batch_size 310 | self._buckets = [[] for _ in range(2 * num_datasets)] 311 | 312 | def __iter__(self): 313 | for d in self.dataset: 314 | w, h = d["width"], d["height"] 315 | aspect_ratio_bucket_id = 0 if w > h else 1 316 | bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id 317 | bucket = self._buckets[bucket_id] 318 | bucket.append(d) 319 | if len(bucket) == self.batch_size: 320 | yield bucket[:] 321 | del bucket[:] 322 | 323 | 324 | class DIFFMDAspectRatioGroupedDataset(torch.utils.data.IterableDataset): 325 | def __init__(self, dataset, batch_sizes, num_datasets): 326 | """ 327 | """ 328 | self.dataset = dataset 329 | self.batch_sizes = batch_sizes 330 | self._buckets = [[] for _ in range(2 * num_datasets)] 331 | 332 | def __iter__(self): 333 | for d in self.dataset: 334 | w, h = d["width"], d["height"] 335 | aspect_ratio_bucket_id = 0 if w > h else 1 336 | bucket_id = d['dataset_source'] * 2 + aspect_ratio_bucket_id 337 | bucket = self._buckets[bucket_id] 338 | bucket.append(d) 339 | if len(bucket) == self.batch_sizes[d['dataset_source']]: 340 | yield bucket[:] 341 | del bucket[:] 342 | 343 | 344 | def repeat_factors_from_tag_frequency(dataset_dicts, repeat_thresh): 345 | """ 346 | """ 347 | category_freq = defaultdict(int) 348 | for dataset_dict in dataset_dicts: 349 | cat_ids = dataset_dict['pos_category_ids'] 350 | for cat_id in cat_ids: 351 | category_freq[cat_id] += 1 352 | num_images = len(dataset_dicts) 353 | for k, v in category_freq.items(): 354 | category_freq[k] = v / num_images 355 | 356 | category_rep = { 357 | cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq)) 358 | for cat_id, cat_freq in category_freq.items() 359 | } 360 | 361 | rep_factors = [] 362 | for dataset_dict in dataset_dicts: 363 | cat_ids = dataset_dict['pos_category_ids'] 364 | rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0) 365 | rep_factors.append(rep_factor) 366 | 367 | return torch.tensor(rep_factors, dtype=torch.float32) -------------------------------------------------------------------------------- /mmovod/data/custom_dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import copy 3 | import logging 4 | import numpy as np 5 | from typing import List, Optional, Union 6 | import torch 7 | import pycocotools.mask as mask_util 8 | 9 | from detectron2.config import configurable 10 | 11 | from detectron2.data import detection_utils as utils 12 | from detectron2.data.detection_utils import transform_keypoint_annotations 13 | from detectron2.data import transforms as T 14 | from detectron2.data.dataset_mapper import DatasetMapper 15 | from detectron2.structures import Boxes, BoxMode, Instances 16 | from detectron2.structures import Keypoints, PolygonMasks, BitMasks 17 | from fvcore.transforms.transform import TransformList 18 | from .custom_build_augmentation import build_custom_augmentation 19 | 20 | __all__ = ["CustomDatasetMapper"] 21 | 22 | class CustomDatasetMapper(DatasetMapper): 23 | @configurable 24 | def __init__(self, is_train: bool, 25 | with_ann_type=False, 26 | dataset_ann=[], 27 | use_diff_bs_size=False, 28 | dataset_augs=[], 29 | is_debug=False, 30 | use_tar_dataset=False, 31 | tarfile_path='', 32 | tar_index_dir='', 33 | **kwargs): 34 | """ 35 | add image labels 36 | """ 37 | self.with_ann_type = with_ann_type 38 | self.dataset_ann = dataset_ann 39 | self.use_diff_bs_size = use_diff_bs_size 40 | if self.use_diff_bs_size and is_train: 41 | self.dataset_augs = [T.AugmentationList(x) for x in dataset_augs] 42 | self.is_debug = is_debug 43 | self.use_tar_dataset = use_tar_dataset 44 | if self.use_tar_dataset: 45 | raise NotImplementedError('Using tar dataset not supported') 46 | super().__init__(is_train, **kwargs) 47 | 48 | 49 | @classmethod 50 | def from_config(cls, cfg, is_train: bool = True): 51 | ret = super().from_config(cfg, is_train) 52 | ret.update({ 53 | 'with_ann_type': cfg.WITH_IMAGE_LABELS, 54 | 'dataset_ann': cfg.DATALOADER.DATASET_ANN, 55 | 'use_diff_bs_size': cfg.DATALOADER.USE_DIFF_BS_SIZE, 56 | 'is_debug': cfg.IS_DEBUG, 57 | 'use_tar_dataset': cfg.DATALOADER.USE_TAR_DATASET, 58 | 'tarfile_path': cfg.DATALOADER.TARFILE_PATH, 59 | 'tar_index_dir': cfg.DATALOADER.TAR_INDEX_DIR, 60 | }) 61 | if ret['use_diff_bs_size'] and is_train: 62 | if cfg.INPUT.CUSTOM_AUG == 'EfficientDetResizeCrop': 63 | dataset_scales = cfg.DATALOADER.DATASET_INPUT_SCALE 64 | dataset_sizes = cfg.DATALOADER.DATASET_INPUT_SIZE 65 | ret['dataset_augs'] = [ 66 | build_custom_augmentation(cfg, True, scale, size) \ 67 | for scale, size in zip(dataset_scales, dataset_sizes)] 68 | else: 69 | assert cfg.INPUT.CUSTOM_AUG == 'ResizeShortestEdge' 70 | min_sizes = cfg.DATALOADER.DATASET_MIN_SIZES 71 | max_sizes = cfg.DATALOADER.DATASET_MAX_SIZES 72 | ret['dataset_augs'] = [ 73 | build_custom_augmentation( 74 | cfg, True, min_size=mi, max_size=ma) \ 75 | for mi, ma in zip(min_sizes, max_sizes)] 76 | else: 77 | ret['dataset_augs'] = [] 78 | 79 | return ret 80 | 81 | def __call__(self, dataset_dict): 82 | """ 83 | include image labels 84 | """ 85 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 86 | # USER: Write your own image loading if it's not from a file 87 | if 'file_name' in dataset_dict: 88 | ori_image = utils.read_image( 89 | dataset_dict["file_name"], format=self.image_format) 90 | else: 91 | ori_image, _, _ = self.tar_dataset[dataset_dict["tar_index"]] 92 | ori_image = utils._apply_exif_orientation(ori_image) 93 | ori_image = utils.convert_PIL_to_numpy(ori_image, self.image_format) 94 | utils.check_image_size(dataset_dict, ori_image) 95 | 96 | # USER: Remove if you don't do semantic/panoptic segmentation. 97 | if "sem_seg_file_name" in dataset_dict: 98 | sem_seg_gt = utils.read_image( 99 | dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) 100 | else: 101 | sem_seg_gt = None 102 | 103 | if self.is_debug: 104 | dataset_dict['dataset_source'] = 0 105 | 106 | not_full_labeled = 'dataset_source' in dataset_dict and \ 107 | self.with_ann_type and \ 108 | self.dataset_ann[dataset_dict['dataset_source']] != 'box' 109 | 110 | aug_input = T.AugInput(copy.deepcopy(ori_image), sem_seg=sem_seg_gt) 111 | if self.use_diff_bs_size and self.is_train: 112 | transforms = \ 113 | self.dataset_augs[dataset_dict['dataset_source']](aug_input) 114 | else: 115 | transforms = self.augmentations(aug_input) 116 | image, sem_seg_gt = aug_input.image, aug_input.sem_seg 117 | 118 | image_shape = image.shape[:2] # h, w 119 | dataset_dict["image"] = torch.as_tensor( 120 | np.ascontiguousarray(image.transpose(2, 0, 1))) 121 | 122 | if sem_seg_gt is not None: 123 | dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) 124 | 125 | # USER: Remove if you don't use pre-computed proposals. 126 | # Most users would not need this feature. 127 | if self.proposal_topk is not None: 128 | utils.transform_proposals( 129 | dataset_dict, image_shape, transforms, 130 | proposal_topk=self.proposal_topk 131 | ) 132 | 133 | if not self.is_train: 134 | # USER: Modify this if you want to keep them for some reason. 135 | dataset_dict.pop("annotations", None) 136 | dataset_dict.pop("sem_seg_file_name", None) 137 | return dataset_dict 138 | 139 | if "annotations" in dataset_dict: 140 | # USER: Modify this if you want to keep them for some reason. 141 | for anno in dataset_dict["annotations"]: 142 | if not self.use_instance_mask: 143 | anno.pop("segmentation", None) 144 | if not self.use_keypoint: 145 | anno.pop("keypoints", None) 146 | 147 | # USER: Implement additional transformations if you have other types of data 148 | all_annos = [ 149 | (utils.transform_instance_annotations( 150 | obj, transforms, image_shape, 151 | keypoint_hflip_indices=self.keypoint_hflip_indices, 152 | ), obj.get("iscrowd", 0)) 153 | for obj in dataset_dict.pop("annotations") 154 | ] 155 | annos = [ann[0] for ann in all_annos if ann[1] == 0] 156 | instances = utils.annotations_to_instances( 157 | annos, image_shape, mask_format=self.instance_mask_format 158 | ) 159 | 160 | del all_annos 161 | if self.recompute_boxes: 162 | instances.gt_boxes = instances.gt_masks.get_bounding_boxes() 163 | dataset_dict["instances"] = utils.filter_empty_instances(instances) 164 | if self.with_ann_type: 165 | dataset_dict["pos_category_ids"] = dataset_dict.get( 166 | 'pos_category_ids', []) 167 | dataset_dict["ann_type"] = \ 168 | self.dataset_ann[dataset_dict['dataset_source']] 169 | if self.is_debug and (('pos_category_ids' not in dataset_dict) or \ 170 | (dataset_dict['pos_category_ids'] == [])): 171 | dataset_dict['pos_category_ids'] = [x for x in sorted(set( 172 | dataset_dict['instances'].gt_classes.tolist() 173 | ))] 174 | return dataset_dict 175 | -------------------------------------------------------------------------------- /mmovod/data/datasets/__pycache__/cc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/cc.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/data/datasets/__pycache__/coco_zeroshot.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/coco_zeroshot.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/data/datasets/__pycache__/imagenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/imagenet.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/data/datasets/__pycache__/lvis_22k_categories.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/lvis_22k_categories.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/data/datasets/__pycache__/lvis_v1.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/lvis_v1.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/data/datasets/__pycache__/objects365.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/objects365.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/data/datasets/__pycache__/oid.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/oid.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/data/datasets/__pycache__/register_oid.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/datasets/__pycache__/register_oid.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/data/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # import logging 3 | import os 4 | 5 | from detectron2.data import DatasetCatalog, MetadataCatalog 6 | from detectron2.data.datasets.lvis import get_lvis_instances_meta 7 | from .lvis_v1 import custom_load_lvis_json 8 | 9 | 10 | def custom_register_imagenet_instances(name, metadata, json_file, image_root): 11 | """ 12 | """ 13 | DatasetCatalog.register(name, lambda: custom_load_lvis_json( 14 | json_file, image_root, name)) 15 | MetadataCatalog.get(name).set( 16 | json_file=json_file, image_root=image_root, 17 | evaluator_type="imagenet", **metadata 18 | ) 19 | 20 | 21 | _CUSTOM_SPLITS_IMAGENET = { 22 | "imagenet_lvis_v1": ( 23 | "imagenet/imagenet21k_P/", 24 | "imagenet/annotations/imagenet_lvis_image_info.json", 25 | ), 26 | } 27 | 28 | 29 | for key, (image_root, json_file) in _CUSTOM_SPLITS_IMAGENET.items(): 30 | _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets")) 31 | custom_register_imagenet_instances( 32 | key, 33 | get_lvis_instances_meta('lvis_v1'), 34 | os.path.join(_root, json_file) if "://" not in json_file else json_file, 35 | os.path.join(_root, image_root), 36 | ) 37 | -------------------------------------------------------------------------------- /mmovod/data/datasets/lvis_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | import os 4 | 5 | from fvcore.common.timer import Timer 6 | from detectron2.structures import BoxMode 7 | from fvcore.common.file_io import PathManager 8 | from detectron2.data import DatasetCatalog, MetadataCatalog 9 | from detectron2.data.datasets.lvis import get_lvis_instances_meta 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | __all__ = ["custom_load_lvis_json", "custom_register_lvis_instances"] 14 | 15 | 16 | def custom_register_lvis_instances(name, metadata, json_file, image_root): 17 | """ 18 | """ 19 | DatasetCatalog.register(name, lambda: custom_load_lvis_json( 20 | json_file, image_root, name)) 21 | MetadataCatalog.get(name).set( 22 | json_file=json_file, image_root=image_root, 23 | evaluator_type="lvis", **metadata 24 | ) 25 | 26 | 27 | def custom_load_lvis_json(json_file, image_root, dataset_name=None): 28 | ''' 29 | Modifications: 30 | use `file_name` 31 | convert neg_category_ids 32 | add pos_category_ids 33 | ''' 34 | from lvis import LVIS 35 | 36 | json_file = PathManager.get_local_path(json_file) 37 | 38 | timer = Timer() 39 | lvis_api = LVIS(json_file) 40 | if timer.seconds() > 1: 41 | logger.info("Loading {} takes {:.2f} seconds.".format( 42 | json_file, timer.seconds())) 43 | 44 | catid2contid = {x['id']: i for i, x in enumerate( 45 | sorted(lvis_api.dataset['categories'], key=lambda x: x['id']))} 46 | if len(lvis_api.dataset['categories']) == 1203: 47 | for x in lvis_api.dataset['categories']: 48 | assert catid2contid[x['id']] == x['id'] - 1 49 | img_ids = sorted(lvis_api.imgs.keys()) 50 | imgs = lvis_api.load_imgs(img_ids) 51 | anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids] 52 | 53 | ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image] 54 | assert len(set(ann_ids)) == len(ann_ids), \ 55 | "Annotation ids in '{}' are not unique".format(json_file) 56 | 57 | imgs_anns = list(zip(imgs, anns)) 58 | logger.info("Loaded {} images in the LVIS v1 format from {}".format( 59 | len(imgs_anns), json_file)) 60 | 61 | dataset_dicts = [] 62 | 63 | for (img_dict, anno_dict_list) in imgs_anns: 64 | record = {} 65 | if "file_name" in img_dict: 66 | file_name = img_dict["file_name"] 67 | if img_dict["file_name"].startswith("COCO"): 68 | file_name = file_name[-16:] 69 | record["file_name"] = os.path.join(image_root, file_name) 70 | elif 'coco_url' in img_dict: 71 | # e.g., http://images.cocodataset.org/train2017/000000391895.jpg 72 | file_name = img_dict["coco_url"][30:] 73 | record["file_name"] = os.path.join(image_root, file_name) 74 | elif 'tar_index' in img_dict: 75 | record['tar_index'] = img_dict['tar_index'] 76 | 77 | record["height"] = img_dict["height"] 78 | record["width"] = img_dict["width"] 79 | record["not_exhaustive_category_ids"] = img_dict.get( 80 | "not_exhaustive_category_ids", []) 81 | record["neg_category_ids"] = img_dict.get("neg_category_ids", []) 82 | # NOTE: modified by Xingyi: convert to 0-based 83 | record["neg_category_ids"] = [ 84 | catid2contid[x] for x in record["neg_category_ids"]] 85 | if 'pos_category_ids' in img_dict: 86 | record['pos_category_ids'] = [ 87 | catid2contid[x] for x in img_dict.get("pos_category_ids", [])] 88 | if 'captions' in img_dict: 89 | record['captions'] = img_dict['captions'] 90 | if 'caption_features' in img_dict: 91 | record['caption_features'] = img_dict['caption_features'] 92 | image_id = record["image_id"] = img_dict["id"] 93 | 94 | objs = [] 95 | for anno in anno_dict_list: 96 | assert anno["image_id"] == image_id 97 | if anno.get('iscrowd', 0) > 0: 98 | continue 99 | obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS} 100 | obj["category_id"] = catid2contid[anno['category_id']] 101 | if 'segmentation' in anno: 102 | segm = anno["segmentation"] 103 | valid_segm = [poly for poly in segm 104 | if len(poly) % 2 == 0 and len(poly) >= 6] 105 | # assert len(segm) == len( 106 | # valid_segm 107 | # ), "Annotation contains an invalid polygon with < 3 points" 108 | if not len(segm) == len(valid_segm): 109 | print('Annotation contains an invalid polygon with < 3 points') 110 | assert len(segm) > 0 111 | obj["segmentation"] = segm 112 | objs.append(obj) 113 | record["annotations"] = objs 114 | dataset_dicts.append(record) 115 | 116 | return dataset_dicts 117 | 118 | 119 | _CUSTOM_SPLITS_LVIS = { 120 | "lvis_v1_train_norare": ("coco/", "lvis/lvis_v1_train_norare.json"), 121 | "lvis_v1_val_rareonly": ("coco/", "lvis/lvis_v1_val_rare_only.json"), 122 | "lvis_v1_val_norare": ("coco/", "lvis/lvis_v1_val_norare.json"), 123 | "lvis_v1_train_rareonly": ("coco/", "lvis/lvis_v1_train_rare_only.json"), 124 | } 125 | 126 | 127 | for key, (image_root, json_file) in _CUSTOM_SPLITS_LVIS.items(): 128 | _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets")) 129 | custom_register_lvis_instances( 130 | key, 131 | get_lvis_instances_meta(key), 132 | os.path.join(_root, json_file) if "://" not in json_file else json_file, 133 | os.path.join(_root, image_root), 134 | ) 135 | -------------------------------------------------------------------------------- /mmovod/data/transforms/__pycache__/custom_augmentation_impl.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/transforms/__pycache__/custom_augmentation_impl.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/data/transforms/__pycache__/custom_transform.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/data/transforms/__pycache__/custom_transform.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/data/transforms/custom_augmentation_impl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | # Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py 4 | # Modified by Xingyi Zhou 5 | # The original code is under Apache-2.0 License 6 | import numpy as np 7 | import sys 8 | from fvcore.transforms.transform import ( 9 | BlendTransform, 10 | CropTransform, 11 | HFlipTransform, 12 | NoOpTransform, 13 | Transform, 14 | VFlipTransform, 15 | ) 16 | from PIL import Image 17 | 18 | from detectron2.data.transforms.augmentation import Augmentation 19 | from .custom_transform import EfficientDetResizeCropTransform 20 | 21 | __all__ = [ 22 | "EfficientDetResizeCrop", 23 | ] 24 | 25 | class EfficientDetResizeCrop(Augmentation): 26 | """ 27 | Scale the shorter edge to the given size, with a limit of `max_size` on the longer edge. 28 | If `max_size` is reached, then downscale so that the longer edge does not exceed max_size. 29 | """ 30 | 31 | def __init__( 32 | self, size, scale, interp=Image.BILINEAR 33 | ): 34 | """ 35 | """ 36 | super().__init__() 37 | self.target_size = (size, size) 38 | self.scale = scale 39 | self.interp = interp 40 | 41 | def get_transform(self, img): 42 | # Select a random scale factor. 43 | scale_factor = np.random.uniform(*self.scale) 44 | scaled_target_height = scale_factor * self.target_size[0] 45 | scaled_target_width = scale_factor * self.target_size[1] 46 | # Recompute the accurate scale_factor using rounded scaled image size. 47 | width, height = img.shape[1], img.shape[0] 48 | img_scale_y = scaled_target_height / height 49 | img_scale_x = scaled_target_width / width 50 | img_scale = min(img_scale_y, img_scale_x) 51 | 52 | # Select non-zero random offset (x, y) if scaled image is larger than target size 53 | scaled_h = int(height * img_scale) 54 | scaled_w = int(width * img_scale) 55 | offset_y = scaled_h - self.target_size[0] 56 | offset_x = scaled_w - self.target_size[1] 57 | offset_y = int(max(0.0, float(offset_y)) * np.random.uniform(0, 1)) 58 | offset_x = int(max(0.0, float(offset_x)) * np.random.uniform(0, 1)) 59 | return EfficientDetResizeCropTransform( 60 | scaled_h, scaled_w, offset_y, offset_x, img_scale, self.target_size, self.interp) 61 | -------------------------------------------------------------------------------- /mmovod/data/transforms/custom_transform.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | # Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py 4 | # Modified by Xingyi Zhou 5 | # The original code is under Apache-2.0 License 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from fvcore.transforms.transform import ( 10 | CropTransform, 11 | HFlipTransform, 12 | NoOpTransform, 13 | Transform, 14 | TransformList, 15 | ) 16 | from PIL import Image 17 | 18 | try: 19 | import cv2 # noqa 20 | except ImportError: 21 | # OpenCV is an optional dependency at the moment 22 | pass 23 | 24 | __all__ = [ 25 | "EfficientDetResizeCropTransform", 26 | ] 27 | 28 | class EfficientDetResizeCropTransform(Transform): 29 | """ 30 | """ 31 | 32 | def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, \ 33 | target_size, interp=None): 34 | """ 35 | Args: 36 | h, w (int): original image size 37 | new_h, new_w (int): new image size 38 | interp: PIL interpolation methods, defaults to bilinear. 39 | """ 40 | # TODO decide on PIL vs opencv 41 | super().__init__() 42 | if interp is None: 43 | interp = Image.BILINEAR 44 | self._set_attributes(locals()) 45 | 46 | def apply_image(self, img, interp=None): 47 | assert len(img.shape) <= 4 48 | 49 | if img.dtype == np.uint8: 50 | pil_image = Image.fromarray(img) 51 | interp_method = interp if interp is not None else self.interp 52 | pil_image = pil_image.resize((self.scaled_w, self.scaled_h), interp_method) 53 | ret = np.asarray(pil_image) 54 | right = min(self.scaled_w, self.offset_x + self.target_size[1]) 55 | lower = min(self.scaled_h, self.offset_y + self.target_size[0]) 56 | if len(ret.shape) <= 3: 57 | ret = ret[self.offset_y: lower, self.offset_x: right] 58 | else: 59 | ret = ret[..., self.offset_y: lower, self.offset_x: right, :] 60 | else: 61 | # PIL only supports uint8 62 | img = torch.from_numpy(img) 63 | shape = list(img.shape) 64 | shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:] 65 | img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw 66 | _PIL_RESIZE_TO_INTERPOLATE_MODE = {Image.BILINEAR: "bilinear", Image.BICUBIC: "bicubic"} 67 | mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[self.interp] 68 | img = F.interpolate(img, (self.scaled_h, self.scaled_w), mode=mode, align_corners=False) 69 | shape[:2] = (self.scaled_h, self.scaled_w) 70 | ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c) 71 | right = min(self.scaled_w, self.offset_x + self.target_size[1]) 72 | lower = min(self.scaled_h, self.offset_y + self.target_size[0]) 73 | if len(ret.shape) <= 3: 74 | ret = ret[self.offset_y: lower, self.offset_x: right] 75 | else: 76 | ret = ret[..., self.offset_y: lower, self.offset_x: right, :] 77 | return ret 78 | 79 | 80 | def apply_coords(self, coords): 81 | coords[:, 0] = coords[:, 0] * self.img_scale 82 | coords[:, 1] = coords[:, 1] * self.img_scale 83 | coords[:, 0] -= self.offset_x 84 | coords[:, 1] -= self.offset_y 85 | return coords 86 | 87 | 88 | def apply_segmentation(self, segmentation): 89 | segmentation = self.apply_image(segmentation, interp=Image.NEAREST) 90 | return segmentation 91 | 92 | 93 | def inverse(self): 94 | raise NotImplementedError 95 | 96 | 97 | def inverse_apply_coords(self, coords): 98 | coords[:, 0] += self.offset_x 99 | coords[:, 1] += self.offset_y 100 | coords[:, 0] = coords[:, 0] / self.img_scale 101 | coords[:, 1] = coords[:, 1] / self.img_scale 102 | return coords 103 | 104 | 105 | def inverse_apply_box(self, box: np.ndarray) -> np.ndarray: 106 | """ 107 | """ 108 | idxs = np.array([(0, 1), (2, 1), (0, 3), (2, 3)]).flatten() 109 | coords = np.asarray(box).reshape(-1, 4)[:, idxs].reshape(-1, 2) 110 | coords = self.inverse_apply_coords(coords).reshape((-1, 4, 2)) 111 | minxy = coords.min(axis=1) 112 | maxxy = coords.max(axis=1) 113 | trans_boxes = np.concatenate((minxy, maxxy), axis=1) 114 | return trans_boxes -------------------------------------------------------------------------------- /mmovod/modeling/__pycache__/debug.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/__pycache__/debug.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/modeling/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/modeling/backbone/__pycache__/swintransformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/backbone/__pycache__/swintransformer.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/modeling/backbone/__pycache__/timm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/backbone/__pycache__/timm.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/modeling/backbone/timm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | import math 5 | from os.path import join 6 | import numpy as np 7 | import copy 8 | from functools import partial 9 | 10 | import torch 11 | from torch import nn 12 | import torch.utils.model_zoo as model_zoo 13 | import torch.nn.functional as F 14 | import fvcore.nn.weight_init as weight_init 15 | 16 | from detectron2.modeling.backbone import FPN 17 | from detectron2.modeling.backbone.build import BACKBONE_REGISTRY 18 | from detectron2.layers.batch_norm import get_norm, FrozenBatchNorm2d 19 | from detectron2.modeling.backbone import Backbone 20 | 21 | from timm import create_model 22 | from timm.models.helpers import build_model_with_cfg 23 | from timm.models.registry import register_model 24 | from timm.models.resnet import ResNet, Bottleneck 25 | from timm.models.resnet import default_cfgs as default_cfgs_resnet 26 | from timm.models.convnext import ConvNeXt, default_cfgs, checkpoint_filter_fn 27 | 28 | 29 | @register_model 30 | def convnext_tiny_21k(pretrained=False, **kwargs): 31 | model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) 32 | cfg = default_cfgs['convnext_tiny'] 33 | cfg['url'] = 'https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth' 34 | model = build_model_with_cfg( 35 | ConvNeXt, 'convnext_tiny', pretrained, 36 | default_cfg=cfg, 37 | pretrained_filter_fn=checkpoint_filter_fn, 38 | feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), 39 | **model_args) 40 | return model 41 | 42 | class CustomResNet(ResNet): 43 | def __init__(self, **kwargs): 44 | self.out_indices = kwargs.pop('out_indices') 45 | super().__init__(**kwargs) 46 | 47 | 48 | def forward(self, x): 49 | x = self.conv1(x) 50 | x = self.bn1(x) 51 | x = self.act1(x) 52 | x = self.maxpool(x) 53 | ret = [x] 54 | x = self.layer1(x) 55 | ret.append(x) 56 | x = self.layer2(x) 57 | ret.append(x) 58 | x = self.layer3(x) 59 | ret.append(x) 60 | x = self.layer4(x) 61 | ret.append(x) 62 | return [ret[i] for i in self.out_indices] 63 | 64 | 65 | def load_pretrained(self, cached_file): 66 | data = torch.load(cached_file, map_location='cpu') 67 | if 'state_dict' in data: 68 | self.load_state_dict(data['state_dict']) 69 | else: 70 | self.load_state_dict(data) 71 | 72 | 73 | model_params = { 74 | 'resnet50_in21k': dict(block=Bottleneck, layers=[3, 4, 6, 3]), 75 | } 76 | 77 | 78 | def create_timm_resnet(variant, out_indices, pretrained=False, **kwargs): 79 | params = model_params[variant] 80 | default_cfgs_resnet['resnet50_in21k'] = \ 81 | copy.deepcopy(default_cfgs_resnet['resnet50']) 82 | default_cfgs_resnet['resnet50_in21k']['url'] = \ 83 | 'https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth' 84 | default_cfgs_resnet['resnet50_in21k']['num_classes'] = 11221 85 | 86 | return build_model_with_cfg( 87 | CustomResNet, variant, pretrained, 88 | default_cfg=default_cfgs_resnet[variant], 89 | out_indices=out_indices, 90 | pretrained_custom_load=True, 91 | **params, 92 | **kwargs) 93 | 94 | 95 | class LastLevelP6P7_P5(nn.Module): 96 | """ 97 | """ 98 | def __init__(self, in_channels, out_channels): 99 | super().__init__() 100 | self.num_levels = 2 101 | self.in_feature = "p5" 102 | self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) 103 | self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) 104 | for module in [self.p6, self.p7]: 105 | weight_init.c2_xavier_fill(module) 106 | 107 | def forward(self, c5): 108 | p6 = self.p6(c5) 109 | p7 = self.p7(F.relu(p6)) 110 | return [p6, p7] 111 | 112 | 113 | def freeze_module(x): 114 | """ 115 | """ 116 | for p in x.parameters(): 117 | p.requires_grad = False 118 | FrozenBatchNorm2d.convert_frozen_batchnorm(x) 119 | return x 120 | 121 | 122 | class TIMM(Backbone): 123 | def __init__(self, base_name, out_levels, freeze_at=0, norm='FrozenBN', pretrained=False): 124 | super().__init__() 125 | out_indices = [x - 1 for x in out_levels] 126 | if base_name in model_params: 127 | self.base = create_timm_resnet( 128 | base_name, out_indices=out_indices, 129 | pretrained=False) 130 | elif 'eff' in base_name or 'resnet' in base_name or 'regnet' in base_name: 131 | self.base = create_model( 132 | base_name, features_only=True, 133 | out_indices=out_indices, pretrained=pretrained) 134 | elif 'convnext' in base_name: 135 | drop_path_rate = 0.2 \ 136 | if ('tiny' in base_name or 'small' in base_name) else 0.3 137 | self.base = create_model( 138 | base_name, features_only=True, 139 | out_indices=out_indices, pretrained=pretrained, 140 | drop_path_rate=drop_path_rate) 141 | else: 142 | assert 0, base_name 143 | feature_info = [dict(num_chs=f['num_chs'], reduction=f['reduction']) \ 144 | for i, f in enumerate(self.base.feature_info)] 145 | self._out_features = ['layer{}'.format(x) for x in out_levels] 146 | self._out_feature_channels = { 147 | 'layer{}'.format(l): feature_info[l - 1]['num_chs'] for l in out_levels} 148 | self._out_feature_strides = { 149 | 'layer{}'.format(l): feature_info[l - 1]['reduction'] for l in out_levels} 150 | self._size_divisibility = max(self._out_feature_strides.values()) 151 | if 'resnet' in base_name: 152 | self.freeze(freeze_at) 153 | if norm == 'FrozenBN': 154 | self = FrozenBatchNorm2d.convert_frozen_batchnorm(self) 155 | 156 | def freeze(self, freeze_at=0): 157 | """ 158 | """ 159 | if freeze_at >= 1: 160 | print('Frezing', self.base.conv1) 161 | self.base.conv1 = freeze_module(self.base.conv1) 162 | if freeze_at >= 2: 163 | print('Frezing', self.base.layer1) 164 | self.base.layer1 = freeze_module(self.base.layer1) 165 | 166 | def forward(self, x): 167 | features = self.base(x) 168 | ret = {k: v for k, v in zip(self._out_features, features)} 169 | return ret 170 | 171 | @property 172 | def size_divisibility(self): 173 | return self._size_divisibility 174 | 175 | 176 | @BACKBONE_REGISTRY.register() 177 | def build_timm_backbone(cfg, input_shape): 178 | model = TIMM( 179 | cfg.MODEL.TIMM.BASE_NAME, 180 | cfg.MODEL.TIMM.OUT_LEVELS, 181 | freeze_at=cfg.MODEL.TIMM.FREEZE_AT, 182 | norm=cfg.MODEL.TIMM.NORM, 183 | pretrained=cfg.MODEL.TIMM.PRETRAINED, 184 | ) 185 | return model 186 | 187 | 188 | @BACKBONE_REGISTRY.register() 189 | def build_p67_timm_fpn_backbone(cfg, input_shape): 190 | """ 191 | """ 192 | bottom_up = build_timm_backbone(cfg, input_shape) 193 | in_features = cfg.MODEL.FPN.IN_FEATURES 194 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS 195 | backbone = FPN( 196 | bottom_up=bottom_up, 197 | in_features=in_features, 198 | out_channels=out_channels, 199 | norm=cfg.MODEL.FPN.NORM, 200 | top_block=LastLevelP6P7_P5(out_channels, out_channels), 201 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE, 202 | ) 203 | return backbone 204 | 205 | @BACKBONE_REGISTRY.register() 206 | def build_p35_timm_fpn_backbone(cfg, input_shape): 207 | """ 208 | """ 209 | bottom_up = build_timm_backbone(cfg, input_shape) 210 | 211 | in_features = cfg.MODEL.FPN.IN_FEATURES 212 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS 213 | backbone = FPN( 214 | bottom_up=bottom_up, 215 | in_features=in_features, 216 | out_channels=out_channels, 217 | norm=cfg.MODEL.FPN.NORM, 218 | top_block=None, 219 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE, 220 | ) 221 | return backbone -------------------------------------------------------------------------------- /mmovod/modeling/debug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import cv2 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | import os 7 | 8 | COLORS = ((np.random.rand(1300, 3) * 0.4 + 0.6) * 255).astype( 9 | np.uint8).reshape(1300, 1, 1, 3) 10 | 11 | def _get_color_image(heatmap): 12 | heatmap = heatmap.reshape( 13 | heatmap.shape[0], heatmap.shape[1], heatmap.shape[2], 1) 14 | if heatmap.shape[0] == 1: 15 | color_map = (heatmap * np.ones((1, 1, 1, 3), np.uint8) * 255).max( 16 | axis=0).astype(np.uint8) # H, W, 3 17 | else: 18 | color_map = (heatmap * COLORS[:heatmap.shape[0]]).max(axis=0).astype(np.uint8) # H, W, 3 19 | 20 | return color_map 21 | 22 | def _blend_image(image, color_map, a=0.7): 23 | color_map = cv2.resize(color_map, (image.shape[1], image.shape[0])) 24 | ret = np.clip(image * (1 - a) + color_map * a, 0, 255).astype(np.uint8) 25 | return ret 26 | 27 | def _blend_image_heatmaps(image, color_maps, a=0.7): 28 | merges = np.zeros((image.shape[0], image.shape[1], 3), np.float32) 29 | for color_map in color_maps: 30 | color_map = cv2.resize(color_map, (image.shape[1], image.shape[0])) 31 | merges = np.maximum(merges, color_map) 32 | ret = np.clip(image * (1 - a) + merges * a, 0, 255).astype(np.uint8) 33 | return ret 34 | 35 | def _decompose_level(x, shapes_per_level, N): 36 | ''' 37 | x: LNHiWi x C 38 | ''' 39 | x = x.view(x.shape[0], -1) 40 | ret = [] 41 | st = 0 42 | for l in range(len(shapes_per_level)): 43 | ret.append([]) 44 | h = shapes_per_level[l][0].int().item() 45 | w = shapes_per_level[l][1].int().item() 46 | for i in range(N): 47 | ret[l].append(x[st + h * w * i:st + h * w * (i + 1)].view( 48 | h, w, -1).permute(2, 0, 1)) 49 | st += h * w * N 50 | return ret 51 | 52 | def _imagelist_to_tensor(images): 53 | images = [x for x in images] 54 | image_sizes = [x.shape[-2:] for x in images] 55 | h = max([size[0] for size in image_sizes]) 56 | w = max([size[1] for size in image_sizes]) 57 | S = 32 58 | h, w = ((h - 1) // S + 1) * S, ((w - 1) // S + 1) * S 59 | images = [F.pad(x, (0, w - x.shape[2], 0, h - x.shape[1], 0, 0)) \ 60 | for x in images] 61 | images = torch.stack(images) 62 | return images 63 | 64 | 65 | def _ind2il(ind, shapes_per_level, N): 66 | r = ind 67 | l = 0 68 | S = 0 69 | while r - S >= N * shapes_per_level[l][0] * shapes_per_level[l][1]: 70 | S += N * shapes_per_level[l][0] * shapes_per_level[l][1] 71 | l += 1 72 | i = (r - S) // (shapes_per_level[l][0] * shapes_per_level[l][1]) 73 | return i, l 74 | 75 | def debug_train( 76 | images, gt_instances, flattened_hms, reg_targets, labels, pos_inds, 77 | shapes_per_level, locations, strides): 78 | ''' 79 | images: N x 3 x H x W 80 | flattened_hms: LNHiWi x C 81 | shapes_per_level: L x 2 [(H_i, W_i)] 82 | locations: LNHiWi x 2 83 | ''' 84 | reg_inds = torch.nonzero( 85 | reg_targets.max(dim=1)[0] > 0).squeeze(1) 86 | N = len(images) 87 | images = _imagelist_to_tensor(images) 88 | repeated_locations = [torch.cat([loc] * N, dim=0) \ 89 | for loc in locations] 90 | locations = torch.cat(repeated_locations, dim=0) 91 | gt_hms = _decompose_level(flattened_hms, shapes_per_level, N) 92 | masks = flattened_hms.new_zeros((flattened_hms.shape[0], 1)) 93 | masks[pos_inds] = 1 94 | masks = _decompose_level(masks, shapes_per_level, N) 95 | for i in range(len(images)): 96 | image = images[i].detach().cpu().numpy().transpose(1, 2, 0) 97 | color_maps = [] 98 | for l in range(len(gt_hms)): 99 | color_map = _get_color_image( 100 | gt_hms[l][i].detach().cpu().numpy()) 101 | color_maps.append(color_map) 102 | cv2.imshow('gthm_{}'.format(l), color_map) 103 | blend = _blend_image_heatmaps(image.copy(), color_maps) 104 | if gt_instances is not None: 105 | bboxes = gt_instances[i].gt_boxes.tensor 106 | for j in range(len(bboxes)): 107 | bbox = bboxes[j] 108 | cv2.rectangle( 109 | blend, 110 | (int(bbox[0]), int(bbox[1])), 111 | (int(bbox[2]), int(bbox[3])), 112 | (0, 0, 255), 3, cv2.LINE_AA) 113 | 114 | for j in range(len(pos_inds)): 115 | image_id, l = _ind2il(pos_inds[j], shapes_per_level, N) 116 | if image_id != i: 117 | continue 118 | loc = locations[pos_inds[j]] 119 | cv2.drawMarker( 120 | blend, (int(loc[0]), int(loc[1])), (0, 255, 255), 121 | markerSize=(l + 1) * 16) 122 | 123 | for j in range(len(reg_inds)): 124 | image_id, l = _ind2il(reg_inds[j], shapes_per_level, N) 125 | if image_id != i: 126 | continue 127 | ltrb = reg_targets[reg_inds[j]] 128 | ltrb *= strides[l] 129 | loc = locations[reg_inds[j]] 130 | bbox = [(loc[0] - ltrb[0]), (loc[1] - ltrb[1]), 131 | (loc[0] + ltrb[2]), (loc[1] + ltrb[3])] 132 | cv2.rectangle( 133 | blend, 134 | (int(bbox[0]), int(bbox[1])), 135 | (int(bbox[2]), int(bbox[3])), 136 | (255, 0, 0), 1, cv2.LINE_AA) 137 | cv2.circle(blend, (int(loc[0]), int(loc[1])), 2, (255, 0, 0), -1) 138 | 139 | cv2.imshow('blend', blend) 140 | cv2.waitKey() 141 | 142 | 143 | def debug_test( 144 | images, logits_pred, reg_pred, agn_hm_pred=[], preds=[], 145 | vis_thresh=0.3, debug_show_name=False, mult_agn=False): 146 | ''' 147 | images: N x 3 x H x W 148 | class_target: LNHiWi x C 149 | cat_agn_heatmap: LNHiWi 150 | shapes_per_level: L x 2 [(H_i, W_i)] 151 | ''' 152 | N = len(images) 153 | for i in range(len(images)): 154 | image = images[i].detach().cpu().numpy().transpose(1, 2, 0) 155 | result = image.copy().astype(np.uint8) 156 | pred_image = image.copy().astype(np.uint8) 157 | color_maps = [] 158 | L = len(logits_pred) 159 | for l in range(L): 160 | if logits_pred[0] is not None: 161 | stride = min(image.shape[0], image.shape[1]) / min( 162 | logits_pred[l][i].shape[1], logits_pred[l][i].shape[2]) 163 | else: 164 | stride = min(image.shape[0], image.shape[1]) / min( 165 | agn_hm_pred[l][i].shape[1], agn_hm_pred[l][i].shape[2]) 166 | stride = stride if stride < 60 else 64 if stride < 100 else 128 167 | if logits_pred[0] is not None: 168 | if mult_agn: 169 | logits_pred[l][i] = logits_pred[l][i] * agn_hm_pred[l][i] 170 | color_map = _get_color_image( 171 | logits_pred[l][i].detach().cpu().numpy()) 172 | color_maps.append(color_map) 173 | cv2.imshow('predhm_{}'.format(l), color_map) 174 | 175 | if debug_show_name: 176 | from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES 177 | cat2name = [x['name'] for x in LVIS_CATEGORIES] 178 | for j in range(len(preds[i].scores) if preds is not None else 0): 179 | if preds[i].scores[j] > vis_thresh: 180 | bbox = preds[i].proposal_boxes[j] \ 181 | if preds[i].has('proposal_boxes') else \ 182 | preds[i].pred_boxes[j] 183 | bbox = bbox.tensor[0].detach().cpu().numpy().astype(np.int32) 184 | cat = int(preds[i].pred_classes[j]) \ 185 | if preds[i].has('pred_classes') else 0 186 | cl = COLORS[cat, 0, 0] 187 | cv2.rectangle( 188 | pred_image, (int(bbox[0]), int(bbox[1])), 189 | (int(bbox[2]), int(bbox[3])), 190 | (int(cl[0]), int(cl[1]), int(cl[2])), 2, cv2.LINE_AA) 191 | if debug_show_name: 192 | txt = '{}{:.1f}'.format( 193 | cat2name[cat] if cat > 0 else '', 194 | preds[i].scores[j]) 195 | font = cv2.FONT_HERSHEY_SIMPLEX 196 | cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0] 197 | cv2.rectangle( 198 | pred_image, 199 | (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)), 200 | (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)), 201 | (int(cl[0]), int(cl[1]), int(cl[2])), -1) 202 | cv2.putText( 203 | pred_image, txt, (int(bbox[0]), int(bbox[1] - 2)), 204 | font, 0.5, (0, 0, 0), thickness=1, lineType=cv2.LINE_AA) 205 | 206 | 207 | if agn_hm_pred[l] is not None: 208 | agn_hm_ = agn_hm_pred[l][i, 0, :, :, None].detach().cpu().numpy() 209 | agn_hm_ = (agn_hm_ * np.array([255, 255, 255]).reshape( 210 | 1, 1, 3)).astype(np.uint8) 211 | cv2.imshow('agn_hm_{}'.format(l), agn_hm_) 212 | blend = _blend_image_heatmaps(image.copy(), color_maps) 213 | cv2.imshow('blend', blend) 214 | cv2.imshow('preds', pred_image) 215 | cv2.waitKey() 216 | 217 | global cnt 218 | cnt = 0 219 | 220 | def debug_second_stage(images, instances, proposals=None, vis_thresh=0.3, 221 | save_debug=False, debug_show_name=False, image_labels=[], 222 | save_debug_path='output/save_debug/', 223 | bgr=False): 224 | images = _imagelist_to_tensor(images) 225 | if 'COCO' in save_debug_path: 226 | from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES 227 | cat2name = [x['name'] for x in COCO_CATEGORIES] 228 | else: 229 | from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES 230 | cat2name = ['({}){}'.format(x['frequency'], x['name']) \ 231 | for x in LVIS_CATEGORIES] 232 | for i in range(len(images)): 233 | image = images[i].detach().cpu().numpy().transpose(1, 2, 0).astype(np.uint8).copy() 234 | if bgr: 235 | image = image[:, :, ::-1].copy() 236 | if instances[i].has('gt_boxes'): 237 | bboxes = instances[i].gt_boxes.tensor.cpu().numpy() 238 | scores = np.ones(bboxes.shape[0]) 239 | cats = instances[i].gt_classes.cpu().numpy() 240 | else: 241 | bboxes = instances[i].pred_boxes.tensor.cpu().numpy() 242 | scores = instances[i].scores.cpu().numpy() 243 | cats = instances[i].pred_classes.cpu().numpy() 244 | for j in range(len(bboxes)): 245 | if scores[j] > vis_thresh: 246 | bbox = bboxes[j] 247 | cl = COLORS[cats[j], 0, 0] 248 | cl = (int(cl[0]), int(cl[1]), int(cl[2])) 249 | cv2.rectangle( 250 | image, 251 | (int(bbox[0]), int(bbox[1])), 252 | (int(bbox[2]), int(bbox[3])), 253 | cl, 2, cv2.LINE_AA) 254 | if debug_show_name: 255 | cat = cats[j] 256 | txt = '{}{:.1f}'.format( 257 | cat2name[cat] if cat > 0 else '', 258 | scores[j]) 259 | font = cv2.FONT_HERSHEY_SIMPLEX 260 | cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0] 261 | cv2.rectangle( 262 | image, 263 | (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)), 264 | (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)), 265 | (int(cl[0]), int(cl[1]), int(cl[2])), -1) 266 | cv2.putText( 267 | image, txt, (int(bbox[0]), int(bbox[1] - 2)), 268 | font, 0.5, (0, 0, 0), thickness=1, lineType=cv2.LINE_AA) 269 | if proposals is not None: 270 | proposal_image = images[i].detach().cpu().numpy().transpose(1, 2, 0).astype(np.uint8).copy() 271 | if bgr: 272 | proposal_image = proposal_image.copy() 273 | else: 274 | proposal_image = proposal_image[:, :, ::-1].copy() 275 | bboxes = proposals[i].proposal_boxes.tensor.cpu().numpy() 276 | if proposals[i].has('scores'): 277 | scores = proposals[i].scores.detach().cpu().numpy() 278 | else: 279 | scores = proposals[i].objectness_logits.detach().cpu().numpy() 280 | # selected = -1 281 | # if proposals[i].has('image_loss'): 282 | # selected = proposals[i].image_loss.argmin() 283 | if proposals[i].has('selected'): 284 | selected = proposals[i].selected 285 | else: 286 | selected = [-1 for _ in range(len(bboxes))] 287 | for j in range(len(bboxes)): 288 | if scores[j] > vis_thresh or selected[j] >= 0: 289 | bbox = bboxes[j] 290 | cl = (209, 159, 83) 291 | th = 2 292 | if selected[j] >= 0: 293 | cl = (0, 0, 0xa4) 294 | th = 4 295 | cv2.rectangle( 296 | proposal_image, 297 | (int(bbox[0]), int(bbox[1])), 298 | (int(bbox[2]), int(bbox[3])), 299 | cl, th, cv2.LINE_AA) 300 | if selected[j] >= 0 and debug_show_name: 301 | cat = selected[j].item() 302 | txt = '{}'.format(cat2name[cat]) 303 | font = cv2.FONT_HERSHEY_SIMPLEX 304 | cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0] 305 | cv2.rectangle( 306 | proposal_image, 307 | (int(bbox[0]), int(bbox[1] - cat_size[1] - 2)), 308 | (int(bbox[0] + cat_size[0]), int(bbox[1] - 2)), 309 | (int(cl[0]), int(cl[1]), int(cl[2])), -1) 310 | cv2.putText( 311 | proposal_image, txt, 312 | (int(bbox[0]), int(bbox[1] - 2)), 313 | font, 0.5, (0, 0, 0), thickness=1, 314 | lineType=cv2.LINE_AA) 315 | 316 | if save_debug: 317 | global cnt 318 | cnt = (cnt + 1) % 5000 319 | if not os.path.exists(save_debug_path): 320 | os.mkdir(save_debug_path) 321 | save_name = '{}/{:05d}.jpg'.format(save_debug_path, cnt) 322 | if i < len(image_labels): 323 | image_label = image_labels[i] 324 | save_name = '{}/{:05d}'.format(save_debug_path, cnt) 325 | for x in image_label: 326 | class_name = cat2name[x] 327 | save_name = save_name + '|{}'.format(class_name) 328 | save_name = save_name + '.jpg' 329 | cv2.imwrite(save_name, proposal_image) 330 | else: 331 | cv2.imshow('image', image) 332 | if proposals is not None: 333 | cv2.imshow('proposals', proposal_image) 334 | cv2.waitKey() -------------------------------------------------------------------------------- /mmovod/modeling/meta_arch/__pycache__/custom_rcnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/meta_arch/__pycache__/custom_rcnn.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/modeling/meta_arch/__pycache__/d2_deformable_detr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/meta_arch/__pycache__/d2_deformable_detr.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/modeling/meta_arch/custom_rcnn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import copy 3 | import logging 4 | import numpy as np 5 | from typing import Dict, List, Optional, Tuple 6 | import torch 7 | from torch import nn 8 | import json 9 | from detectron2.utils.events import get_event_storage 10 | from detectron2.config import configurable 11 | from detectron2.structures import ImageList, Instances, Boxes 12 | import detectron2.utils.comm as comm 13 | 14 | from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY 15 | from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN 16 | from detectron2.modeling.postprocessing import detector_postprocess 17 | from detectron2.utils.visualizer import Visualizer, _create_text_labels 18 | from detectron2.data.detection_utils import convert_image_to_rgb 19 | 20 | from torch.cuda.amp import autocast 21 | from ..text.text_encoder import build_text_encoder 22 | from ..utils import load_class_freq, get_fed_loss_inds 23 | 24 | @META_ARCH_REGISTRY.register() 25 | class CustomRCNN(GeneralizedRCNN): 26 | ''' 27 | Add image labels 28 | ''' 29 | @configurable 30 | def __init__( 31 | self, 32 | with_image_labels = False, 33 | dataset_loss_weight = [], 34 | fp16 = False, 35 | sync_caption_batch = False, 36 | roi_head_name = '', 37 | cap_batch_ratio = 4, 38 | with_caption = False, 39 | dynamic_classifier = False, 40 | **kwargs): 41 | """ 42 | """ 43 | self.with_image_labels = with_image_labels 44 | self.dataset_loss_weight = dataset_loss_weight 45 | self.fp16 = fp16 46 | self.with_caption = with_caption 47 | self.sync_caption_batch = sync_caption_batch 48 | self.roi_head_name = roi_head_name 49 | self.cap_batch_ratio = cap_batch_ratio 50 | self.dynamic_classifier = dynamic_classifier 51 | self.return_proposal = False 52 | if self.dynamic_classifier: 53 | self.freq_weight = kwargs.pop('freq_weight') 54 | self.num_classes = kwargs.pop('num_classes') 55 | self.num_sample_cats = kwargs.pop('num_sample_cats') 56 | super().__init__(**kwargs) 57 | assert self.proposal_generator is not None 58 | if self.with_caption: 59 | assert not self.dynamic_classifier 60 | self.text_encoder = build_text_encoder(pretrain=True) 61 | for v in self.text_encoder.parameters(): 62 | v.requires_grad = False 63 | 64 | 65 | @classmethod 66 | def from_config(cls, cfg): 67 | ret = super().from_config(cfg) 68 | ret.update({ 69 | 'with_image_labels': cfg.WITH_IMAGE_LABELS, 70 | 'dataset_loss_weight': cfg.MODEL.DATASET_LOSS_WEIGHT, 71 | 'fp16': cfg.FP16, 72 | 'with_caption': cfg.MODEL.WITH_CAPTION, 73 | 'sync_caption_batch': cfg.MODEL.SYNC_CAPTION_BATCH, 74 | 'dynamic_classifier': cfg.MODEL.DYNAMIC_CLASSIFIER, 75 | 'roi_head_name': cfg.MODEL.ROI_HEADS.NAME, 76 | 'cap_batch_ratio': cfg.MODEL.CAP_BATCH_RATIO, 77 | }) 78 | if ret['dynamic_classifier']: 79 | ret['freq_weight'] = load_class_freq( 80 | cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH, 81 | cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT) 82 | ret['num_classes'] = cfg.MODEL.ROI_HEADS.NUM_CLASSES 83 | ret['num_sample_cats'] = cfg.MODEL.NUM_SAMPLE_CATS 84 | return ret 85 | 86 | 87 | def inference( 88 | self, 89 | batched_inputs: Tuple[Dict[str, torch.Tensor]], 90 | detected_instances: Optional[List[Instances]] = None, 91 | do_postprocess: bool = True, 92 | ): 93 | assert not self.training 94 | assert detected_instances is None 95 | 96 | images = self.preprocess_image(batched_inputs) 97 | features = self.backbone(images.tensor) 98 | proposals, _ = self.proposal_generator(images, features, None) 99 | results, _ = self.roi_heads(images, features, proposals) 100 | if do_postprocess: 101 | assert not torch.jit.is_scripting(), \ 102 | "Scripting is not supported for postprocess." 103 | return CustomRCNN._postprocess( 104 | results, batched_inputs, images.image_sizes) 105 | else: 106 | return results 107 | 108 | 109 | def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): 110 | """ 111 | Add ann_type 112 | Ignore proposal loss when training with image labels 113 | """ 114 | if not self.training: 115 | return self.inference(batched_inputs) 116 | 117 | images = self.preprocess_image(batched_inputs) 118 | 119 | ann_type = 'box' 120 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 121 | if self.with_image_labels: 122 | for inst, x in zip(gt_instances, batched_inputs): 123 | inst._ann_type = x['ann_type'] 124 | inst._pos_category_ids = x['pos_category_ids'] 125 | ann_types = [x['ann_type'] for x in batched_inputs] 126 | assert len(set(ann_types)) == 1 127 | ann_type = ann_types[0] 128 | if ann_type in ['prop', 'proptag']: 129 | for t in gt_instances: 130 | t.gt_classes *= 0 131 | 132 | if self.fp16: # TODO (zhouxy): improve 133 | with autocast(): 134 | features = self.backbone(images.tensor.half()) 135 | features = {k: v.float() for k, v in features.items()} 136 | else: 137 | features = self.backbone(images.tensor) 138 | 139 | cls_features, cls_inds, caption_features = None, None, None 140 | 141 | if self.with_caption and 'caption' in ann_type: 142 | inds = [torch.randint(len(x['captions']), (1,))[0].item() \ 143 | for x in batched_inputs] 144 | caps = [x['captions'][ind] for ind, x in zip(inds, batched_inputs)] 145 | caption_features = self.text_encoder(caps).float() 146 | if self.sync_caption_batch: 147 | caption_features = self._sync_caption_features( 148 | caption_features, ann_type, len(batched_inputs)) 149 | 150 | if self.dynamic_classifier and ann_type != 'caption': 151 | cls_inds = self._sample_cls_inds(gt_instances, ann_type) # inds, inv_inds 152 | ind_with_bg = cls_inds[0].tolist() + [-1] 153 | cls_features = self.roi_heads.box_predictor[ 154 | 0].cls_score.zs_weight[:, ind_with_bg].permute(1, 0).contiguous() 155 | 156 | classifier_info = cls_features, cls_inds, caption_features 157 | proposals, proposal_losses = self.proposal_generator( 158 | images, features, gt_instances) 159 | 160 | if self.roi_head_name in ['StandardROIHeads', 'CascadeROIHeads']: 161 | proposals, detector_losses = self.roi_heads( 162 | images, features, proposals, gt_instances) 163 | else: 164 | proposals, detector_losses = self.roi_heads( 165 | images, features, proposals, gt_instances, 166 | ann_type=ann_type, classifier_info=classifier_info) 167 | 168 | if self.vis_period > 0: 169 | storage = get_event_storage() 170 | if storage.iter % self.vis_period == 0: 171 | self.visualize_training(batched_inputs, proposals) 172 | 173 | losses = {} 174 | losses.update(detector_losses) 175 | if self.with_image_labels: 176 | if ann_type in ['box', 'prop', 'proptag']: 177 | losses.update(proposal_losses) 178 | else: # ignore proposal loss for non-bbox data 179 | losses.update({k: v * 0 for k, v in proposal_losses.items()}) 180 | else: 181 | losses.update(proposal_losses) 182 | if len(self.dataset_loss_weight) > 0: 183 | dataset_sources = [x['dataset_source'] for x in batched_inputs] 184 | assert len(set(dataset_sources)) == 1 185 | dataset_source = dataset_sources[0] 186 | for k in losses: 187 | losses[k] *= self.dataset_loss_weight[dataset_source] 188 | 189 | if self.return_proposal: 190 | return proposals, losses 191 | else: 192 | return losses 193 | 194 | 195 | def _sync_caption_features(self, caption_features, ann_type, BS): 196 | has_caption_feature = (caption_features is not None) 197 | BS = (BS * self.cap_batch_ratio) if (ann_type == 'box') else BS 198 | rank = torch.full( 199 | (BS, 1), comm.get_rank(), dtype=torch.float32, 200 | device=self.device) 201 | if not has_caption_feature: 202 | caption_features = rank.new_zeros((BS, 512)) 203 | caption_features = torch.cat([caption_features, rank], dim=1) 204 | global_caption_features = comm.all_gather(caption_features) 205 | caption_features = torch.cat( 206 | [x.to(self.device) for x in global_caption_features], dim=0) \ 207 | if has_caption_feature else None # (NB) x (D + 1) 208 | return caption_features 209 | 210 | 211 | def _sample_cls_inds(self, gt_instances, ann_type='box'): 212 | if ann_type == 'box': 213 | gt_classes = torch.cat( 214 | [x.gt_classes for x in gt_instances]) 215 | C = len(self.freq_weight) 216 | freq_weight = self.freq_weight 217 | else: 218 | gt_classes = torch.cat( 219 | [torch.tensor( 220 | x._pos_category_ids, 221 | dtype=torch.long, device=x.gt_classes.device) \ 222 | for x in gt_instances]) 223 | C = self.num_classes 224 | freq_weight = None 225 | assert gt_classes.max() < C, '{} {}'.format(gt_classes.max(), C) 226 | inds = get_fed_loss_inds( 227 | gt_classes, self.num_sample_cats, C, 228 | weight=freq_weight) 229 | cls_id_map = gt_classes.new_full( 230 | (self.num_classes + 1,), len(inds)) 231 | cls_id_map[inds] = torch.arange(len(inds), device=cls_id_map.device) 232 | return inds, cls_id_map -------------------------------------------------------------------------------- /mmovod/modeling/roi_heads/__pycache__/detic_fast_rcnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/roi_heads/__pycache__/detic_fast_rcnn.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/modeling/roi_heads/__pycache__/detic_roi_heads.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/roi_heads/__pycache__/detic_roi_heads.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/modeling/roi_heads/__pycache__/res5_roi_heads.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/roi_heads/__pycache__/res5_roi_heads.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/modeling/roi_heads/__pycache__/zero_shot_classifier.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/roi_heads/__pycache__/zero_shot_classifier.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/modeling/roi_heads/detic_roi_heads.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import copy 3 | import numpy as np 4 | import json 5 | import math 6 | import torch 7 | from torch import nn 8 | from torch.autograd.function import Function 9 | from typing import Dict, List, Optional, Tuple, Union 10 | from torch.nn import functional as F 11 | 12 | from detectron2.config import configurable 13 | from detectron2.layers import ShapeSpec 14 | from detectron2.layers import batched_nms 15 | from detectron2.structures import Boxes, Instances, pairwise_iou 16 | from detectron2.utils.events import get_event_storage 17 | 18 | from detectron2.modeling.box_regression import Box2BoxTransform 19 | from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference 20 | from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads 21 | from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient 22 | from detectron2.modeling.roi_heads.box_head import build_box_head 23 | from .detic_fast_rcnn import DeticFastRCNNOutputLayers 24 | from ..debug import debug_second_stage 25 | 26 | from torch.cuda.amp import autocast 27 | 28 | @ROI_HEADS_REGISTRY.register() 29 | class DeticCascadeROIHeads(CascadeROIHeads): 30 | @configurable 31 | def __init__( 32 | self, 33 | *, 34 | mult_proposal_score: bool = False, 35 | with_image_labels: bool = False, 36 | add_image_box: bool = False, 37 | image_box_size: float = 1.0, 38 | ws_num_props: int = 512, 39 | add_feature_to_prop: bool = False, 40 | mask_weight: float = 1.0, 41 | one_class_per_proposal: bool = False, 42 | **kwargs, 43 | ): 44 | super().__init__(**kwargs) 45 | self.mult_proposal_score = mult_proposal_score 46 | self.with_image_labels = with_image_labels 47 | self.add_image_box = add_image_box 48 | self.image_box_size = image_box_size 49 | self.ws_num_props = ws_num_props 50 | self.add_feature_to_prop = add_feature_to_prop 51 | self.mask_weight = mask_weight 52 | self.one_class_per_proposal = one_class_per_proposal 53 | 54 | @classmethod 55 | def from_config(cls, cfg, input_shape): 56 | ret = super().from_config(cfg, input_shape) 57 | ret.update({ 58 | 'mult_proposal_score': cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE, 59 | 'with_image_labels': cfg.WITH_IMAGE_LABELS, 60 | 'add_image_box': cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX, 61 | 'image_box_size': cfg.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE, 62 | 'ws_num_props': cfg.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS, 63 | 'add_feature_to_prop': cfg.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP, 64 | 'mask_weight': cfg.MODEL.ROI_HEADS.MASK_WEIGHT, 65 | 'one_class_per_proposal': cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL, 66 | }) 67 | return ret 68 | 69 | 70 | @classmethod 71 | def _init_box_head(self, cfg, input_shape): 72 | ret = super()._init_box_head(cfg, input_shape) 73 | del ret['box_predictors'] 74 | cascade_bbox_reg_weights = cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS 75 | box_predictors = [] 76 | for box_head, bbox_reg_weights in zip(ret['box_heads'], \ 77 | cascade_bbox_reg_weights): 78 | box_predictors.append( 79 | DeticFastRCNNOutputLayers( 80 | cfg, box_head.output_shape, 81 | box2box_transform=Box2BoxTransform(weights=bbox_reg_weights) 82 | )) 83 | ret['box_predictors'] = box_predictors 84 | return ret 85 | 86 | def _forward_box(self, features, proposals, targets=None, 87 | ann_type='box', classifier_info=(None,None,None)): 88 | """ 89 | Add mult proposal scores at testing 90 | Add ann_type 91 | """ 92 | if (not self.training) and self.mult_proposal_score: 93 | if len(proposals) > 0 and proposals[0].has('scores'): 94 | proposal_scores = [p.get('scores') for p in proposals] 95 | else: 96 | proposal_scores = [p.get('objectness_logits') for p in proposals] 97 | 98 | features = [features[f] for f in self.box_in_features] 99 | head_outputs = [] # (predictor, predictions, proposals) 100 | prev_pred_boxes = None 101 | image_sizes = [x.image_size for x in proposals] 102 | 103 | for k in range(self.num_cascade_stages): 104 | if k > 0: 105 | proposals = self._create_proposals_from_boxes( 106 | prev_pred_boxes, image_sizes, 107 | logits=[p.objectness_logits for p in proposals]) 108 | if self.training and ann_type in ['box']: 109 | proposals = self._match_and_label_boxes( 110 | proposals, k, targets) 111 | predictions = self._run_stage(features, proposals, k, 112 | classifier_info=classifier_info) 113 | prev_pred_boxes = self.box_predictor[k].predict_boxes( 114 | (predictions[0], predictions[1]), proposals) 115 | head_outputs.append((self.box_predictor[k], predictions, proposals)) 116 | 117 | if self.training: 118 | losses = {} 119 | storage = get_event_storage() 120 | for stage, (predictor, predictions, proposals) in enumerate(head_outputs): 121 | with storage.name_scope("stage{}".format(stage)): 122 | if ann_type != 'box': 123 | stage_losses = {} 124 | if ann_type in ['image', 'caption', 'captiontag']: 125 | image_labels = [x._pos_category_ids for x in targets] 126 | # import pdb; pdb.set_trace() 127 | weak_losses = predictor.image_label_losses( 128 | predictions, proposals, image_labels, 129 | classifier_info=classifier_info, 130 | ann_type=ann_type) 131 | stage_losses.update(weak_losses) 132 | else: # supervised 133 | stage_losses = predictor.losses( 134 | (predictions[0], predictions[1]), proposals, 135 | classifier_info=classifier_info) 136 | if self.with_image_labels: 137 | stage_losses['image_loss'] = \ 138 | predictions[0].new_zeros([1])[0] 139 | losses.update({k + "_stage{}".format(stage): v \ 140 | for k, v in stage_losses.items()}) 141 | return losses 142 | else: 143 | # Each is a list[Tensor] of length #image. Each tensor is Ri x (K+1) 144 | scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs] 145 | scores = [ 146 | sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages) 147 | for scores_per_image in zip(*scores_per_stage) 148 | ] 149 | if self.mult_proposal_score: 150 | scores = [(s * ps[:, None]) ** 0.5 \ 151 | for s, ps in zip(scores, proposal_scores)] 152 | if self.one_class_per_proposal: 153 | scores = [s * (s == s[:, :-1].max(dim=1)[0][:, None]).float() for s in scores] 154 | predictor, predictions, proposals = head_outputs[-1] 155 | boxes = predictor.predict_boxes( 156 | (predictions[0], predictions[1]), proposals) 157 | pred_instances, _ = fast_rcnn_inference( 158 | boxes, 159 | scores, 160 | image_sizes, 161 | predictor.test_score_thresh, 162 | predictor.test_nms_thresh, 163 | predictor.test_topk_per_image, 164 | ) 165 | return pred_instances 166 | 167 | def forward(self, images, features, proposals, targets=None, 168 | ann_type='box', classifier_info=(None,None,None)): 169 | ''' 170 | enable debug and image labels 171 | classifier_info is shared across the batch 172 | ''' 173 | if self.training: 174 | if ann_type in ['box', 'prop', 'proptag']: 175 | proposals = self.label_and_sample_proposals( 176 | proposals, targets) 177 | else: 178 | proposals = self.get_top_proposals(proposals) 179 | 180 | losses = self._forward_box(features, proposals, targets, \ 181 | ann_type=ann_type, classifier_info=classifier_info) 182 | if ann_type == 'box' and targets[0].has('gt_masks'): 183 | mask_losses = self._forward_mask(features, proposals) 184 | losses.update({k: v * self.mask_weight \ 185 | for k, v in mask_losses.items()}) 186 | losses.update(self._forward_keypoint(features, proposals)) 187 | else: 188 | losses.update(self._get_empty_mask_loss( 189 | features, proposals, 190 | device=proposals[0].objectness_logits.device)) 191 | return proposals, losses 192 | else: 193 | pred_instances = self._forward_box( 194 | features, proposals, classifier_info=classifier_info) 195 | pred_instances = self.forward_with_given_boxes(features, pred_instances) 196 | return pred_instances, {} 197 | 198 | 199 | def get_top_proposals(self, proposals): 200 | for i in range(len(proposals)): 201 | proposals[i].proposal_boxes.clip(proposals[i].image_size) 202 | proposals = [p[:self.ws_num_props] for p in proposals] 203 | for i, p in enumerate(proposals): 204 | p.proposal_boxes.tensor = p.proposal_boxes.tensor.detach() 205 | if self.add_image_box: 206 | proposals[i] = self._add_image_box(p) 207 | return proposals 208 | 209 | 210 | def _add_image_box(self, p): 211 | image_box = Instances(p.image_size) 212 | n = 1 213 | h, w = p.image_size 214 | f = self.image_box_size 215 | image_box.proposal_boxes = Boxes( 216 | p.proposal_boxes.tensor.new_tensor( 217 | [w * (1. - f) / 2., 218 | h * (1. - f) / 2., 219 | w * (1. - (1. - f) / 2.), 220 | h * (1. - (1. - f) / 2.)] 221 | ).view(n, 4)) 222 | image_box.objectness_logits = p.objectness_logits.new_ones(n) 223 | return Instances.cat([p, image_box]) 224 | 225 | 226 | def _get_empty_mask_loss(self, features, proposals, device): 227 | if self.mask_on: 228 | return {'loss_mask': torch.zeros( 229 | (1, ), device=device, dtype=torch.float32)[0]} 230 | else: 231 | return {} 232 | 233 | 234 | def _create_proposals_from_boxes(self, boxes, image_sizes, logits): 235 | """ 236 | Add objectness_logits 237 | """ 238 | boxes = [Boxes(b.detach()) for b in boxes] 239 | proposals = [] 240 | for boxes_per_image, image_size, logit in zip( 241 | boxes, image_sizes, logits): 242 | boxes_per_image.clip(image_size) 243 | if self.training: 244 | inds = boxes_per_image.nonempty() 245 | boxes_per_image = boxes_per_image[inds] 246 | logit = logit[inds] 247 | prop = Instances(image_size) 248 | prop.proposal_boxes = boxes_per_image 249 | prop.objectness_logits = logit 250 | proposals.append(prop) 251 | return proposals 252 | 253 | 254 | def _run_stage(self, features, proposals, stage, \ 255 | classifier_info=(None,None,None)): 256 | """ 257 | Support classifier_info and add_feature_to_prop 258 | """ 259 | pool_boxes = [x.proposal_boxes for x in proposals] 260 | box_features = self.box_pooler(features, pool_boxes) 261 | box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages) 262 | box_features = self.box_head[stage](box_features) 263 | if self.add_feature_to_prop: 264 | feats_per_image = box_features.split( 265 | [len(p) for p in proposals], dim=0) 266 | for feat, p in zip(feats_per_image, proposals): 267 | p.feat = feat 268 | return self.box_predictor[stage]( 269 | box_features, 270 | classifier_info=classifier_info) 271 | -------------------------------------------------------------------------------- /mmovod/modeling/roi_heads/res5_roi_heads.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import inspect 3 | import logging 4 | import numpy as np 5 | from typing import Dict, List, Optional, Tuple 6 | import torch 7 | from torch import nn 8 | 9 | from detectron2.config import configurable 10 | from detectron2.layers import ShapeSpec, nonzero_tuple 11 | from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou 12 | from detectron2.utils.events import get_event_storage 13 | from detectron2.utils.registry import Registry 14 | 15 | from detectron2.modeling.box_regression import Box2BoxTransform 16 | from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference 17 | from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads 18 | from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient 19 | from detectron2.modeling.roi_heads.box_head import build_box_head 20 | 21 | from .detic_fast_rcnn import DeticFastRCNNOutputLayers 22 | from ..debug import debug_second_stage 23 | 24 | from torch.cuda.amp import autocast 25 | 26 | @ROI_HEADS_REGISTRY.register() 27 | class CustomRes5ROIHeads(Res5ROIHeads): 28 | @configurable 29 | def __init__(self, **kwargs): 30 | cfg = kwargs.pop('cfg') 31 | super().__init__(**kwargs) 32 | stage_channel_factor = 2 ** 3 33 | out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor 34 | 35 | self.with_image_labels = cfg.WITH_IMAGE_LABELS 36 | self.ws_num_props = cfg.MODEL.ROI_BOX_HEAD.WS_NUM_PROPS 37 | self.add_image_box = cfg.MODEL.ROI_BOX_HEAD.ADD_IMAGE_BOX 38 | self.add_feature_to_prop = cfg.MODEL.ROI_BOX_HEAD.ADD_FEATURE_TO_PROP 39 | self.image_box_size = cfg.MODEL.ROI_BOX_HEAD.IMAGE_BOX_SIZE 40 | self.box_predictor = DeticFastRCNNOutputLayers( 41 | cfg, ShapeSpec(channels=out_channels, height=1, width=1) 42 | ) 43 | 44 | self.save_debug = cfg.SAVE_DEBUG 45 | self.save_debug_path = cfg.SAVE_DEBUG_PATH 46 | if self.save_debug: 47 | self.debug_show_name = cfg.DEBUG_SHOW_NAME 48 | self.vis_thresh = cfg.VIS_THRESH 49 | self.pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to( 50 | torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) 51 | self.pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to( 52 | torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) 53 | self.bgr = (cfg.INPUT.FORMAT == 'BGR') 54 | 55 | @classmethod 56 | def from_config(cls, cfg, input_shape): 57 | ret = super().from_config(cfg, input_shape) 58 | ret['cfg'] = cfg 59 | return ret 60 | 61 | def forward(self, images, features, proposals, targets=None, 62 | ann_type='box', classifier_info=(None,None,None)): 63 | ''' 64 | enable debug and image labels 65 | classifier_info is shared across the batch 66 | ''' 67 | if not self.save_debug: 68 | del images 69 | 70 | if self.training: 71 | if ann_type in ['box']: 72 | proposals = self.label_and_sample_proposals( 73 | proposals, targets) 74 | else: 75 | proposals = self.get_top_proposals(proposals) 76 | 77 | proposal_boxes = [x.proposal_boxes for x in proposals] 78 | box_features = self._shared_roi_transform( 79 | [features[f] for f in self.in_features], proposal_boxes 80 | ) 81 | predictions = self.box_predictor( 82 | box_features.mean(dim=[2, 3]), 83 | classifier_info=classifier_info) 84 | 85 | if self.add_feature_to_prop: 86 | feats_per_image = box_features.mean(dim=[2, 3]).split( 87 | [len(p) for p in proposals], dim=0) 88 | for feat, p in zip(feats_per_image, proposals): 89 | p.feat = feat 90 | 91 | if self.training: 92 | del features 93 | if (ann_type != 'box'): 94 | image_labels = [x._pos_category_ids for x in targets] 95 | losses = self.box_predictor.image_label_losses( 96 | predictions, proposals, image_labels, 97 | classifier_info=classifier_info, 98 | ann_type=ann_type) 99 | else: 100 | losses = self.box_predictor.losses( 101 | (predictions[0], predictions[1]), proposals) 102 | if self.with_image_labels: 103 | assert 'image_loss' not in losses 104 | losses['image_loss'] = predictions[0].new_zeros([1])[0] 105 | if self.save_debug: 106 | denormalizer = lambda x: x * self.pixel_std + self.pixel_mean 107 | if ann_type != 'box': 108 | image_labels = [x._pos_category_ids for x in targets] 109 | else: 110 | image_labels = [[] for x in targets] 111 | debug_second_stage( 112 | [denormalizer(x.clone()) for x in images], 113 | targets, proposals=proposals, 114 | save_debug=self.save_debug, 115 | debug_show_name=self.debug_show_name, 116 | vis_thresh=self.vis_thresh, 117 | image_labels=image_labels, 118 | save_debug_path=self.save_debug_path, 119 | bgr=self.bgr) 120 | return proposals, losses 121 | else: 122 | pred_instances, _ = self.box_predictor.inference(predictions, proposals) 123 | pred_instances = self.forward_with_given_boxes(features, pred_instances) 124 | if self.save_debug: 125 | denormalizer = lambda x: x * self.pixel_std + self.pixel_mean 126 | debug_second_stage( 127 | [denormalizer(x.clone()) for x in images], 128 | pred_instances, proposals=proposals, 129 | save_debug=self.save_debug, 130 | debug_show_name=self.debug_show_name, 131 | vis_thresh=self.vis_thresh, 132 | save_debug_path=self.save_debug_path, 133 | bgr=self.bgr) 134 | return pred_instances, {} 135 | 136 | def get_top_proposals(self, proposals): 137 | for i in range(len(proposals)): 138 | proposals[i].proposal_boxes.clip(proposals[i].image_size) 139 | proposals = [p[:self.ws_num_props] for p in proposals] 140 | for i, p in enumerate(proposals): 141 | p.proposal_boxes.tensor = p.proposal_boxes.tensor.detach() 142 | if self.add_image_box: 143 | proposals[i] = self._add_image_box(p) 144 | return proposals 145 | 146 | def _add_image_box(self, p, use_score=False): 147 | image_box = Instances(p.image_size) 148 | n = 1 149 | h, w = p.image_size 150 | if self.image_box_size < 1.0: 151 | f = self.image_box_size 152 | image_box.proposal_boxes = Boxes( 153 | p.proposal_boxes.tensor.new_tensor( 154 | [w * (1. - f) / 2., 155 | h * (1. - f) / 2., 156 | w * (1. - (1. - f) / 2.), 157 | h * (1. - (1. - f) / 2.)] 158 | ).view(n, 4)) 159 | else: 160 | image_box.proposal_boxes = Boxes( 161 | p.proposal_boxes.tensor.new_tensor( 162 | [0, 0, w, h]).view(n, 4)) 163 | if use_score: 164 | image_box.scores = \ 165 | p.objectness_logits.new_ones(n) 166 | image_box.pred_classes = \ 167 | p.objectness_logits.new_zeros(n, dtype=torch.long) 168 | image_box.objectness_logits = \ 169 | p.objectness_logits.new_ones(n) 170 | else: 171 | image_box.objectness_logits = \ 172 | p.objectness_logits.new_ones(n) 173 | return Instances.cat([p, image_box]) -------------------------------------------------------------------------------- /mmovod/modeling/roi_heads/zero_shot_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from detectron2.config import configurable 7 | from detectron2.layers import ShapeSpec 8 | 9 | 10 | class ZeroShotClassifier(nn.Module): 11 | @configurable 12 | def __init__( 13 | self, 14 | input_shape: ShapeSpec, 15 | *, 16 | num_classes: int, 17 | zs_weight_path: str, 18 | zs_weight_dim: int = 512, 19 | use_bias: float = 0.0, 20 | norm_weight: bool = True, 21 | norm_temperature: float = 50.0, 22 | ): 23 | super().__init__() 24 | if isinstance(input_shape, int): # some backward compatibility 25 | input_shape = ShapeSpec(channels=input_shape) 26 | input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1) 27 | self.norm_weight = norm_weight 28 | self.norm_temperature = norm_temperature 29 | 30 | self.use_bias = use_bias < 0 31 | if self.use_bias: 32 | self.cls_bias = nn.Parameter(torch.ones(1) * use_bias) 33 | 34 | self.linear = nn.Linear(input_size, zs_weight_dim) 35 | 36 | if zs_weight_path == 'rand': 37 | zs_weight = torch.randn((zs_weight_dim, num_classes)) 38 | nn.init.normal_(zs_weight, std=0.01) 39 | else: 40 | zs_weight = torch.tensor( 41 | np.load(zs_weight_path), 42 | dtype=torch.float32).permute(1, 0).contiguous() # D x C 43 | zs_weight = torch.cat( 44 | [zs_weight, zs_weight.new_zeros((zs_weight_dim, 1))], 45 | dim=1) # D x (C + 1) 46 | 47 | if self.norm_weight: 48 | zs_weight = F.normalize(zs_weight, p=2, dim=0) 49 | 50 | if zs_weight_path == 'rand': 51 | self.zs_weight = nn.Parameter(zs_weight) 52 | else: 53 | self.register_buffer('zs_weight', zs_weight) 54 | 55 | assert self.zs_weight.shape[1] == num_classes + 1, self.zs_weight.shape 56 | 57 | @classmethod 58 | def from_config(cls, cfg, input_shape): 59 | return { 60 | 'input_shape': input_shape, 61 | 'num_classes': cfg.MODEL.ROI_HEADS.NUM_CLASSES, 62 | 'zs_weight_path': cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH, 63 | 'zs_weight_dim': cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_DIM, 64 | 'use_bias': cfg.MODEL.ROI_BOX_HEAD.USE_BIAS, 65 | 'norm_weight': cfg.MODEL.ROI_BOX_HEAD.NORM_WEIGHT, 66 | 'norm_temperature': cfg.MODEL.ROI_BOX_HEAD.NORM_TEMP, 67 | } 68 | 69 | def forward(self, x, classifier=None): 70 | ''' 71 | Inputs: 72 | x: B x D' 73 | classifier_info: (C', C' x D) 74 | ''' 75 | x = self.linear(x) 76 | if classifier is not None: 77 | zs_weight = classifier.permute(1, 0).contiguous() # D x C' 78 | zs_weight = F.normalize(zs_weight, p=2, dim=0) \ 79 | if self.norm_weight else zs_weight 80 | else: 81 | zs_weight = self.zs_weight 82 | if self.norm_weight: 83 | x = self.norm_temperature * F.normalize(x, p=2, dim=1) 84 | x = torch.mm(x, zs_weight) 85 | if self.use_bias: 86 | x = x + self.cls_bias 87 | return x 88 | -------------------------------------------------------------------------------- /mmovod/modeling/text/__pycache__/text_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prannaykaul/mm-ovod/c5cabd5f2c80463b1c9f87aaca0a49c189af2202/mmovod/modeling/text/__pycache__/text_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /mmovod/modeling/text/text_encoder.py: -------------------------------------------------------------------------------- 1 | # This code is modified from https://github.com/openai/CLIP/blob/main/clip/clip.py 2 | # Modified by Xingyi Zhou 3 | # The original code is under MIT license 4 | # Copyright (c) Facebook, Inc. and its affiliates. 5 | from typing import Union, List 6 | from collections import OrderedDict 7 | import torch 8 | from torch import nn 9 | import torch 10 | 11 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 12 | 13 | __all__ = ["tokenize"] 14 | 15 | count = 0 16 | 17 | class LayerNorm(nn.LayerNorm): 18 | """Subclass torch's LayerNorm to handle fp16.""" 19 | 20 | def forward(self, x: torch.Tensor): 21 | orig_type = x.dtype 22 | ret = super().forward(x.type(torch.float32)) 23 | return ret.type(orig_type) 24 | 25 | 26 | class QuickGELU(nn.Module): 27 | def forward(self, x: torch.Tensor): 28 | return x * torch.sigmoid(1.702 * x) 29 | 30 | 31 | class ResidualAttentionBlock(nn.Module): 32 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 33 | super().__init__() 34 | 35 | self.attn = nn.MultiheadAttention(d_model, n_head) 36 | self.ln_1 = LayerNorm(d_model) 37 | self.mlp = nn.Sequential(OrderedDict([ 38 | ("c_fc", nn.Linear(d_model, d_model * 4)), 39 | ("gelu", QuickGELU()), 40 | ("c_proj", nn.Linear(d_model * 4, d_model)) 41 | ])) 42 | self.ln_2 = LayerNorm(d_model) 43 | self.attn_mask = attn_mask 44 | 45 | def attention(self, x: torch.Tensor): 46 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 47 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 48 | 49 | def forward(self, x: torch.Tensor): 50 | x = x + self.attention(self.ln_1(x)) 51 | x = x + self.mlp(self.ln_2(x)) 52 | return x 53 | 54 | 55 | class Transformer(nn.Module): 56 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 57 | super().__init__() 58 | self.width = width 59 | self.layers = layers 60 | self.resblocks = nn.Sequential( 61 | *[ResidualAttentionBlock(width, heads, attn_mask) \ 62 | for _ in range(layers)]) 63 | 64 | def forward(self, x: torch.Tensor): 65 | return self.resblocks(x) 66 | 67 | class CLIPTEXT(nn.Module): 68 | def __init__(self, 69 | embed_dim=512, 70 | # text 71 | context_length=77, 72 | vocab_size=49408, 73 | transformer_width=512, 74 | transformer_heads=8, 75 | transformer_layers=12 76 | ): 77 | super().__init__() 78 | 79 | self._tokenizer = _Tokenizer() 80 | self.context_length = context_length 81 | 82 | self.transformer = Transformer( 83 | width=transformer_width, 84 | layers=transformer_layers, 85 | heads=transformer_heads, 86 | attn_mask=self.build_attention_mask() 87 | ) 88 | 89 | self.vocab_size = vocab_size 90 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 91 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 92 | self.ln_final = LayerNorm(transformer_width) 93 | 94 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 95 | # self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 96 | 97 | self.initialize_parameters() 98 | 99 | def initialize_parameters(self): 100 | nn.init.normal_(self.token_embedding.weight, std=0.02) 101 | nn.init.normal_(self.positional_embedding, std=0.01) 102 | 103 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 104 | attn_std = self.transformer.width ** -0.5 105 | fc_std = (2 * self.transformer.width) ** -0.5 106 | for block in self.transformer.resblocks: 107 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 108 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 109 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 110 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 111 | 112 | if self.text_projection is not None: 113 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 114 | 115 | def build_attention_mask(self): 116 | # lazily create causal attention mask, with full attention between the vision tokens 117 | # pytorch uses additive attention mask; fill with -inf 118 | mask = torch.empty(self.context_length, self.context_length) 119 | mask.fill_(float("-inf")) 120 | mask.triu_(1) # zero out the lower diagonal 121 | return mask 122 | 123 | @property 124 | def device(self): 125 | return self.text_projection.device 126 | 127 | @property 128 | def dtype(self): 129 | return self.text_projection.dtype 130 | 131 | def tokenize(self, 132 | texts: Union[str, List[str]], \ 133 | context_length: int = 77) -> torch.LongTensor: 134 | """ 135 | """ 136 | if isinstance(texts, str): 137 | texts = [texts] 138 | 139 | sot_token = self._tokenizer.encoder["<|startoftext|>"] 140 | eot_token = self._tokenizer.encoder["<|endoftext|>"] 141 | all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts] 142 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 143 | 144 | for i, tokens in enumerate(all_tokens): 145 | if len(tokens) > context_length: 146 | st = torch.randint( 147 | len(tokens) - context_length + 1, (1,))[0].item() 148 | tokens = tokens[st: st + context_length] 149 | # raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 150 | result[i, :len(tokens)] = torch.tensor(tokens) 151 | 152 | return result 153 | 154 | def encode_text(self, text): 155 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 156 | x = x + self.positional_embedding.type(self.dtype) 157 | x = x.permute(1, 0, 2) # NLD -> LND 158 | x = self.transformer(x) 159 | x = x.permute(1, 0, 2) # LND -> NLD 160 | x = self.ln_final(x).type(self.dtype) 161 | # take features from the eot embedding (eot_token is the highest number in each sequence) 162 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 163 | return x 164 | 165 | def forward(self, captions): 166 | ''' 167 | captions: list of strings 168 | ''' 169 | text = self.tokenize(captions).to(self.device) # B x L x D 170 | features = self.encode_text(text) # B x D 171 | return features 172 | 173 | 174 | def build_text_encoder(pretrain=True): 175 | text_encoder = CLIPTEXT() 176 | if pretrain: 177 | import clip 178 | pretrained_model, _ = clip.load("ViT-B/32", device='cpu') 179 | state_dict = pretrained_model.state_dict() 180 | to_delete_keys = ["logit_scale", "input_resolution", \ 181 | "context_length", "vocab_size"] + \ 182 | [k for k in state_dict.keys() if k.startswith('visual.')] 183 | for k in to_delete_keys: 184 | if k in state_dict: 185 | del state_dict[k] 186 | print('Loading pretrained CLIP') 187 | text_encoder.load_state_dict(state_dict) 188 | # import pdb; pdb.set_trace() 189 | return text_encoder -------------------------------------------------------------------------------- /mmovod/modeling/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | import json 4 | import numpy as np 5 | from torch.nn import functional as F 6 | 7 | def load_class_freq( 8 | path='datasets/metadata/lvis_v1_train_cat_info.json', freq_weight=1.0): 9 | cat_info = json.load(open(path, 'r')) 10 | cat_info = torch.tensor( 11 | [c['image_count'] for c in sorted(cat_info, key=lambda x: x['id'])]) 12 | freq_weight = cat_info.float() ** freq_weight 13 | return freq_weight 14 | 15 | 16 | def get_fed_loss_inds(gt_classes, num_sample_cats, C, weight=None): 17 | appeared = torch.unique(gt_classes) # C' 18 | prob = appeared.new_ones(C + 1).float() 19 | prob[-1] = 0 20 | if len(appeared) < num_sample_cats: 21 | if weight is not None: 22 | prob[:C] = weight.float().clone() 23 | prob[appeared] = 0 24 | more_appeared = torch.multinomial( 25 | prob, num_sample_cats - len(appeared), 26 | replacement=False) 27 | appeared = torch.cat([appeared, more_appeared]) 28 | return appeared 29 | 30 | 31 | 32 | def reset_cls_test(model, cls_path, num_classes): 33 | model.roi_heads.num_classes = num_classes 34 | if type(cls_path) == str: 35 | print('Resetting zs_weight', cls_path) 36 | zs_weight = torch.tensor( 37 | np.load(cls_path), 38 | dtype=torch.float32).permute(1, 0).contiguous() # D x C 39 | else: 40 | zs_weight = cls_path 41 | zs_weight = torch.cat( 42 | [zs_weight, zs_weight.new_zeros((zs_weight.shape[0], 1))], 43 | dim=1) # D x (C + 1) 44 | if model.roi_heads.box_predictor[0].cls_score.norm_weight: 45 | zs_weight = F.normalize(zs_weight, p=2, dim=0) 46 | zs_weight = zs_weight.to(model.device) 47 | for k in range(len(model.roi_heads.box_predictor)): 48 | del model.roi_heads.box_predictor[k].cls_score.zs_weight 49 | model.roi_heads.box_predictor[k].cls_score.zs_weight = zs_weight -------------------------------------------------------------------------------- /mmovod/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import atexit 3 | import bisect 4 | import multiprocessing as mp 5 | from collections import deque 6 | import cv2 7 | import torch 8 | 9 | from detectron2.data import MetadataCatalog 10 | from detectron2.engine.defaults import DefaultPredictor 11 | from detectron2.utils.video_visualizer import VideoVisualizer 12 | from detectron2.utils.visualizer import ColorMode, Visualizer 13 | 14 | from .modeling.utils import reset_cls_test 15 | 16 | 17 | def get_clip_embeddings(vocabulary, prompt='a '): 18 | from detic.modeling.text.text_encoder import build_text_encoder 19 | text_encoder = build_text_encoder(pretrain=True) 20 | text_encoder.eval() 21 | texts = [prompt + x for x in vocabulary] 22 | emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu() 23 | return emb 24 | 25 | BUILDIN_CLASSIFIER = { 26 | 'lvis': 'datasets/metadata/lvis_v1_clip_a+cname.npy', 27 | 'objects365': 'datasets/metadata/o365_clip_a+cnamefix.npy', 28 | 'openimages': 'datasets/metadata/oid_clip_a+cname.npy', 29 | 'coco': 'datasets/metadata/coco_clip_a+cname.npy', 30 | } 31 | 32 | BUILDIN_METADATA_PATH = { 33 | 'lvis': 'lvis_v1_val', 34 | 'objects365': 'objects365_v2_val', 35 | 'openimages': 'oid_val_expanded', 36 | 'coco': 'coco_2017_val', 37 | } 38 | 39 | class VisualizationDemo(object): 40 | def __init__(self, cfg, args, 41 | instance_mode=ColorMode.IMAGE, parallel=False): 42 | """ 43 | Args: 44 | cfg (CfgNode): 45 | instance_mode (ColorMode): 46 | parallel (bool): whether to run the model in different processes from visualization. 47 | Useful since the visualization logic can be slow. 48 | """ 49 | if args.vocabulary == 'custom': 50 | self.metadata = MetadataCatalog.get("__unused") 51 | self.metadata.thing_classes = args.custom_vocabulary.split(',') 52 | classifier = get_clip_embeddings(self.metadata.thing_classes) 53 | else: 54 | self.metadata = MetadataCatalog.get( 55 | BUILDIN_METADATA_PATH[args.vocabulary]) 56 | classifier = BUILDIN_CLASSIFIER[args.vocabulary] 57 | 58 | num_classes = len(self.metadata.thing_classes) 59 | self.cpu_device = torch.device("cpu") 60 | self.instance_mode = instance_mode 61 | 62 | self.parallel = parallel 63 | if parallel: 64 | num_gpu = torch.cuda.device_count() 65 | self.predictor = AsyncPredictor(cfg, num_gpus=num_gpu) 66 | else: 67 | self.predictor = DefaultPredictor(cfg) 68 | reset_cls_test(self.predictor.model, classifier, num_classes) 69 | 70 | def run_on_image(self, image): 71 | """ 72 | Args: 73 | image (np.ndarray): an image of shape (H, W, C) (in BGR order). 74 | This is the format used by OpenCV. 75 | 76 | Returns: 77 | predictions (dict): the output of the model. 78 | vis_output (VisImage): the visualized image output. 79 | """ 80 | vis_output = None 81 | predictions = self.predictor(image) 82 | # Convert image from OpenCV BGR format to Matplotlib RGB format. 83 | image = image[:, :, ::-1] 84 | visualizer = Visualizer(image, self.metadata, instance_mode=self.instance_mode) 85 | if "panoptic_seg" in predictions: 86 | panoptic_seg, segments_info = predictions["panoptic_seg"] 87 | vis_output = visualizer.draw_panoptic_seg_predictions( 88 | panoptic_seg.to(self.cpu_device), segments_info 89 | ) 90 | else: 91 | if "sem_seg" in predictions: 92 | vis_output = visualizer.draw_sem_seg( 93 | predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) 94 | ) 95 | if "instances" in predictions: 96 | instances = predictions["instances"].to(self.cpu_device) 97 | vis_output = visualizer.draw_instance_predictions(predictions=instances) 98 | 99 | return predictions, vis_output 100 | 101 | def _frame_from_video(self, video): 102 | while video.isOpened(): 103 | success, frame = video.read() 104 | if success: 105 | yield frame 106 | else: 107 | break 108 | 109 | def run_on_video(self, video): 110 | """ 111 | Visualizes predictions on frames of the input video. 112 | 113 | Args: 114 | video (cv2.VideoCapture): a :class:`VideoCapture` object, whose source can be 115 | either a webcam or a video file. 116 | 117 | Yields: 118 | ndarray: BGR visualizations of each video frame. 119 | """ 120 | video_visualizer = VideoVisualizer(self.metadata, self.instance_mode) 121 | 122 | def process_predictions(frame, predictions): 123 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 124 | if "panoptic_seg" in predictions: 125 | panoptic_seg, segments_info = predictions["panoptic_seg"] 126 | vis_frame = video_visualizer.draw_panoptic_seg_predictions( 127 | frame, panoptic_seg.to(self.cpu_device), segments_info 128 | ) 129 | elif "instances" in predictions: 130 | predictions = predictions["instances"].to(self.cpu_device) 131 | vis_frame = video_visualizer.draw_instance_predictions(frame, predictions) 132 | elif "sem_seg" in predictions: 133 | vis_frame = video_visualizer.draw_sem_seg( 134 | frame, predictions["sem_seg"].argmax(dim=0).to(self.cpu_device) 135 | ) 136 | 137 | # Converts Matplotlib RGB format to OpenCV BGR format 138 | vis_frame = cv2.cvtColor(vis_frame.get_image(), cv2.COLOR_RGB2BGR) 139 | return vis_frame 140 | 141 | frame_gen = self._frame_from_video(video) 142 | if self.parallel: 143 | buffer_size = self.predictor.default_buffer_size 144 | 145 | frame_data = deque() 146 | 147 | for cnt, frame in enumerate(frame_gen): 148 | frame_data.append(frame) 149 | self.predictor.put(frame) 150 | 151 | if cnt >= buffer_size: 152 | frame = frame_data.popleft() 153 | predictions = self.predictor.get() 154 | yield process_predictions(frame, predictions) 155 | 156 | while len(frame_data): 157 | frame = frame_data.popleft() 158 | predictions = self.predictor.get() 159 | yield process_predictions(frame, predictions) 160 | else: 161 | for frame in frame_gen: 162 | yield process_predictions(frame, self.predictor(frame)) 163 | 164 | 165 | class AsyncPredictor: 166 | """ 167 | A predictor that runs the model asynchronously, possibly on >1 GPUs. 168 | Because rendering the visualization takes considerably amount of time, 169 | this helps improve throughput a little bit when rendering videos. 170 | """ 171 | 172 | class _StopToken: 173 | pass 174 | 175 | class _PredictWorker(mp.Process): 176 | def __init__(self, cfg, task_queue, result_queue): 177 | self.cfg = cfg 178 | self.task_queue = task_queue 179 | self.result_queue = result_queue 180 | super().__init__() 181 | 182 | def run(self): 183 | predictor = DefaultPredictor(self.cfg) 184 | 185 | while True: 186 | task = self.task_queue.get() 187 | if isinstance(task, AsyncPredictor._StopToken): 188 | break 189 | idx, data = task 190 | result = predictor(data) 191 | self.result_queue.put((idx, result)) 192 | 193 | def __init__(self, cfg, num_gpus: int = 1): 194 | """ 195 | Args: 196 | cfg (CfgNode): 197 | num_gpus (int): if 0, will run on CPU 198 | """ 199 | num_workers = max(num_gpus, 1) 200 | self.task_queue = mp.Queue(maxsize=num_workers * 3) 201 | self.result_queue = mp.Queue(maxsize=num_workers * 3) 202 | self.procs = [] 203 | for gpuid in range(max(num_gpus, 1)): 204 | cfg = cfg.clone() 205 | cfg.defrost() 206 | cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu" 207 | self.procs.append( 208 | AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue) 209 | ) 210 | 211 | self.put_idx = 0 212 | self.get_idx = 0 213 | self.result_rank = [] 214 | self.result_data = [] 215 | 216 | for p in self.procs: 217 | p.start() 218 | atexit.register(self.shutdown) 219 | 220 | def put(self, image): 221 | self.put_idx += 1 222 | self.task_queue.put((self.put_idx, image)) 223 | 224 | def get(self): 225 | self.get_idx += 1 # the index needed for this request 226 | if len(self.result_rank) and self.result_rank[0] == self.get_idx: 227 | res = self.result_data[0] 228 | del self.result_data[0], self.result_rank[0] 229 | return res 230 | 231 | while True: 232 | # make sure the results are returned in the correct order 233 | idx, res = self.result_queue.get() 234 | if idx == self.get_idx: 235 | return res 236 | insert = bisect.bisect(self.result_rank, idx) 237 | self.result_rank.insert(insert, idx) 238 | self.result_data.insert(insert, res) 239 | 240 | def __len__(self): 241 | return self.put_idx - self.get_idx 242 | 243 | def __call__(self, image): 244 | self.put(image) 245 | return self.get() 246 | 247 | def shutdown(self): 248 | for _ in self.procs: 249 | self.task_queue.put(AsyncPredictor._StopToken()) 250 | 251 | @property 252 | def default_buffer_size(self): 253 | return len(self.procs) * 5 254 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | mss 3 | timm==0.5.4 4 | dataclasses 5 | ftfy 6 | regex 7 | fasttext 8 | scikit-learn 9 | lvis 10 | nltk 11 | openai 12 | numpy==1.23.3 13 | git+https://github.com/openai/CLIP.git 14 | -------------------------------------------------------------------------------- /tools/collate_exemplar_dict.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import glob 4 | import copy 5 | import pickle 6 | import os 7 | import argparse 8 | 9 | from collections import defaultdict 10 | from multiprocessing import Pool 11 | 12 | import numpy as np 13 | 14 | from nltk.corpus import wordnet as wn 15 | from lvis import LVIS 16 | from tqdm.auto import tqdm 17 | 18 | 19 | def get_code(syn): 20 | return syn.pos() + str(syn.offset()).zfill(8) 21 | 22 | 23 | def get_lemma_names(syn): 24 | return [lemma.name() for lemma in syn.lemmas()] 25 | 26 | 27 | def get_args(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument( 30 | "--lvis-dir", 31 | type=str, 32 | default="datasets/lvis", 33 | ) 34 | parser.add_argument( 35 | "--imagenet-dir", 36 | type=str, 37 | default="datasets/imagenet", 38 | ) 39 | parser.add_argument( 40 | "--visual-genome-dir", 41 | type=str, 42 | default="datasets/VisualGenome", 43 | ) 44 | parser.add_argument( 45 | "--output-path", 46 | type=str, 47 | default="datasets/metadata/exemplar_dict.json", 48 | ) 49 | args = parser.parse_args() 50 | return args 51 | 52 | 53 | def main(args): 54 | lvis_train_dataset = LVIS(os.path.join(args.lvis_dir, "lvis_v1_train.json")) 55 | # In LVIS the 'stopsign' category is not a wordnet synset and so replace with 'street_sign', a near cousin 56 | street_sign_synset = wn.synset("street_sign.n.01") 57 | lvis_synsets = [ 58 | wn.synset(c['synset']) if "stop_sign" not in c['name'] else street_sign_synset 59 | for c in lvis_train_dataset.cats.values() 60 | ] 61 | synset2catid = {v['synset']: v['id'] for v in lvis_train_dataset.cats.values()} 62 | catid2synset = {v['id']: v['synset'] for v in lvis_train_dataset.cats.values()} 63 | 64 | # Lets start by collecting images in ImageNet 65 | # we do this using all ImageNet21k with tree hierarchy as defined in the dataset preparation file 66 | synsets_pos_off_paths = ( 67 | glob.glob(os.path.join(args.imagenet_dir, "train/*")) 68 | + glob.glob(os.path.join(args.imagenet_dir, "imagenet21k_small_classes/*")) 69 | ) 70 | # import pdb; pdb.set_trace() 71 | synsets_pos_off2path = {os.path.basename(d): d for d in synsets_pos_off_paths} 72 | synsets_pos_off = [os.path.basename(d) for d in synsets_pos_off_paths] 73 | available_imagenet_synsets = [wn.synset_from_pos_and_offset(a[0], int(a[1:])) for a in synsets_pos_off] 74 | direct_match_imagenet_synsets = [v for v in available_imagenet_synsets if v.name() in synset2catid.keys()] 75 | assert len(direct_match_imagenet_synsets) == 997 76 | exemplar_dict_imagenet = defaultdict(list) 77 | 78 | for v in tqdm(direct_match_imagenet_synsets, total=len(direct_match_imagenet_synsets)): 79 | code = get_code(v) 80 | filenames = glob.glob(os.path.join(synsets_pos_off2path[code], "*.JPEG")) 81 | anns = [] 82 | for filename in filenames: 83 | # need the last three parts of the path 84 | ann = { 85 | "file_name": "/".join(filename.split("/")[-3:]), 86 | "category_id": synset2catid[v.name()], 87 | "dataset": "imagenet21k", 88 | "synset": v.name(), 89 | } 90 | anns.append(ann) 91 | exemplar_dict_imagenet[v.name()].extend(anns) 92 | 93 | # Now lets collect images from LVIS with area > 32*32 94 | exemplar_dict_lvis = defaultdict(list) 95 | lvis_anns = lvis_train_dataset.load_anns( 96 | lvis_train_dataset.get_ann_ids( 97 | cat_ids=lvis_train_dataset.get_cat_ids(), 98 | area_rng=[32.0**2, float('inf')] 99 | )) 100 | 101 | for ann in tqdm(lvis_anns): 102 | img = lvis_train_dataset.load_imgs([ann['image_id']])[0] 103 | ann['dataset'] = "lvis" 104 | ann['file_name'] = "/".join(img['coco_url'].split("/")[-2:]) 105 | ann['synset'] = catid2synset[ann['category_id']] 106 | exemplar_dict_lvis[catid2synset[ann['category_id']]].append(ann) 107 | 108 | # get keys from both dictionaries 109 | keys = list(set(exemplar_dict_imagenet.keys()).union(set(exemplar_dict_lvis.keys()))) 110 | exemplar_dict_combined_two = defaultdict(list) 111 | for key in keys: 112 | exemplar_dict_combined_two[key] = exemplar_dict_imagenet[key] + exemplar_dict_lvis[key] 113 | 114 | # let's find the lacking synsets 115 | lacking_synsets = [k.name() for k in lvis_synsets if len(exemplar_dict_combined_two[k.name()]) < 40] 116 | print(f"After collecting exemplars from ImageNet and LVIS, there are still {len(lacking_synsets)} without at least 40 exemplars") 117 | 118 | # Now let's collect images from Visual Genome 119 | exemplar_dict_vg = defaultdict(list) 120 | vg_objects_path = os.path.join(args.visual_genome_dir, "objects.json") 121 | with open(vg_objects_path, "r") as f: 122 | visual_genome_objects = json.load(f) 123 | 124 | vg_images_path = os.path.join(args.visual_genome_dir, "image_data.json") 125 | with open(vg_images_path, "r") as f: 126 | visual_genome_images = json.load(f) 127 | visual_genome_iid2path = {v['image_id']: "/".join(v['url'].split("/")[-2:]) for v in visual_genome_images} 128 | 129 | synsets2boxes = defaultdict(list) 130 | for i, img in tqdm(enumerate(visual_genome_objects), total=len(visual_genome_objects)): 131 | for j, obj in enumerate(img['objects']): 132 | if len(obj['synsets']) != 1: 133 | continue 134 | synsets2boxes[obj['synsets'][0]].append((i, j, obj['w'] * obj['h'])) 135 | 136 | # We shall only use visual genome for synsets with less than 40 exemplars 137 | for k in tqdm(lacking_synsets, total=len(lacking_synsets)): 138 | visual_genome_ids = synsets2boxes[k] 139 | anns = [] 140 | for a in visual_genome_ids: 141 | if a[-1] < 32**2: 142 | continue 143 | img_objects = visual_genome_objects[a[0]] 144 | iid = img_objects['image_id'] 145 | ann = { 146 | 'image_id': iid, 147 | 'dataset': 'visual_genome', 148 | 'file_name': visual_genome_iid2path[iid], 149 | 'category_id': synset2catid[k], 150 | 'synset': k, 151 | } 152 | ann.update(img_objects['objects'][a[1]]) 153 | assert ann['synsets'][0] == k 154 | anns.append(ann) 155 | exemplar_dict_vg[k] = anns 156 | 157 | # Now let's combine all the dictionaries 158 | exemplar_dict_combined_three = defaultdict(list) 159 | for syn in synset2catid.keys(): 160 | exemplar_dict_combined_three[syn].extend( 161 | exemplar_dict_lvis[syn] 162 | + exemplar_dict_imagenet[syn] 163 | + exemplar_dict_vg[syn] 164 | ) 165 | 166 | # At this point there should be at least TEN exemplars in 1160 out of 1203 synsets 167 | # as described in the appendix of the paper, some other synsets in imagenet are suitable 168 | # and in some cases we use close cousins 169 | 170 | manual_synsets = { 171 | "anklet.n.03": ["anklet.n.02"], 172 | "beach_ball.n.01": ["volleyball.n.02"], 173 | "bible.n.01": ["book.n.11"], 174 | "black_flag.n.01": ["flag.n.01"], 175 | "bob.n.05": ["spinner.n.03"], 176 | "bowl.n.08": ["pipe_smoker.n.01"], 177 | "brooch.n.01": ["pectoral.n.02", "bling.n.01"], 178 | "card.n.02": ["business_card.n.01", "library_card.n.01"], 179 | "checkbook.n.01": ["daybook.n.02"], 180 | "coil.n.05": ["coil_spring.n.01"], 181 | "coloring_material.n.01": ["crayon.n.01"], 182 | "crab.n.05": ["shellfish.n.01", "lobster.n.01"], 183 | "cube.n.05": ["die.n.01"], 184 | "cufflink.n.01": ["bling.n.01"], 185 | "dishwasher_detergent.n.01": ["laundry_detergent.n.01", "cleansing_agent.n.01"], 186 | "diving_board.n.01": ["springboard.n.01"], 187 | "dollar.n.02": ["money.n.01", "paper_money.n.01"], 188 | "eel.n.01": ["electric_eel.n.01"], 189 | "escargot.n.01": ["snail.n.01"], 190 | "gargoyle.n.02": ["statue.n.01"], 191 | "gem.n.02": ["crystal.n.01", "bling.n.01"], 192 | "grits.n.01": ["congee.n.01"], 193 | "hardback.n.01": ["book.n.07"], 194 | "jewel.n.01": ["bling.n.01"], 195 | "keycard.n.01": ["magnetic_stripe.n.01"], 196 | "lamb_chop.n.01": ["porkchop.n.01", "rib.n.03"], 197 | "mail_slot.n.01": ["maildrop.n.01", "mailbox.n.01"], 198 | "milestone.n.01": ["cairn.n.01"], 199 | "pad.n.03": ["handstamp.n.01"], 200 | "paperback_book.n.01": ["book.n.07"], 201 | "paperweight.n.01": ["letter_opener.n.01"], 202 | "pennant.n.02": ["bunting.n.01"], 203 | "penny.n.02": ["coin.n.01"], 204 | "plume.n.02": ["headdress.n.01"], 205 | "poker.n.01": ["fire_tongs.n.01"], 206 | "rag_doll.n.01": ["doll.n.01"], 207 | "road_map.n.02": ["map.n.01"], 208 | "scarecrow.n.01": ["creche.n.02"], 209 | "sparkler.n.02": ["firework.n.01"], 210 | "sugarcane.n.01": ["cane_sugar.n.02"], 211 | "water_pistol.n.01": ["pistol.n.01"], 212 | "wedding_ring.n.01": ["ring.n.01"], 213 | "windsock.n.01": ["weathervane.n.01"], 214 | } 215 | 216 | manual_exemplar_dict = defaultdict(list) 217 | 218 | remaining_lacking_synsets = [(k, len(v)) for k, v in exemplar_dict_combined_three.items() if len(v) < 10] 219 | assert len(remaining_lacking_synsets) == len(manual_synsets) == 43 220 | 221 | for k, v in tqdm(manual_synsets.items(), total=len(remaining_lacking_synsets)): 222 | for manual_synset in v: 223 | code = get_code(wn.synset(manual_synset)) 224 | if code not in synsets_pos_off2path: 225 | continue 226 | 227 | filenames = glob.glob(os.path.join(synsets_pos_off2path[code], "*.JPEG")) 228 | anns = [] 229 | for filename in filenames: 230 | ann = { 231 | "file_name": "/".join(filename.split("/")[-3:]), 232 | "category_id": synset2catid[k], 233 | "dataset": "imagenet21k", 234 | "synset": manual_synset, 235 | } 236 | anns.append(ann) 237 | # if k in remaining_lacking_synsets: 238 | # print(k, manual_synset, len(anns)) 239 | manual_exemplar_dict[k].extend(anns) 240 | 241 | for k, v in tqdm(manual_synsets.items(), total=len(remaining_lacking_synsets)): 242 | for manual_synset in v: 243 | if manual_synset not in synsets2boxes.keys(): 244 | continue 245 | visual_genome_ids = synsets2boxes[manual_synset] 246 | anns = [] 247 | for a in visual_genome_ids: 248 | if a[-1] < 32**2: 249 | continue 250 | img_objects = visual_genome_objects[a[0]] 251 | iid = img_objects['image_id'] 252 | ann = { 253 | 'image_id': iid, 254 | 'dataset': 'visual_genome', 255 | 'file_name': visual_genome_iid2path[iid], 256 | 'category_id': synset2catid[k], 257 | 'synset': manual_synset, 258 | } 259 | ann.update(img_objects['objects'][a[1]]) 260 | # assert ann['synsets'][0] == k 261 | anns.append(ann) 262 | manual_exemplar_dict[k].extend(anns) 263 | 264 | # combine manual_exemplar_dict with exemplar_dict_combined_three 265 | for k, v in manual_exemplar_dict.items(): 266 | exemplar_dict_combined_three[k].extend(v) 267 | 268 | assert min([len(v) for k, v in exemplar_dict_combined_three.items()]) >= 10, "some synsets have less than 10 exemplars" 269 | with open(args.output_path, "w") as f: 270 | json.dump(exemplar_dict_combined_three, f) 271 | 272 | 273 | if __name__ == "__main__": 274 | args = get_args() 275 | main(args) 276 | -------------------------------------------------------------------------------- /tools/convert-thirdparty-pretrained-model-to-d2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import argparse 4 | import pickle 5 | import torch 6 | 7 | """ 8 | Usage: 9 | 10 | cd DETIC_ROOT/models/ 11 | wget https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth 12 | python ../tools/convert-thirdparty-pretrained-model-to-d2.py --path resnet50_miil_21k.pth 13 | 14 | wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth 15 | python ../tools/convert-thirdparty-pretrained-model-to-d2.py --path swin_base_patch4_window7_224_22k.pth 16 | 17 | """ 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--path', default='') 23 | args = parser.parse_args() 24 | 25 | print('Loading', args.path) 26 | model = torch.load(args.path, map_location="cpu") 27 | # import pdb; pdb.set_trace() 28 | if 'model' in model: 29 | model = model['model'] 30 | if 'state_dict' in model: 31 | model = model['state_dict'] 32 | ret = { 33 | "model": model, 34 | "__author__": "third_party", 35 | "matching_heuristics": True 36 | } 37 | out_path = args.path.replace('.pth', '.pkl') 38 | print('Saving to', out_path) 39 | pickle.dump(ret, open(out_path, "wb")) 40 | -------------------------------------------------------------------------------- /tools/create_imagenetlvis_json.py: -------------------------------------------------------------------------------- 1 | # edited from script in Detic from Facebook, Inc. and its affiliates. 2 | import argparse 3 | import json 4 | import os 5 | from nltk.corpus import wordnet as wn 6 | from detectron2.data.detection_utils import read_image 7 | 8 | 9 | def get_code(syn): 10 | return syn.pos() + str(syn.offset()).zfill(8) 11 | 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--imagenet-path', default='datasets/imagenet/imagenet21k_P') 16 | parser.add_argument('--lvis-meta-path', default='datasets/lvis/lvis_v1_val.json') 17 | parser.add_argument('--out-path', default='datasets/imagenet/annotations/imagenet_lvis_image_info.json') 18 | args = parser.parse_args() 19 | 20 | print('Loading LVIS meta') 21 | data = json.load(open(args.lvis_meta_path, 'r')) 22 | print('Done') 23 | synset2cat = {x['synset']: x for x in data['categories']} 24 | count = 0 25 | images = [] 26 | image_counts = {} 27 | synset2folders = {} 28 | for synset in synset2cat: 29 | if synset == "stop_sign.n.01": 30 | synset2folders["stop_sign.n.01"] = [] 31 | continue 32 | code = get_code(wn.synset(synset)) 33 | folders = [ 34 | os.path.join(args.imagenet_path, folder, code) for folder in ["train", "val", "imagenet21k_small_classes"] 35 | ] 36 | synset2folders[synset] = [f for f in folders if os.path.exists(f)] 37 | 38 | for synset, synset_folders in synset2folders.items(): 39 | if len(synset_folders) == 0: 40 | continue 41 | cat = synset2cat[synset] 42 | cat_id = cat['id'] 43 | cat_name = cat['name'] 44 | cat_images = [] 45 | files = [] 46 | for folder in synset_folders: 47 | folder_files = os.listdir(folder) 48 | for file in folder_files: 49 | count = count + 1 50 | # file_name only needs to be last two parts of path 51 | # import pdb; pdb.set_trace() 52 | file_name = '{}/{}/{}'.format(*(folder.split("/")[-2:]), file) 53 | assert os.path.join(folder, file) == os.path.join(args.imagenet_path, file_name) 54 | img = read_image(os.path.join(args.imagenet_path, file_name)) 55 | h, w = img.shape[:2] 56 | image = { 57 | 'id': count, 58 | 'file_name': file_name, 59 | 'pos_category_ids': [cat_id], 60 | 'width': w, 61 | 'height': h 62 | } 63 | cat_images.append(image) 64 | images.extend(cat_images) 65 | image_counts[cat_id] = len(cat_images) 66 | print(cat_id, cat_name, len(cat_images)) 67 | print('# Images', len(images)) 68 | for x in data['categories']: 69 | x['image_count'] = image_counts[x['id']] if x['id'] in image_counts else 0 70 | out = {'categories': data['categories'], 'images': images, 'annotations': []} 71 | print('Writing to', args.out_path) 72 | json.dump(out, open(args.out_path, 'w')) 73 | -------------------------------------------------------------------------------- /tools/dump_clip_features.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import argparse 3 | import json 4 | import torch 5 | import numpy as np 6 | import itertools 7 | from nltk.corpus import wordnet 8 | import sys 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--ann', default='datasets/lvis/lvis_v1_val.json') 13 | parser.add_argument('--out_path', default='') 14 | parser.add_argument('--prompt', default='a') 15 | parser.add_argument('--model', default='clip') 16 | parser.add_argument('--clip_model', default="ViT-B/32") 17 | parser.add_argument('--fix_space', action='store_true') 18 | parser.add_argument('--use_underscore', action='store_true') 19 | parser.add_argument('--avg_synonyms', action='store_true') 20 | parser.add_argument('--use_wn_name', action='store_true') 21 | args = parser.parse_args() 22 | 23 | print('Loading', args.ann) 24 | data = json.load(open(args.ann, 'r')) 25 | cat_names = [x['name'] for x in \ 26 | sorted(data['categories'], key=lambda x: x['id'])] 27 | if 'synonyms' in data['categories'][0]: 28 | if args.use_wn_name: 29 | synonyms = [ 30 | [xx.name() for xx in wordnet.synset(x['synset']).lemmas()] \ 31 | if x['synset'] != 'stop_sign.n.01' else ['stop_sign'] \ 32 | for x in sorted(data['categories'], key=lambda x: x['id'])] 33 | else: 34 | synonyms = [x['synonyms'] for x in \ 35 | sorted(data['categories'], key=lambda x: x['id'])] 36 | else: 37 | synonyms = [] 38 | if args.fix_space: 39 | cat_names = [x.replace('_', ' ') for x in cat_names] 40 | if args.use_underscore: 41 | cat_names = [x.strip().replace('/ ', '/').replace(' ', '_') for x in cat_names] 42 | print('cat_names', cat_names) 43 | device = "cuda" if torch.cuda.is_available() else "cpu" 44 | 45 | if args.prompt == 'a': 46 | sentences = ['a ' + x for x in cat_names] 47 | sentences_synonyms = [['a ' + xx for xx in x] for x in synonyms] 48 | if args.prompt == 'none': 49 | sentences = [x for x in cat_names] 50 | sentences_synonyms = [[xx for xx in x] for x in synonyms] 51 | elif args.prompt == 'photo': 52 | sentences = ['a photo of a {}'.format(x) for x in cat_names] 53 | sentences_synonyms = [['a photo of a {}'.format(xx) for xx in x] \ 54 | for x in synonyms] 55 | elif args.prompt == 'scene': 56 | sentences = ['a photo of a {} in the scene'.format(x) for x in cat_names] 57 | sentences_synonyms = [['a photo of a {} in the scene'.format(xx) for xx in x] \ 58 | for x in synonyms] 59 | 60 | print('sentences_synonyms', len(sentences_synonyms), \ 61 | sum(len(x) for x in sentences_synonyms)) 62 | if args.model == 'clip': 63 | import clip 64 | print('Loading CLIP') 65 | model, preprocess = clip.load(args.clip_model, device=device) 66 | if args.avg_synonyms: 67 | sentences = list(itertools.chain.from_iterable(sentences_synonyms)) 68 | print('flattened_sentences', len(sentences)) 69 | text = clip.tokenize(sentences).to(device) 70 | with torch.no_grad(): 71 | if len(text) > 10000: 72 | text_features = torch.cat([ 73 | model.encode_text(text[:len(text) // 2]), 74 | model.encode_text(text[len(text) // 2:])], 75 | dim=0) 76 | else: 77 | text_features = model.encode_text(text) 78 | print('text_features.shape', text_features.shape) 79 | if args.avg_synonyms: 80 | synonyms_per_cat = [len(x) for x in sentences_synonyms] 81 | text_features = text_features.split(synonyms_per_cat, dim=0) 82 | text_features = [x.mean(dim=0) for x in text_features] 83 | text_features = torch.stack(text_features, dim=0) 84 | print('after stack', text_features.shape) 85 | text_features = text_features.cpu().numpy() 86 | elif args.model in ['bert', 'roberta']: 87 | from transformers import AutoTokenizer, AutoModel 88 | if args.model == 'bert': 89 | model_name = 'bert-large-uncased' 90 | if args.model == 'roberta': 91 | model_name = 'roberta-large' 92 | tokenizer = AutoTokenizer.from_pretrained(model_name) 93 | model = AutoModel.from_pretrained(model_name) 94 | model.eval() 95 | if args.avg_synonyms: 96 | sentences = list(itertools.chain.from_iterable(sentences_synonyms)) 97 | print('flattened_sentences', len(sentences)) 98 | inputs = tokenizer(sentences, padding=True, return_tensors="pt") 99 | with torch.no_grad(): 100 | model_outputs = model(**inputs) 101 | outputs = model_outputs.pooler_output 102 | text_features = outputs.detach().cpu() 103 | if args.avg_synonyms: 104 | synonyms_per_cat = [len(x) for x in sentences_synonyms] 105 | text_features = text_features.split(synonyms_per_cat, dim=0) 106 | text_features = [x.mean(dim=0) for x in text_features] 107 | text_features = torch.stack(text_features, dim=0) 108 | print('after stack', text_features.shape) 109 | text_features = text_features.numpy() 110 | print('text_features.shape', text_features.shape) 111 | else: 112 | assert 0, args.model 113 | if args.out_path != '': 114 | print('saveing to', args.out_path) 115 | np.save(open(args.out_path, 'wb'), text_features) 116 | import pdb; pdb.set_trace() 117 | -------------------------------------------------------------------------------- /tools/dump_clip_features_lvis_sentences.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import torch 4 | import numpy as np 5 | from tqdm.auto import tqdm 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--descriptions-path", 11 | type=str, 12 | default="datasets/metadata/lvis_gpt3_text-davinci-002_descriptions_author.json" 13 | ) 14 | parser.add_argument( 15 | "--ann-path", 16 | type=str, 17 | default="datasets/lvis/lvis_v1_val.json" 18 | ) 19 | parser.add_argument( 20 | "--out-path", 21 | type=str, 22 | default="datasets/metadata/lvis_gpt3_text-davinci-002_features_author.npy" 23 | ) 24 | parser.add_argument('--model', default='clip') 25 | parser.add_argument('--clip_model', default="ViT-B/32") 26 | 27 | args = parser.parse_args() 28 | 29 | print("Loading descriptions from: {}".format(args.descriptions_path)) 30 | with open(args.descriptions_path, 'r') as f: 31 | descriptions = json.load(f) 32 | 33 | print("Loading annotations from: {}".format(args.ann_path)) 34 | with open(args.ann_path, 'r') as f: 35 | ann_data = json.load(f) 36 | 37 | lvis_cats = ann_data['categories'] 38 | lvis_cats = [(c['id'], c['synonyms'][0].replace("_", " ")) for c in lvis_cats] 39 | sentences_per_cat = [descriptions[c[1]] for c in lvis_cats] 40 | 41 | device = "cuda" if torch.cuda.is_available() else "cpu" 42 | print( 43 | "Total number of sentences for {} classes: {}".format( 44 | len(sentences_per_cat), sum(len(x) for x in sentences_per_cat)) 45 | ) 46 | 47 | if args.model == 'clip': 48 | import clip 49 | print('Loading CLIP') 50 | model, preprocess = clip.load(args.clip_model, device=device) 51 | model.eval() 52 | all_text_features = [] 53 | for cat_sentences in tqdm(sentences_per_cat): 54 | text = clip.tokenize(cat_sentences, truncate=True) 55 | with torch.no_grad(): 56 | if len(text) > 10000: 57 | split_text = text.split(128) 58 | split_features = [] 59 | for t in tqdm(split_text, total=len(split_text)): 60 | split_features.append(model.encode_text(t.to(device)).cpu()) 61 | text_features = torch.cat(split_features, dim=0) 62 | # text_features = torch.cat([ 63 | # model.encode_text(t) for t in text.split(128) 64 | # ], dim=0) 65 | else: 66 | text_features = model.encode_text(text.to(device)) 67 | all_text_features.append(text_features.mean(dim=0)) 68 | all_text_features = torch.stack(all_text_features) 69 | print("Output text features shape: ", all_text_features.shape) 70 | text_features = text_features.cpu().numpy() 71 | 72 | else: 73 | assert 0, "Model {} is not supported only clip".format(args.model) 74 | if args.out_path != '': 75 | print('saveing to', args.out_path) 76 | np.save(open(args.out_path, 'wb'), text_features) 77 | -------------------------------------------------------------------------------- /tools/generate_descriptions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import json 4 | 5 | import openai 6 | from openai.error import RateLimitError 7 | from lvis import LVIS 8 | 9 | 10 | API_KEY = "YOUR_API_KEY" 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--ann-path", 17 | type=str, 18 | default="datasets/lvis/lvis_v1_val.json" 19 | ) 20 | parser.add_argument( 21 | "--output-path", 22 | type=str, 23 | default="datasets/metadata/lvis_gpt3_{}_descriptions_own.json" 24 | ) 25 | parser.add_argument( 26 | "--openai-model", 27 | type=str, 28 | default="text-davinci-002" 29 | ) 30 | 31 | 32 | def main(args): 33 | lvis_gt = LVIS(args.ann_path) 34 | categories = sorted(lvis_gt.cats.values(), key=lambda x: x["id"]) 35 | 36 | category_list = [c['synonyms'][0].replace('_', ' ') for c in categories] 37 | all_responses = {} 38 | vowel_list = ['a', 'e', 'i', 'o', 'u'] 39 | 40 | for i, category in enumerate(category_list): 41 | if category[0] in vowel_list: 42 | article = 'an' 43 | else: 44 | article = 'a' 45 | prompts = [] 46 | prompts.append("Describe what " + article + " " + category + " looks like.") 47 | 48 | all_result = [] 49 | # call openai api taking into account rate limits 50 | for curr_prompt in prompts: 51 | try: 52 | response = openai.Completion.create( 53 | model="text-davinci-002", 54 | prompt=curr_prompt, 55 | temperature=0.99, 56 | max_tokens=50, 57 | n=10, 58 | stop="." 59 | ) 60 | except RateLimitError: 61 | print("Hit rate limit. Waiting 15 seconds.") 62 | time.sleep(15) 63 | response = openai.Completion.create( 64 | model="text-davinci-002", 65 | prompt=curr_prompt, 66 | temperature=.99, 67 | max_tokens=50, 68 | n=10, 69 | stop="." 70 | ) 71 | 72 | time.sleep(0.15) 73 | 74 | for r in range(len(response["choices"])): 75 | result = response["choices"][r]["text"] 76 | all_result.append(result.replace("\n\n", "") + ".") 77 | all_responses[category] = all_result 78 | 79 | output_path = args.output_path.format(args.openai_model) 80 | with open(output_path, 'w') as f: 81 | json.dump(all_responses, f, indent=4) 82 | 83 | 84 | if __name__ == "__main__": 85 | args = get_args() 86 | main(args) 87 | -------------------------------------------------------------------------------- /tools/get_exemplars_tta.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | import clip 5 | import json 6 | from PIL import Image 7 | import sys 8 | import argparse 9 | # from lvis import LVIS 10 | import os 11 | # from pprint import pprint 12 | import random 13 | import numpy as np 14 | from tqdm.auto import tqdm 15 | import torch 16 | import torchvision.transforms as tvt 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | try: 20 | from torchvision.transforms import InterpolationMode 21 | BICUBIC = InterpolationMode.BICUBIC 22 | except ImportError: 23 | BICUBIC = Image.BICUBIC 24 | 25 | 26 | # PATHS = { 27 | # "imagenet21k": "", 28 | # "visual_genome": "/scratch/local/hdd/prannay/datasets/VisualGenome/", 29 | # "lvis": "/scratch/local/hdd/prannay/datasets/coco/", 30 | # } 31 | 32 | 33 | def _convert_image_to_rgb(image: Image.Image): 34 | return image.convert("RGB") 35 | 36 | 37 | def get_crop(img, bb, context=0.0, square=True): 38 | # print(bb) 39 | x1, y1, w, h = bb 40 | W, H = img.size 41 | y, x = y1 + h / 2.0, x1 + w / 2.0 42 | h, w = h * (1. + context), w * (1. + context) 43 | if square: 44 | w = max(w, h) 45 | h = max(w, h) 46 | # print(x, y, w, h) 47 | x1, x2 = x - w / 2.0, x + w / 2.0 48 | y1, y2 = y - h / 2.0, y + h / 2.0 49 | # print([x1, y1, x2, y2]) 50 | x1, x2 = max(0, x1), min(W, x2) 51 | y1, y2 = max(0, y1), min(H, y2) 52 | # print([x1, y1, x2, y2]) 53 | bb_new = [int(c) for c in [x1, y1, x2, y2]] 54 | # print(bb_new) 55 | crop = img.crop(bb_new) 56 | return crop 57 | 58 | 59 | def run_crop(d, paths, context=0.4, square=True): 60 | dataset = d['dataset'] 61 | file_name = os.path.join(paths[dataset], d['file_name']) 62 | # with open(file_name, "rb") as f: 63 | img = Image.open(file_name) 64 | if dataset == "imagenet21k": 65 | bb = [0, 0, 0, 0] 66 | return img 67 | elif dataset == "lvis": 68 | bb = [ 69 | int(c) 70 | for c in [ 71 | d['bbox'][0] // 1, 72 | d['bbox'][1] // 1, 73 | d['bbox'][2] // 1 + 1, 74 | d['bbox'][3] // 1 + 1 75 | ] 76 | ] 77 | elif dataset == "visual_genome": 78 | bb = [int(c) for c in [d['x'], d['y'], d['w'], d['h']]] 79 | return get_crop(img, bb, context=context, square=square) 80 | 81 | 82 | def get_args(): 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument( 85 | "--ann-path", 86 | type=str, 87 | default="datasets/metadata/lvis_image_exemplar_dict_K-005_own.json" 88 | ) 89 | parser.add_argument( 90 | "--output-path", 91 | type=str, 92 | default="datasets/metadata/lvis_image_exemplar_features_avg_K-005_own.npy" 93 | ) 94 | parser.add_argument( 95 | "--lvis-img-dir", 96 | type=str, 97 | default="datasets/coco/" 98 | ) 99 | parser.add_argument( 100 | "--imagenet-img-dir", 101 | type=str, 102 | default="datasets/imagenet" 103 | ) 104 | parser.add_argument( 105 | "--visual-genome-img-dir", 106 | type=str, 107 | default="datasets/VisualGenome/" 108 | ) 109 | parser.add_argument("--num-augs", type=int, default=5) 110 | parser.add_argument("--nw", type=int, default=8) 111 | 112 | args = parser.parse_args() 113 | return args 114 | 115 | 116 | def main(args): 117 | 118 | anns_path = args.ann_path 119 | num_augs = args.num_augs 120 | 121 | model, transform = clip.load("ViT-B/32", device="cpu") 122 | del model.transformer 123 | 124 | model = model.to("cuda:0") 125 | 126 | run(anns_path, num_augs, model, transform, args) 127 | 128 | 129 | class CropDataset(Dataset): 130 | def __init__( 131 | self, 132 | exemplar_dict, 133 | num_augs, 134 | transform, 135 | transform2, 136 | args, 137 | ): 138 | self.exemplar_dict = exemplar_dict 139 | self.transform = transform 140 | self.transform2 = transform2 141 | self.num_augs = num_augs 142 | self.paths = { 143 | "imagenet21k": args.imagenet_img_dir, 144 | "visual_genome": args.visual_genome_img_dir, 145 | "lvis": args.lvis_img_dir, 146 | } 147 | 148 | def __len__(self): 149 | return len(self.exemplar_dict) 150 | 151 | def __getitem__(self, idx): 152 | chosen_anns = self.exemplar_dict[idx] 153 | crops = [run_crop(ann, self.paths) for ann in chosen_anns] 154 | # add the tta in here somewhere 155 | crops = [ 156 | self.transform(self.transform2(crop)) 157 | for crop in crops for _ in range(self.num_augs) 158 | ] 159 | return torch.stack(crops) 160 | 161 | 162 | def run(anns_path, num_augs, model, transform, args): 163 | random.seed(100000 + num_augs) 164 | torch.manual_seed(100000 + num_augs) 165 | s = 0.25 166 | color_jitter = tvt.ColorJitter( 167 | 0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s 168 | ) 169 | transform2 = tvt.Compose([ 170 | tvt.RandomResizedCrop(size=224 * 4, scale=(0.8, 1.0), interpolation=BICUBIC), 171 | tvt.RandomHorizontalFlip(), # with 0.5 probability 172 | _convert_image_to_rgb, 173 | tvt.RandomApply([color_jitter], p=0.8), 174 | ]) 175 | with open(anns_path, "r") as fp: 176 | exemplar_dict = json.load(fp) 177 | dataset = CropDataset(exemplar_dict, num_augs, transform, transform2, args) 178 | dataloader = DataLoader( 179 | dataset, batch_size=1, shuffle=False, pin_memory=True, num_workers=args.nw) 180 | feats = [] 181 | # chosen_anns_all = [] 182 | for crops in tqdm(dataloader, total=len(dataloader)): 183 | # rng = np.random.default_rng(seed) 184 | # synset = catid2synset[cat_id] 185 | # trial_anns = exemplar_dict[synset] 186 | # probs = [a['area'] for a in trial_anns] 187 | # probs = np.array(probs) / sum(probs) 188 | # chosen_anns = rng.choice(trial_anns, size=K, p=probs, replace=len(trial_anns) < K) 189 | # if len(trial_anns) < K: 190 | # chosen_anns = rng.choice(trial_anns, size=K, p=[a['area'] for a in trial_anns], replace=False) 191 | # else: 192 | # chosen_anns = rng.sample(trial_anns, k=K, counts=[a['area'] for a in trial_anns 193 | # crops = [run_crop(ann) for ann in chosen_anns] 194 | # # add the tta in here somewhere 195 | # crops = [transform(transform2(crop)) for crop in crops for _ in range(num_augs)] 196 | with torch.no_grad(): 197 | image_embeddings = model.encode_image(crops[0].to("cuda:0")) 198 | # print(image_embeddings.size()) 199 | feats.append(image_embeddings.cpu()) 200 | # chosen_anns_all.append(chosen_anns.tolist()) 201 | # crops.append([run_crop(ann) for ann in chosen_anns]) 202 | 203 | feats_all = torch.stack(feats) 204 | 205 | save_basename = args.output_path 206 | 207 | np.save(save_basename, feats_all.mean(dim=1).numpy()) 208 | 209 | 210 | if __name__ == "__main__": 211 | args = get_args() 212 | main(args) 213 | -------------------------------------------------------------------------------- /tools/get_lvis_cat_info.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import argparse 3 | import json 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--ann", default='datasets/lvis/lvis_v1_train.json') 8 | parser.add_argument("--add_freq", action='store_true') 9 | parser.add_argument("--r_thresh", type=int, default=10) 10 | parser.add_argument("--c_thresh", type=int, default=100) 11 | parser.add_argument("--bypass", action='store_true') 12 | args = parser.parse_args() 13 | 14 | print('Loading', args.ann) 15 | data = json.load(open(args.ann, 'r')) 16 | cats = data['categories'] 17 | image_count = {x['id']: set() for x in cats} 18 | ann_count = {x['id']: 0 for x in cats} 19 | if args.bypass: 20 | for x in data['images']: 21 | for y in x['pos_category_ids']: 22 | image_count[y].add(x['id']) 23 | ann_count[y] += 1 24 | else: 25 | for x in data['annotations']: 26 | image_count[x['category_id']].add(x['image_id']) 27 | ann_count[x['category_id']] += 1 28 | num_freqs = {x: 0 for x in ['r', 'f', 'c']} 29 | for x in cats: 30 | x['image_count'] = len(image_count[x['id']]) 31 | x['instance_count'] = ann_count[x['id']] 32 | if args.add_freq: 33 | freq = 'f' 34 | if x['image_count'] < args.c_thresh: 35 | freq = 'c' 36 | if x['image_count'] < args.r_thresh: 37 | freq = 'r' 38 | x['frequency'] = freq 39 | num_freqs[freq] += 1 40 | print(cats) 41 | image_counts = sorted([x['image_count'] for x in cats]) 42 | # print('image count', image_counts) 43 | # import pdb; pdb.set_trace() 44 | if args.add_freq: 45 | for x in ['r', 'c', 'f']: 46 | print(x, num_freqs[x]) 47 | out = cats # {'categories': cats} 48 | out_path = args.ann[:-5] + '_cat_info.json' 49 | print('Saving to', out_path) 50 | json.dump(out, open(out_path, 'w')) 51 | 52 | -------------------------------------------------------------------------------- /tools/norm_feat_sum_norm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | def get_args(): 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument( 7 | "--feat1-path", 8 | type=str, 9 | default="datasets/metadata/lvis_gpt3_text-davinci-002_features_author.npy" 10 | ) 11 | parser.add_argument( 12 | "--feat2-path", 13 | type=str, 14 | default="datasets/metadata/lvis_image_exemplar_features_avg_K-005_own.npy" 15 | ) 16 | parser.add_argument( 17 | "--out-path", 18 | type=str, 19 | default="datasets/metadata/lvis_multi-modal_avg_K-005_own.npy" 20 | ) 21 | return parser.parse_args() 22 | 23 | 24 | def main(args): 25 | 26 | feat1 = np.load(args.feat1_path) 27 | feat2 = np.load(args.feat2_path) 28 | # l2 normalize each 29 | feat1 = feat1 / np.linalg.norm(feat1, axis=1, keepdims=True) 30 | feat2 = feat2 / np.linalg.norm(feat2, axis=1, keepdims=True) 31 | # take sum 32 | feat = feat1 + feat2 33 | # l2 normalize again 34 | feat = feat / np.linalg.norm(feat, axis=1, keepdims=True) 35 | np.save(args.out_path, feat) 36 | 37 | 38 | if __name__ == '__main__': 39 | args = get_args() 40 | main(args) 41 | -------------------------------------------------------------------------------- /tools/remove_lvis_rare.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import argparse 3 | import json 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--ann', default='datasets/lvis/lvis_v1_train.json') 8 | args = parser.parse_args() 9 | 10 | print('Loading', args.ann) 11 | data = json.load(open(args.ann, 'r')) 12 | catid2freq = {x['id']: x['frequency'] for x in data['categories']} 13 | print('ori #anns', len(data['annotations'])) 14 | exclude = ['r'] 15 | data['annotations'] = [x for x in data['annotations'] \ 16 | if catid2freq[x['category_id']] not in exclude] 17 | print('filtered #anns', len(data['annotations'])) 18 | out_path = args.ann[:-5] + '_norare.json' 19 | print('Saving to', out_path) 20 | json.dump(data, open(out_path, 'w')) 21 | -------------------------------------------------------------------------------- /tools/sample_exemplars.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from lvis import LVIS 4 | from tqdm.auto import tqdm 5 | import numpy as np 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument( 11 | "--exemplar-dict-path", 12 | type=str, 13 | default="datasets/metadata/exemplar_dict.json" 14 | ) 15 | parser.add_argument( 16 | "--lvis-ann-path", 17 | type=str, 18 | default="datasets/lvis/lvis_v1_val.json" 19 | ) 20 | parser.add_argument( 21 | "--out-path", 22 | type=str, 23 | default="datasets/metadata/lvis_image_exemplar_dict_K-005_own.json" 24 | ) 25 | parser.add_argument( 26 | "-K", 27 | type=int, 28 | default=5 29 | ) 30 | parser.add_argument( 31 | "--seed", 32 | type=int, 33 | default=42 34 | ) 35 | args = parser.parse_args() 36 | return args 37 | 38 | 39 | def run(catid2synset, exemplar_dict, K, seed, out_path): 40 | chosen_anns_all = [] 41 | for cat_id in tqdm(sorted(catid2synset.keys()), total=len(catid2synset)): 42 | rng = np.random.default_rng(seed) 43 | synset = catid2synset[cat_id] 44 | trial_anns = exemplar_dict[synset] 45 | probs = [a['area'] for a in trial_anns] 46 | probs = np.array(probs) / sum(probs) 47 | chosen_anns = rng.choice(trial_anns, size=K, p=probs, replace=False) 48 | chosen_anns_all.append(chosen_anns.tolist()) 49 | with open(out_path, "w") as f: 50 | json.dump(chosen_anns_all, f) 51 | 52 | 53 | def main(args): 54 | lvis_cats = LVIS(args.lvis_ann_path).cats 55 | catid2synset = {v['id']: v['synset'] for v in lvis_cats.values()} 56 | with open(args.exemplar_dict_path, 'r') as f: 57 | exemplar_dict = json.load(f) 58 | for k, v in exemplar_dict.items(): 59 | for ann in v: 60 | dataset = ann['dataset'] 61 | if dataset == "imagenet21k": 62 | ann['area'] = 100. * 100. 63 | elif dataset == "lvis": 64 | ann['area'] = float(ann['bbox'][2] * ann['bbox'][3]) 65 | elif dataset == "visual_genome": 66 | ann['area'] = float(ann['w'] * ann['h']) 67 | 68 | run(catid2synset, exemplar_dict, args.K, args.seed, args.out_path) 69 | 70 | 71 | if __name__ == "__main__": 72 | args = get_args() 73 | main(args) 74 | -------------------------------------------------------------------------------- /tools/unzip_imagenet_lvis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import os 3 | import argparse 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--src_path', default='datasets/imagenet/ImageNet-21K/') 8 | parser.add_argument('--dst_path', default='datasets/imagenet/ImageNet-LVIS/') 9 | parser.add_argument('--data_path', default='datasets/imagenet_lvis_wnid.txt') 10 | args = parser.parse_args() 11 | 12 | f = open(args.data_path) 13 | for i, line in enumerate(f): 14 | cmd = 'mkdir {x} && tar -xf {src}/{l}.tar -C {x}'.format( 15 | src=args.src_path, 16 | l=line.strip(), 17 | x=args.dst_path + '/' + line.strip()) 18 | print(i, cmd) 19 | os.system(cmd) 20 | -------------------------------------------------------------------------------- /train_net_auto.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import logging 3 | import os 4 | import sys 5 | from collections import OrderedDict 6 | import torch 7 | from torch.nn.parallel import DistributedDataParallel 8 | import time 9 | import datetime 10 | from typing import Any 11 | 12 | from fvcore.common.timer import Timer 13 | from iopath.common.file_io import PathManager 14 | import detectron2.utils.comm as comm 15 | from detectron2.checkpoint import ( 16 | DetectionCheckpointer, PeriodicCheckpointer, Checkpointer 17 | ) 18 | from detectron2.config import get_cfg 19 | from detectron2.data import ( 20 | MetadataCatalog, 21 | build_detection_test_loader, 22 | ) 23 | from detectron2.engine import default_argument_parser, default_setup, launch 24 | 25 | from detectron2.evaluation import ( 26 | inference_on_dataset, 27 | print_csv_format, 28 | LVISEvaluator, 29 | COCOEvaluator, 30 | ) 31 | from detectron2.modeling import build_model 32 | from detectron2.solver import build_lr_scheduler, build_optimizer 33 | from detectron2.utils.events import ( 34 | CommonMetricPrinter, 35 | EventStorage, 36 | JSONWriter, 37 | TensorboardXWriter, 38 | ) 39 | from detectron2.data.dataset_mapper import DatasetMapper 40 | from detectron2.data.build import build_detection_train_loader 41 | from detectron2.utils.logger import setup_logger 42 | from torch.cuda.amp import GradScaler 43 | 44 | sys.path.insert(0, 'third_party/CenterNet2/') 45 | from centernet.config import add_centernet_config 46 | 47 | from mmovod.config import add_mmovod_config 48 | from mmovod.data.custom_build_augmentation import build_custom_augmentation 49 | from mmovod.data.custom_dataset_dataloader import build_custom_train_loader 50 | from mmovod.data.custom_dataset_mapper import CustomDatasetMapper 51 | from mmovod.custom_solver import build_custom_optimizer 52 | from mmovod.modeling.utils import reset_cls_test 53 | 54 | 55 | logger = logging.getLogger("detectron2") 56 | 57 | 58 | class LatestCheckpointer: 59 | """ 60 | Save checkpoints periodically. When `.step(iteration)` is called, it will 61 | execute `checkpointer.save` on the given checkpointer, if iteration is a 62 | multiple of period or if `max_iter` is reached. 63 | 64 | Attributes: 65 | checkpointer (Checkpointer): the underlying checkpointer object 66 | """ 67 | 68 | def __init__( 69 | self, 70 | checkpointer: Checkpointer, 71 | period: int, 72 | file_prefix: str = "model", 73 | ) -> None: 74 | """ 75 | Args: 76 | checkpointer: the checkpointer object used to save checkpoints. 77 | period (int): the period to save checkpoint. 78 | max_iter (int): maximum number of iterations. When it is reached, 79 | a checkpoint named "{file_prefix}_final" will be saved. 80 | max_to_keep (int): maximum number of most current checkpoints to keep, 81 | previous checkpoints will be deleted 82 | file_prefix (str): the prefix of checkpoint's filename 83 | """ 84 | self.checkpointer = checkpointer 85 | self.period = int(period) 86 | self.path_manager: PathManager = checkpointer.path_manager 87 | self.file_prefix = file_prefix 88 | 89 | def step(self, iteration: int, **kwargs: Any) -> None: 90 | """ 91 | Perform the appropriate action at the given iteration. 92 | 93 | Args: 94 | iteration (int): the current iteration, ranged in [0, max_iter-1]. 95 | kwargs (Any): extra data to save, same as in 96 | :meth:`Checkpointer.save`. 97 | """ 98 | iteration = int(iteration) 99 | additional_state = {"iteration": iteration} 100 | additional_state.update(kwargs) 101 | 102 | if (iteration + 1) % self.period == 0: 103 | self.checkpointer.save( 104 | f"{self.file_prefix}_latest", **additional_state 105 | ) 106 | 107 | def save(self, name: str, **kwargs: Any) -> None: 108 | """ 109 | Same argument as :meth:`Checkpointer.save`. 110 | Use this method to manually save checkpoints outside the schedule. 111 | 112 | Args: 113 | name (str): file name. 114 | kwargs (Any): extra data to save, same as in 115 | :meth:`Checkpointer.save`. 116 | """ 117 | self.checkpointer.save(name, **kwargs) 118 | 119 | 120 | def do_test(cfg, model): 121 | results = OrderedDict() 122 | for d, dataset_name in enumerate(cfg.DATASETS.TEST): 123 | if cfg.MODEL.RESET_CLS_TESTS: 124 | reset_cls_test( 125 | model, 126 | cfg.MODEL.TEST_CLASSIFIERS[d], 127 | cfg.MODEL.TEST_NUM_CLASSES[d]) 128 | mapper = None if cfg.INPUT.TEST_INPUT_TYPE == 'default' \ 129 | else DatasetMapper( 130 | cfg, False, augmentations=build_custom_augmentation(cfg, False)) 131 | data_loader = build_detection_test_loader(cfg, dataset_name, mapper=mapper) 132 | output_folder = os.path.join( 133 | cfg.OUTPUT_DIR, "inference_{}".format(dataset_name)) 134 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 135 | 136 | if evaluator_type == "lvis" or cfg.GEN_PSEDO_LABELS: 137 | evaluator = LVISEvaluator(dataset_name, cfg, True, output_folder) 138 | elif evaluator_type == 'coco': 139 | evaluator = COCOEvaluator(dataset_name, cfg, True, output_folder) 140 | else: 141 | assert 0, evaluator_type 142 | 143 | results[dataset_name] = inference_on_dataset( 144 | model, data_loader, evaluator) 145 | if comm.is_main_process(): 146 | logger.info("Evaluation results for {} in csv format:".format( 147 | dataset_name)) 148 | print_csv_format(results[dataset_name]) 149 | if len(results) == 1: 150 | results = list(results.values())[0] 151 | return results 152 | 153 | 154 | def do_train(cfg, model, resume=False): 155 | model.train() 156 | if cfg.SOLVER.USE_CUSTOM_SOLVER: 157 | optimizer = build_custom_optimizer(cfg, model) 158 | else: 159 | assert cfg.SOLVER.OPTIMIZER == 'SGD' 160 | assert cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE != 'full_model' 161 | assert cfg.SOLVER.BACKBONE_MULTIPLIER == 1. 162 | optimizer = build_optimizer(cfg, model) 163 | scheduler = build_lr_scheduler(cfg, optimizer) 164 | logger.info("Following parameters will be trained:") 165 | for n, p in model.named_parameters(): 166 | if p.requires_grad: 167 | logger.info("{}".format(n)) 168 | 169 | checkpointer = DetectionCheckpointer( 170 | model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler 171 | ) 172 | 173 | start_iter = checkpointer.resume_or_load( 174 | cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 175 | if not resume: 176 | start_iter = 0 177 | max_iter = cfg.SOLVER.MAX_ITER if cfg.SOLVER.TRAIN_ITER < 0 else cfg.SOLVER.TRAIN_ITER 178 | 179 | periodic_checkpointer = PeriodicCheckpointer( 180 | checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter 181 | ) 182 | 183 | latest_checkpointer = LatestCheckpointer( 184 | checkpointer, 15000,) 185 | 186 | writers = ( 187 | [ 188 | CommonMetricPrinter(max_iter), 189 | JSONWriter(os.path.join(cfg.OUTPUT_DIR, "metrics.json")), 190 | TensorboardXWriter(cfg.OUTPUT_DIR), 191 | ] 192 | if comm.is_main_process() 193 | else [] 194 | ) 195 | 196 | use_custom_mapper = cfg.WITH_IMAGE_LABELS 197 | MapperClass = CustomDatasetMapper if use_custom_mapper else DatasetMapper 198 | mapper = MapperClass(cfg, True) if cfg.INPUT.CUSTOM_AUG == '' else \ 199 | MapperClass(cfg, True, augmentations=build_custom_augmentation(cfg, True)) 200 | if cfg.DATALOADER.SAMPLER_TRAIN in ['TrainingSampler', 'RepeatFactorTrainingSampler']: 201 | data_loader = build_detection_train_loader(cfg, mapper=mapper) 202 | else: 203 | data_loader = build_custom_train_loader(cfg, mapper=mapper) 204 | 205 | if cfg.FP16: 206 | scaler = GradScaler() 207 | 208 | logger.info("Starting training from iteration {}".format(start_iter)) 209 | with EventStorage(start_iter) as storage: 210 | step_timer = Timer() 211 | data_timer = Timer() 212 | start_time = time.perf_counter() 213 | for data, iteration in zip(data_loader, range(start_iter, max_iter)): 214 | data_time = data_timer.seconds() 215 | storage.put_scalars(data_time=data_time) 216 | step_timer.reset() 217 | iteration = iteration + 1 218 | storage.step() 219 | loss_dict = model(data) 220 | 221 | losses = sum( 222 | loss for k, loss in loss_dict.items()) 223 | assert torch.isfinite(losses).all(), loss_dict 224 | 225 | loss_dict_reduced = { 226 | k: v.item() 227 | for k, v in comm.reduce_dict(loss_dict).items() 228 | } 229 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 230 | if comm.is_main_process(): 231 | storage.put_scalars( 232 | total_loss=losses_reduced, **loss_dict_reduced) 233 | 234 | optimizer.zero_grad() 235 | if cfg.FP16: 236 | scaler.scale(losses).backward() 237 | scaler.step(optimizer) 238 | scaler.update() 239 | else: 240 | losses.backward() 241 | optimizer.step() 242 | 243 | storage.put_scalar( 244 | "lr", optimizer.param_groups[0]["lr"], smoothing_hint=False) 245 | 246 | step_time = step_timer.seconds() 247 | storage.put_scalars(time=step_time) 248 | data_timer.reset() 249 | scheduler.step() 250 | 251 | if (cfg.TEST.EVAL_PERIOD > 0 252 | and iteration % cfg.TEST.EVAL_PERIOD == 0 253 | and iteration != max_iter): 254 | do_test(cfg, model) 255 | comm.synchronize() 256 | 257 | if (iteration - start_iter > 5 258 | and (iteration % 20 == 0 or iteration == max_iter)): 259 | for writer in writers: 260 | writer.write() 261 | latest_checkpointer.step(iteration) 262 | periodic_checkpointer.step(iteration) 263 | 264 | total_time = time.perf_counter() - start_time 265 | logger.info( 266 | "Total training time: {}".format( 267 | str(datetime.timedelta(seconds=int(total_time))))) 268 | 269 | 270 | def setup(args): 271 | """ 272 | Create configs and perform basic setups. 273 | """ 274 | cfg = get_cfg() 275 | add_centernet_config(cfg) 276 | add_mmovod_config(cfg) 277 | cfg.merge_from_file(args.config_file) 278 | cfg.merge_from_list(args.opts) 279 | if '/auto' in cfg.OUTPUT_DIR: 280 | if "configs/" in args.config_file: 281 | new_sub_folder = args.config_file.replace("configs", "")[:-5] 282 | new_output_dir = cfg.OUTPUT_DIR.replace("/auto", new_sub_folder) 283 | cfg.OUTPUT_DIR = new_output_dir 284 | else: 285 | file_name = os.path.basename(args.config_file)[:-5] 286 | cfg.OUTPUT_DIR = cfg.OUTPUT_DIR.replace('/auto', '/{}'.format(file_name)) 287 | print(cfg.OUTPUT_DIR) 288 | cfg.freeze() 289 | default_setup(cfg, args) 290 | setup_logger( 291 | output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="mmovod") 292 | return cfg 293 | 294 | 295 | def main(args): 296 | cfg = setup(args) 297 | 298 | model = build_model(cfg) 299 | logger.info("Model:\n{}".format(model)) 300 | if args.eval_only: 301 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 302 | cfg.MODEL.WEIGHTS, resume=args.resume 303 | ) 304 | 305 | return do_test(cfg, model) 306 | 307 | distributed = comm.get_world_size() > 1 308 | if distributed: 309 | model = DistributedDataParallel( 310 | model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, 311 | find_unused_parameters=cfg.FIND_UNUSED_PARAM 312 | ) 313 | 314 | do_train(cfg, model, resume=args.resume) 315 | return do_test(cfg, model) 316 | 317 | 318 | if __name__ == "__main__": 319 | args = default_argument_parser() 320 | args = args.parse_args() 321 | if args.num_machines == 1: 322 | args.dist_url = 'tcp://127.0.0.1:{}'.format( 323 | torch.randint(11111, 60000, (1,))[0].item()) 324 | else: 325 | if args.dist_url == 'host': 326 | args.dist_url = 'tcp://{}:12345'.format( 327 | os.environ['SLURM_JOB_NODELIST']) 328 | elif not args.dist_url.startswith('tcp'): 329 | tmp = os.popen( 330 | 'echo $(scontrol show job {} | grep BatchHost)'.format( 331 | args.dist_url) 332 | ).read() 333 | tmp = tmp[tmp.find('=') + 1: -1] 334 | args.dist_url = 'tcp://{}:12345'.format(tmp) 335 | print("Command Line Args:", args) 336 | launch( 337 | main, 338 | args.num_gpus, 339 | num_machines=args.num_machines, 340 | machine_rank=args.machine_rank, 341 | dist_url=args.dist_url, 342 | args=(args,), 343 | ) 344 | --------------------------------------------------------------------------------