├── DATA.md ├── GETTING_STARTED.md ├── LICENSE ├── README.md ├── assets ├── framework.jpg ├── image.jpg └── results.jpg ├── configs ├── Base-RegionSpot.yaml ├── eval.yaml ├── objects365_bb.yaml ├── objects365_bl.yaml ├── objects365_bl_336.yaml ├── objects365_v3det_openimages_bb.yaml ├── objects365_v3det_openimages_bl.yaml └── objects365_v3det_openimages_bl_336.yaml ├── demo.py ├── regionspot ├── __init__.py ├── build.py ├── config.py ├── data │ ├── custom_dataset_dataloader.py │ ├── dataset_mapper.py │ ├── objects365.py │ ├── openimages.py │ ├── openimages_categories.py │ ├── v3det.py │ └── v3det_categories.py ├── detector.py ├── modeling │ ├── clip │ │ ├── __init__.py │ │ ├── clip.py │ │ ├── model.py │ │ ├── simple_tokenizer.py │ │ ├── utils.py │ │ └── vit.py │ ├── constants.py │ ├── decoder.py │ ├── regionspot.py │ └── segment_anything │ │ ├── __init__.py │ │ ├── automatic_mask_generator.py │ │ ├── build_sam.py │ │ ├── modeling │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── common.cpython-38.pyc │ │ │ ├── common.cpython-39.pyc │ │ │ ├── image_encoder.cpython-38.pyc │ │ │ ├── image_encoder.cpython-39.pyc │ │ │ ├── mask_decoder.cpython-38.pyc │ │ │ ├── mask_decoder.cpython-39.pyc │ │ │ ├── prompt_encoder.cpython-38.pyc │ │ │ ├── prompt_encoder.cpython-39.pyc │ │ │ ├── prompt_engineering.cpython-38.pyc │ │ │ ├── sam.cpython-38.pyc │ │ │ ├── sam.cpython-39.pyc │ │ │ ├── transformer.cpython-38.pyc │ │ │ └── transformer.cpython-39.pyc │ │ ├── common.py │ │ ├── image_encoder.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ ├── prompt_engineering.py │ │ ├── sam.py │ │ ├── transformer.py │ │ └── utils.py │ │ ├── predictor.py │ │ └── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── amg.cpython-38.pyc │ │ ├── amg.cpython-39.pyc │ │ ├── transforms.cpython-38.pyc │ │ └── transforms.cpython-39.pyc │ │ ├── amg.py │ │ ├── onnx.py │ │ └── transforms.py ├── predictor.py ├── test_time_augmentation.py └── util │ ├── __init__.py │ ├── box_ops.py │ ├── colormap.py │ ├── misc.py │ ├── model_ema.py │ ├── plot_utils.py │ ├── postprocessing.py │ ├── preprocessing.py │ └── transforms.py ├── tools └── re_save_ckpt.py └── train_net.py /DATA.md: -------------------------------------------------------------------------------- 1 | # Data Preparation 2 | Our model was trained using three datasets: Objects365v1, V3Det, and OpenImages. We conducted tests on the LVIS dataset in a zero-shot manner. Please organize the datasets as follows. 3 | ## Pretrained Weights 4 | SAM Pretrain Weights (ViT-base) 5 | ```bash 6 | mkdir -p sam_checkpoints 7 | cd sam_checkpoints 8 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth 9 | cd .. 10 | ``` 11 | ## Data 12 | ### Training 13 | 1. Datasets preparation 14 | Download the datasets from their respective official websites. Ensure that you have [objects365](https://www.objects365.org/overview.html), [V3Det](https://v3det.openxlab.org.cn/) and [OpenImages V6](https://storage.googleapis.com/openimages/web/download_v6.html). Organize the downloaded datasets as follows: 15 | ``` 16 | ${ROOT} 17 | -- datasets 18 | --objects365 19 | --v3det 20 | --openimages 21 | ``` 22 | 23 | 2. Mask Token Preparation 24 | As the SAM (Segment Anything Model) has been set to a frozen state, we've optimized our resource usage by pre-extracting the image mask tokens. This step significantly reduces memory consumption during model training and inference. We have made these pre-extracted mask tokens available for easy access: 25 | [Download Masks Tokens from One Drive](https://1drv.ms/f/s!AgWqwlwga-5Ka9-HT1L83INBHsU?e=wTbJz5) 26 | We anticipate the data to be organized as follows: 27 | 28 | ``` bash 29 | ${ROOT} 30 | -- datasets 31 | -- datasets_mask_tokens_vit_b 32 | --objects365 33 | --v3det 34 | --openimages 35 | 36 | ``` 37 | ### Evaluation 38 | For model evaluation, download the LVIS dataset from [COCO](https://cocodataset.org/#home), [LVIS Dataset](https://www.lvisdataset.org/) and place it in the `datasets` folder at the project root: 39 | ``` 40 | ${ROOT} 41 | -- datasets 42 | --coco 43 | --lvis 44 | ``` 45 | After downloading the LVIS dataset, also obtain the bounding box results from GLIP by downloading the provided JSON file: 46 | 47 | - Download the file from [GLIP Box Results]( https://1drv.ms/u/s!AgWqwlwga-5KdWacuP6dTKajYRg?e=PIBdYd). 48 | 49 | Once downloaded, place the JSON file in the `glip_results` directory within `datasets`: 50 | ``` 51 | ${ROOT} 52 | -- datasets 53 | --glip_results 54 | nms_results_glip_tiny_model_o365_goldg_cc_sbu_lvis_val.json 55 | ``` 56 | -------------------------------------------------------------------------------- /GETTING_STARTED.md: -------------------------------------------------------------------------------- 1 | ## Getting Started with RegionSpot 2 | 3 | 4 | ### Installation 5 | 6 | The codebases are built on top of [Detectron2](https://github.com/facebookresearch/detectron2). 7 | 8 | #### Requirements 9 | - **Operating System**: Linux or macOS 10 | - **Python**: Version 3.6 or newer 11 | - **PyTorch**: Version 1.9.0 or newer, along with the compatible version of [torchvision](https://github.com/pytorch/vision/). You can install both from [pytorch.org](https://pytorch.org). 12 | 13 | 14 | #### Steps 15 | 1. **Detectron2 Installation**: 16 | Install Detectron2 by following the official installation guide available here: 17 | [Detectron2 Installation Guide](https://github.com/facebookresearch/detectron2/blob/main/INSTALL.md#installation). 18 | 1. **CLIP Installation** 19 | 20 | Install CLIP by following the official installation guide available here: 21 | [CLIP Installation](https://github.com/openai/CLIP). 22 | 2. **Data Preparation**: 23 | Organize your data according to the instructions provided in [DATA.md](./DATA.md) in this repository. 24 | 25 | 4. **Model Training**: 26 | To train the RegionSpot model, use the following command templates: 27 | 28 | ```bash 29 | # Stage 1 Training: 30 | python3 train_net.py --num-gpus 8 \ 31 | --config-file configs/objects365_bl.yaml 32 | 33 | # Stage 2 Training: 34 | python3 train_net.py --num-gpus 8 \ 35 | --config-file configs/objects365_v3det_openimages_bl.yaml 36 | 37 | 4. **Model Evaluation**: 38 | To evaluate the trained RegionSpot model, use the following command. Ensure that the `MODEL.CLIP_TYPE` and `MODEL.CLIP_INPUT_SIZE` corresponds to the particular `MODEL.WEIGHTS` you are using. 39 | 40 | ```bash 41 | python3 train_net.py --num-gpus 8 \ 42 | --config-file configs/eval.yaml \ 43 | --eval-only \ 44 | MODEL.WEIGHTS /path/to/model_weights.pth \ 45 | MODEL.CLIP_TYPE CLIP_400M_Large \ 46 | MODEL.CLIP_INPUT_SIZE 224 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Recognize Any Regions(NeurIPS 2024) 2 | 3 | ![teaser](assets/framework.jpg) 4 | 5 | 6 | 7 | > [**Recognize Any Regions**](https://arxiv.org/pdf/2311.01373.pdf) 8 | > Haosen Yang, Chuofan Ma, Bin Wen, Yi Jiang, Zehuan Yuan, Xiatian Zhu 9 | 10 | ## Updates 11 | - **`2023/11/7`**: Checkpoints are available on both [Google Drive](https://drive.google.com/drive/folders/1jPfdzsZTRd95xOX7YcSJkw1IuI_OhjY4?usp=sharing) and [OneDrive](https://onedrive.live.com/?id=4AEE6B205CC2AA05%21106&cid=4AEE6B205CC2AA05). 12 | - **`2023/11/6`**: Code is available Now! 13 | 14 | ## Models 15 | Method | Box AP_rare| Box AP_all | Mask AP_rare | Mask AP_all | Download 16 | --- |:---:|:---:|:---: |:---: |:---: 17 | RegionSpot-BB | 19.1 | 20.9 | 17.5 | 17.8| [model]( https://1drv.ms/u/s!AgWqwlwga-5Kc1WU0Q_iFsc_O-w?e=DgN1xI) 18 | RegionSpot-BL| 26.0 | 23.7 | 22.8 | 20.2 | [model]( https://1drv.ms/u/s!AgWqwlwga-5KdKXfWxTuronI6ts?e=aGhZxj) 19 | RegionSpot-BL@336px | 26.3 | 25.0 | 23.4 | 21.3 | [model](https://1drv.ms/u/s!AgWqwlwga-5KcrrUYKtvTFf4MGY?e=QHZd7u) 20 | 21 | 22 | 23 | ## Getting Started 24 | 25 | The installation instruction and usage are in [Getting Started with Recognize Any Regions](GETTING_STARTED.md). 26 | 27 | ## Demo 28 | 29 | First download a model checkpoint. Then the model can be used in just a few lines to get masks from a given prompt: 30 | 31 | ```bash 32 | from regionspot.modeling.regionspot import build_regionspot_model 33 | from regionspot import RegionSpot_Predictor 34 | custom_vocabulary = [''] 35 | clip_type = 36 | regionspot = build_regionspot_model(checkpoint="", custom_vocabulary=custom_vocabulary, clip_type=clip_type) 37 | predictor = RegionSpot_Predictor(regionspot) 38 | predictor.set_image() 39 | masks, mask_iou_score, class_score, class_index = predictor.predict() 40 | ``` 41 | 42 | See the demo.py on using RegionSpot with box prompts for more details. 43 | ![teaser](assets/results.jpg) 44 | 45 | 46 | ## Citing Recognize Any Regions 47 | 48 | If you use Recognize Any Regions in your research or wish to refer to the baseline results published here, please use the following BibTeX entry. 49 | 50 | ```BibTeX 51 | @inproceedings{yang2023recognize, 52 | title={Recognize any regions}, 53 | author={Yang, Haosen and Ma, Chuofan and Wen, Bin and Jiang, Yi and Yuan, Zehuan and Zhu, Xiatian}, 54 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, 55 | year={2023} 56 | } 57 | ````` 58 | -------------------------------------------------------------------------------- /assets/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/assets/framework.jpg -------------------------------------------------------------------------------- /assets/image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/assets/image.jpg -------------------------------------------------------------------------------- /assets/results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/assets/results.jpg -------------------------------------------------------------------------------- /configs/Base-RegionSpot.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "RegionSpot" 3 | PIXEL_MEAN: [123.675, 116.280, 103.530] 4 | PIXEL_STD: [58.395, 57.120, 57.375] 5 | 6 | SOLVER: 7 | IMS_PER_BATCH: 16 8 | BASE_LR: 0.000025 9 | CHECKPOINT_PERIOD: 50000 10 | STEPS: (210000, 250000) 11 | MAX_ITER: 270000 12 | WARMUP_FACTOR: 0.01 13 | WARMUP_ITERS: 1000 14 | WEIGHT_DECAY: 0.0001 15 | OPTIMIZER: "ADAMW" 16 | BACKBONE_MULTIPLIER: 1.0 # keep same with BASE_LR. 17 | CLIP_GRADIENTS: 18 | ENABLED: True 19 | CLIP_TYPE: "full_model" 20 | CLIP_VALUE: 1.0 21 | NORM_TYPE: 2.0 22 | SEED: 40244023 23 | INPUT: 24 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 25 | CROP: 26 | ENABLED: False 27 | TYPE: "absolute_range" 28 | SIZE: (384, 600) 29 | FORMAT: "RGB" 30 | TEST: 31 | EVAL_PERIOD: 733000000 32 | DATALOADER: 33 | FILTER_EMPTY_ANNOTATIONS: False 34 | NUM_WORKERS: 4 35 | VERSION: 2 36 | 37 | -------------------------------------------------------------------------------- /configs/eval.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RegionSpot.yaml" 2 | MODEL: 3 | CLIP_TYPE: CLIP_400M_Large 4 | TRAINING: False 5 | BOX_TYPE: 'PRED_BOX' 6 | MASK_ON: True 7 | DATASETS: # LVIS 8 | TRAIN: ("lvis_v1_train",) 9 | TEST: ("lvis_v1_val",) 10 | DATALOADER: 11 | SAMPLER_TRAIN: "RepeatFactorTrainingSampler" 12 | REPEAT_THRESHOLD: 0.001 13 | INPUT: 14 | CROP: 15 | ENABLED: True 16 | FORMAT: "RGB" 17 | TEST: # LVIS 18 | EVAL_PERIOD: 0 # disable eval during train since long time 19 | 20 | OUTPUT_DIR: './output/eval' 21 | 22 | -------------------------------------------------------------------------------- /configs/objects365_bb.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RegionSpot.yaml" 2 | MODEL: 3 | CLIP_TYPE: CLIP_400M 4 | DATALOADER: 5 | SAMPLER_TRAIN: "MultiDatasetSampler" 6 | DATASETS: 7 | TRAIN: ("objects365_train",) 8 | TEST: () 9 | TEST: 10 | EVAL_PERIOD: 0 11 | SOLVER: 12 | STEPS: (350000, 420000) 13 | MAX_ITER: 450000 14 | INPUT: 15 | CROP: 16 | ENABLED: True 17 | FORMAT: "RGB" 18 | OUTPUT_DIR: './output/regionspot_obj365_bb' 19 | -------------------------------------------------------------------------------- /configs/objects365_bl.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RegionSpot.yaml" 2 | MODEL: 3 | CLIP_TYPE: CLIP_400M_Large 4 | DATALOADER: 5 | SAMPLER_TRAIN: "MultiDatasetSampler" 6 | DATASETS: 7 | TRAIN: ("objects365_train",) 8 | TEST: () 9 | TEST: 10 | EVAL_PERIOD: 0 11 | SOLVER: 12 | STEPS: (350000, 420000) 13 | MAX_ITER: 450000 14 | INPUT: 15 | CROP: 16 | ENABLED: True 17 | FORMAT: "RGB" 18 | OUTPUT_DIR: './output/regionspot_obj365_bl' 19 | -------------------------------------------------------------------------------- /configs/objects365_bl_336.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RegionSpot.yaml" 2 | MODEL: 3 | CLIP_TYPE: CLIP_400M_Large_336 4 | CLIP_INPUT_SIZE: 336 5 | DATALOADER: 6 | SAMPLER_TRAIN: "MultiDatasetSampler" 7 | DATASETS: 8 | TRAIN: ("objects365_train",) 9 | TEST: () 10 | TEST: 11 | EVAL_PERIOD: 0 12 | SOLVER: 13 | STEPS: (350000, 420000) 14 | MAX_ITER: 450000 15 | INPUT: 16 | CROP: 17 | ENABLED: True 18 | FORMAT: "RGB" 19 | OUTPUT_DIR: './output/regionspot_obj365_bl_336' 20 | -------------------------------------------------------------------------------- /configs/objects365_v3det_openimages_bb.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RegionSpot.yaml" 2 | MODEL: 3 | WEIGHTS: "./output/regionspot_obj365_bb/model_final.pth" 4 | CLIP_TYPE: CLIP_400M 5 | DATALOADER: 6 | SAMPLER_TRAIN: "MultiDatasetSampler" 7 | DATASETS: 8 | TRAIN: ("objects365_train", "v3det_train","openimages_train",) 9 | TEST: () 10 | TEST: 11 | EVAL_PERIOD: 0 12 | SOLVER: 13 | STEPS: (350000, 420000) 14 | MAX_ITER: 450000 15 | INPUT: 16 | CROP: 17 | ENABLED: True 18 | FORMAT: "RGB" 19 | OUTPUT_DIR: './output/regionspot_alldata_bb' 20 | -------------------------------------------------------------------------------- /configs/objects365_v3det_openimages_bl.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RegionSpot.yaml" 2 | MODEL: 3 | WEIGHTS: "./output/regionspot_obj365_bl/model_final.pth" 4 | CLIP_TYPE: CLIP_400M_Large 5 | DATALOADER: 6 | SAMPLER_TRAIN: "MultiDatasetSampler" 7 | DATASETS: 8 | TRAIN: ("objects365_train", "v3det_train","openimages_train",) 9 | TEST: () 10 | TEST: 11 | EVAL_PERIOD: 0 12 | SOLVER: 13 | STEPS: (350000, 420000) 14 | MAX_ITER: 450000 15 | INPUT: 16 | CROP: 17 | ENABLED: True 18 | FORMAT: "RGB" 19 | OUTPUT_DIR: './output/regionspot_alldata_bl' 20 | -------------------------------------------------------------------------------- /configs/objects365_v3det_openimages_bl_336.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RegionSpot.yaml" 2 | MODEL: 3 | WEIGHTS: "./output/regionspot_obj365_bl_336/model_final.pth" 4 | CLIP_TYPE: CLIP_400M_Large_336 5 | CLIP_INPUT_SIZE: 336 6 | DATALOADER: 7 | SAMPLER_TRAIN: "MultiDatasetSampler" 8 | DATASETS: 9 | TRAIN: ("objects365_train", "v3det_train","openimages_train",) 10 | TEST: () 11 | TEST: 12 | EVAL_PERIOD: 0 13 | SOLVER: 14 | STEPS: (350000, 420000) 15 | MAX_ITER: 450000 16 | INPUT: 17 | CROP: 18 | ENABLED: True 19 | FORMAT: "RGB" 20 | OUTPUT_DIR: './output/regionspot_alldata_bl_336' 21 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import matplotlib.pyplot as plt 4 | from regionspot.modeling.regionspot import build_regionspot_model 5 | from regionspot import RegionSpot_Predictor 6 | # Function to show masks on an image 7 | def show_mask(mask, ax, random_color=False): 8 | if random_color: 9 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 10 | else: 11 | color = np.array([30/255, 144/255, 255/255, 0.6]) 12 | h, w = mask.shape[-2:] 13 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 14 | ax.imshow(mask_image) 15 | 16 | # Function to show points on an image 17 | def show_points(coords, labels, ax, marker_size=375): 18 | pos_points = coords[labels == 1] 19 | neg_points = coords[labels == 0] 20 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 21 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 22 | 23 | # Function to show bounding boxes on an image 24 | def show_box(box, ax): 25 | x0, y0 = box[0], box[1] 26 | w, h = box[2] - x0, box[3] - y0 27 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor='none', lw=2)) 28 | 29 | # Read image and set up model 30 | image = cv2.imread('assets/image.jpg') 31 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert image to RGB format 32 | # Multiple boxes 33 | box_prompt = np.array([[64, 926, 804, 1978], [1237, 490, 1615, 771.], [1510, 64, 1670, 167]]) 34 | ckpt_path = '/path/to/model_weights.pth' 35 | clip_type = 'CLIP_400M_Large_336' 36 | clip_input_size = 336 37 | custom_vocabulary = ["Smoothie bowl", "Banana", "Strawberry", "Chia seeds", "Shredded coconut", "Wooden spoons", "Grapefruit", "Goji berries", "Flaxseeds seeds"] 38 | 39 | # Build and initialize the model 40 | model, msg = build_regionspot_model(is_training=False, image_size=clip_input_size, clip_type=clip_type, pretrain_ckpt=ckpt_path, custom_vocabulary=custom_vocabulary) 41 | 42 | # Create predictor and set image 43 | predictor = RegionSpot_Predictor(model.cuda()) 44 | predictor.set_image(image, clip_input_size=clip_input_size) 45 | 46 | # Prediction based on box prompt 47 | masks, mask_iou_score, class_score, class_index = predictor.predict( 48 | point_coords=None, 49 | point_labels=None, 50 | box=box_prompt, 51 | multimask_output=False, 52 | ) 53 | # Extract class name and display image with masks and box 54 | fig, ax = plt.subplots(figsize=(10, 10)) 55 | ax.imshow(image) 56 | for idx in range(len(box_prompt)): 57 | show_mask(masks[idx], ax) 58 | show_box(box_prompt[idx], ax) # Assuming box_prompt contains all your boxes 59 | # You might want to modify the text display for multiple classes as well 60 | class_name = custom_vocabulary[int(class_index[idx])] 61 | ax.text(box_prompt[idx][0], box_prompt[idx][1] - 10, class_name, color='white', fontsize=14, bbox=dict(facecolor='green', edgecolor='green', alpha=0.6)) 62 | 63 | ax.axis('off') 64 | plt.show() 65 | fig.savefig('result.png') 66 | plt.close(fig) 67 | 68 | -------------------------------------------------------------------------------- /regionspot/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import add_regionspot_config 2 | from .detector import RegionSpot 3 | from .data.dataset_mapper import RegionSpotDatasetMapper 4 | from .test_time_augmentation import RegionSpotWithTTA 5 | from .build import * 6 | from .data.custom_dataset_dataloader import * 7 | from .predictor import RegionSpot_Predictor 8 | 9 | 10 | -------------------------------------------------------------------------------- /regionspot/config.py: -------------------------------------------------------------------------------- 1 | from detectron2.config import CfgNode as CN 2 | 3 | 4 | def add_regionspot_config(cfg): 5 | """ 6 | Add config for RegionSpot 7 | """ 8 | cfg.MODEL.RegionSpot = CN() 9 | cfg.MODEL.CLIP_TYPE = 'CLIP_400M_Large' 10 | cfg.MODEL.CLIP_INPUT_SIZE = 224 11 | # Inference 12 | cfg.MODEL.TRAINING = True 13 | cfg.MODEL.BOX_TYPE = 'GT' 14 | 15 | #Dataloder 16 | cfg.DATALOADER.DATASET_RATIO = [1,1,1] # sample ratio 17 | cfg.DATALOADER.USE_RFS = [False, False, False] 18 | cfg.DATALOADER.MULTI_DATASET_GROUPING = True # Always true when multi-dataset is enabled 19 | cfg.DATALOADER.DATASET_ANN = ['box', 'box', 'box'] # Annotation type of each dataset 20 | cfg.DATALOADER.USE_DIFF_BS_SIZE = False # Use different batchsize for each dataset 21 | cfg.DATALOADER.DATASET_BS = [8, 32] # Used when USE_DIFF_BS_SIZE is on 22 | 23 | 24 | 25 | # Optimizer. 26 | cfg.SOLVER.OPTIMIZER = "ADAMW" 27 | cfg.SOLVER.BACKBONE_MULTIPLIER = 1.0 28 | 29 | # TTA. 30 | cfg.TEST.AUG.MIN_SIZES = (400, 500, 600, 640, 700, 900, 1000, 1100, 1200, 1300, 1400, 1800, 800) 31 | cfg.TEST.AUG.CVPODS_TTA = True 32 | cfg.TEST.AUG.SCALE_FILTER = True 33 | cfg.TEST.AUG.SCALE_RANGES = ([96, 10000], [96, 10000], 34 | [64, 10000], [64, 10000], 35 | [64, 10000], [0, 10000], 36 | [0, 10000], [0, 256], 37 | [0, 256], [0, 192], 38 | [0, 192], [0, 96], 39 | [0, 10000]) 40 | -------------------------------------------------------------------------------- /regionspot/data/dataset_mapper.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import numpy as np 4 | import torch 5 | import os 6 | from detectron2.data import detection_utils as utils 7 | from detectron2.data import transforms as T 8 | 9 | 10 | __all__ = ["RegionSpotDatasetMapper"] 11 | 12 | 13 | def build_transform_gen(cfg, is_train): 14 | """ 15 | Create a list of :class:`TransformGen` from config. 16 | Returns: 17 | list[TransformGen] 18 | """ 19 | if is_train: 20 | min_size = cfg.INPUT.MIN_SIZE_TRAIN 21 | max_size = cfg.INPUT.MAX_SIZE_TRAIN 22 | sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING 23 | else: 24 | min_size = cfg.INPUT.MIN_SIZE_TEST 25 | max_size = cfg.INPUT.MAX_SIZE_TEST 26 | sample_style = "choice" 27 | if sample_style == "range": 28 | assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size)) 29 | 30 | logger = logging.getLogger(__name__) 31 | tfm_gens = [] 32 | if is_train: 33 | tfm_gens.append(T.RandomFlip()) 34 | # ResizeShortestEdge 35 | tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) 36 | 37 | if is_train: 38 | logger.info("TransformGens used in training: " + str(tfm_gens)) 39 | return tfm_gens 40 | 41 | 42 | class RegionSpotDatasetMapper: 43 | """ 44 | A callable which takes a dataset dict in Detectron2 Dataset format, 45 | and map it into a format used by RegionSpot. 46 | 47 | The callable currently does the following: 48 | 49 | 1. Read the image from "file_name" 50 | 2. Applies geometric transforms to the image and annotation 51 | 3. Find and applies suitable cropping to the image and annotation 52 | 4. Prepare image and annotation to Tensors 53 | """ 54 | 55 | def __init__(self, cfg, is_train=True): 56 | if cfg.INPUT.CROP.ENABLED and is_train: 57 | self.crop_gen = [ 58 | T.ResizeShortestEdge([400, 500, 600], sample_style="choice"), 59 | T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE), 60 | ] 61 | else: 62 | self.crop_gen = None 63 | 64 | self.tfm_gens = build_transform_gen(cfg, is_train) 65 | logging.getLogger(__name__).info( 66 | "Full TransformGens used in training: {}, crop: {}".format(str(self.tfm_gens), str(self.crop_gen)) 67 | ) 68 | 69 | self.img_format = cfg.INPUT.FORMAT 70 | self.is_train = is_train 71 | # if self.is_train: 72 | # for dataset_name in cfg.DATASETS.TRAIN: 73 | # if dataset_name.startswith("coco"): 74 | self.mask_tokens_dir = os.path.join('./datasets/datasets_mask_tokens_vit_b/') 75 | 76 | def __call__(self, dataset_dict): 77 | """ 78 | Args: 79 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 80 | 81 | Returns: 82 | dict: a format that builtin models in detectron2 accept 83 | """ 84 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 85 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format) 86 | # utils.check_image_size(dataset_dict, image) 87 | # 88 | #get mask token and responsed label 89 | image_id = dataset_dict["image_id"] 90 | dataset_name = dataset_dict["file_name"].split('/')[1] 91 | #datasets/coco/train2017/000000566174.jpg 92 | #read pth 93 | pth_file = os.path.join(self.mask_tokens_dir, os.path.join(dataset_name, str(image_id)+'.pth')) 94 | offline_token = torch.load(pth_file) 95 | # 96 | if self.crop_gen is None: 97 | image, transforms = T.apply_transform_gens(self.tfm_gens, image) 98 | else: 99 | if np.random.rand() > 0.5: 100 | image, transforms = T.apply_transform_gens(self.tfm_gens, image) 101 | else: 102 | image, transforms = T.apply_transform_gens( 103 | self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image 104 | ) 105 | 106 | image_shape = image.shape[:2] # h, w 107 | 108 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 109 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 110 | # Therefore it's important to use torch.Tensor. 111 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 112 | dataset_dict["dataset_name"] = dataset_name 113 | dataset_dict["extra_info"] = offline_token 114 | if not self.is_train: 115 | # USER: Modify this if you want to keep them for some reason. 116 | dataset_dict.pop("annotations", None) 117 | return dataset_dict 118 | 119 | if "annotations" in dataset_dict: 120 | # USER: Modify this if you want to keep them for some reason. 121 | for anno in dataset_dict["annotations"]: 122 | anno.pop("segmentation", None) 123 | anno.pop("keypoints", None) 124 | 125 | # USER: Implement additional transformations if you have other types of data 126 | annos = [ 127 | utils.transform_instance_annotations(obj, transforms, image_shape) 128 | for obj in dataset_dict.pop("annotations") 129 | if obj.get("iscrowd", 0) == 0 130 | ] 131 | instances = utils.annotations_to_instances(annos, image_shape) 132 | dataset_dict["instances"] = utils.filter_empty_instances(instances) 133 | return dataset_dict 134 | -------------------------------------------------------------------------------- /regionspot/data/openimages.py: -------------------------------------------------------------------------------- 1 | from detectron2.data.datasets.register_coco import register_coco_instances 2 | import os 3 | from .openimages_categories import categories 4 | 5 | def _get_builtin_metadata(categories): 6 | id_to_name = {x['id']: x['name'] for x in categories} 7 | thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(len(categories))} 8 | thing_classes = [id_to_name[k] for k in sorted(id_to_name)] 9 | 10 | return { 11 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 12 | "thing_classes": thing_classes} 13 | 14 | def _get_builtin_metadata(): 15 | id_to_name = {x['id']: x['name'] for x in categories} 16 | thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(len(categories))} 17 | thing_classes = [id_to_name[k] for k in sorted(id_to_name)] 18 | return { 19 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 20 | "thing_classes": thing_classes} 21 | 22 | 23 | _PREDEFINED_SPLITS_OPENIMAGES = { 24 | "openimages_train": ("openimages/detection/", "re_openimages_v6_train_bbox_splitdir_int_ids.json"), 25 | "openimages_val": ("openimages/detection/", "re_openimages_v6_train_bbox_splitdir_int_ids.json"), 26 | } 27 | 28 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_OPENIMAGES.items(): 29 | register_coco_instances( 30 | key, 31 | _get_builtin_metadata(), 32 | os.path.join("datasets", json_file) if "://" not in json_file else json_file, 33 | os.path.join("datasets", image_root), 34 | ) -------------------------------------------------------------------------------- /regionspot/data/v3det.py: -------------------------------------------------------------------------------- 1 | from detectron2.data.datasets.register_coco import register_coco_instances 2 | import os 3 | 4 | from .v3det_categories import categories 5 | def _get_builtin_metadata(categories): 6 | id_to_name = {x['id']: x['name'] for x in categories} 7 | thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(len(categories))} 8 | thing_classes = [id_to_name[k] for k in sorted(id_to_name)] 9 | 10 | return { 11 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 12 | "thing_classes": thing_classes} 13 | 14 | def _get_builtin_metadata(): 15 | id_to_name = {x['id']: x['name'] for x in categories} 16 | thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(len(categories))} 17 | thing_classes = [id_to_name[k] for k in sorted(id_to_name)] 18 | return { 19 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 20 | "thing_classes": thing_classes} 21 | 22 | 23 | _PREDEFINED_SPLITS_V3DET = { 24 | "v3det_train": ("v3det/V3Det/", "v3det/v3det_2023_v1_train.json"), 25 | "v3det_val": ("v3det/V3Det/", "v3det/v3det_2023_v1_val.json"), 26 | } 27 | 28 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_V3DET.items(): 29 | register_coco_instances( 30 | key, 31 | _get_builtin_metadata(), 32 | os.path.join("datasets", json_file) if "://" not in json_file else json_file, 33 | os.path.join("datasets", image_root), 34 | ) -------------------------------------------------------------------------------- /regionspot/detector.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from .modeling.regionspot import build_regionspot_model 3 | import torch.cuda.amp as amp 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from einops import rearrange 10 | import json 11 | from detectron2.modeling import META_ARCH_REGISTRY 12 | from .util.postprocessing import segmentation_postprocess 13 | 14 | from detectron2.structures import Boxes, Instances 15 | from .util.preprocessing import prepare_prompt_infer, prepare_prompt_train 16 | 17 | __all__ = ["RegionSpot"] 18 | 19 | 20 | 21 | @META_ARCH_REGISTRY.register() 22 | class RegionSpot(nn.Module): 23 | """ 24 | Implement RegionSpot 25 | """ 26 | def __init__(self, cfg): 27 | super().__init__() 28 | self.device = torch.device(cfg.MODEL.DEVICE) 29 | self.clip_type = cfg.MODEL.CLIP_TYPE 30 | self.inference_box_type = cfg.MODEL.BOX_TYPE 31 | self.clip_input_size = cfg.MODEL.CLIP_INPUT_SIZE 32 | self.clip_target_size = (self.clip_input_size, self.clip_input_size) 33 | self.model, _ = build_regionspot_model(clip_type = self.clip_type, is_training=cfg.MODEL.TRAINING, image_size=self.clip_input_size) 34 | self.model.to(self.device) 35 | if self.inference_box_type != 'GT': 36 | path = './datasets/glip_results/nms_results_glip_tiny_model_o365_goldg_cc_sbu_lvis_val.json' 37 | with open(path, 'r') as file: 38 | self.pred_results = json.load(file) 39 | else: 40 | self.pred_results = None 41 | 42 | @torch.no_grad() 43 | def foward_inference(self, batched_inputs, do_postprocess=True): 44 | with amp.autocast(enabled=True): 45 | with torch.no_grad(): 46 | logits_per_image, pred_mask = self.model.forward_eval(batched_inputs, multimask_output=False) 47 | 48 | image_sizes = [x["original_size"] for x in batched_inputs] 49 | if self.inference_box_type == 'GT': 50 | boxes = torch.stack([x["instances"].gt_boxes.tensor for x in batched_inputs], dim=0) #n, n_box, n_token, 256 51 | if len(boxes[0]) == 0: 52 | boxes=torch.tensor([[[0,0, image_sizes[0][0], image_sizes[0][1]]]]) 53 | else: 54 | boxes = torch.stack([x["pred_boxes"] for x in batched_inputs], dim=0) #n, n_box, n_token, 256 55 | scores = torch.stack([x["scores"] for x in batched_inputs], dim=0) 56 | 57 | 58 | box_cls = logits_per_image 59 | box_pred = boxes 60 | if self.inference_box_type == 'GT': 61 | results = self.inference_gt_box(box_cls, box_pred, pred_mask, image_sizes) 62 | else: 63 | results = self.inference_pred_box(box_cls, box_pred, scores, pred_mask, image_sizes) 64 | if do_postprocess: 65 | processed_results = [] 66 | for results_per_image, input_per_image, image_size in zip(results, batched_inputs, image_sizes): 67 | height = input_per_image.get("height", image_size[0]) 68 | width = input_per_image.get("width", image_size[1]) 69 | r = segmentation_postprocess(results_per_image, height, width) 70 | processed_results.append({"instances": r}) 71 | return processed_results 72 | else: 73 | return results 74 | 75 | def foward_train(self, batched_inputs): 76 | with amp.autocast(enabled=True): 77 | outputs = self.model.forward_train(batched_inputs) 78 | loss = {'loss': outputs} 79 | return loss 80 | 81 | def forward(self, batched_inputs, do_postprocess=True): 82 | if not self.training: 83 | # Prepare Prompt. 84 | batched_inputs = prepare_prompt_infer(batched_inputs, pred_results = self.pred_results, target_size=self.clip_target_size) 85 | 86 | results = self.foward_inference(batched_inputs) 87 | return results 88 | 89 | if self.training: 90 | batched_inputs = prepare_prompt_train(batched_inputs, target_size=self.clip_target_size) 91 | loss_dict = self.foward_train(batched_inputs) 92 | return loss_dict 93 | 94 | 95 | 96 | def inference_gt_box(self, box_cls, box_pred, pred_mask, image_sizes=None): 97 | 98 | device = box_cls.device # assuming all tensors are on the same device 99 | results = [] 100 | 101 | for logits, boxes, masks, img_size in zip(box_cls, box_pred, pred_mask, image_sizes): 102 | # Calculate probabilities and flatten them 103 | probs = F.softmax(logits, dim=-1) 104 | probs_flattened = probs.view(-1) 105 | 106 | # Determine number of top predictions to consider 107 | top_num = min(900, len(probs_flattened)) 108 | 109 | # Get top-k values and indices 110 | topk_probs, topk_indices = torch.topk(probs_flattened, top_num) 111 | 112 | # Decode the top-k indices to get corresponding labels and boxes 113 | topk_labels = topk_indices % logits.shape[1] 114 | topk_boxes_indices = topk_indices // logits.shape[1] 115 | 116 | # Ensure boxes, masks and topk_boxes_indices are on the same device 117 | topk_boxes_indices = topk_boxes_indices.to(device) 118 | boxes = boxes.to(device) 119 | masks = masks.to(device) 120 | 121 | # Retrieve predictions using the top-k indices 122 | boxes_for_topk = boxes[topk_boxes_indices] 123 | masks_for_topk = masks[topk_boxes_indices] 124 | scores_for_topk = topk_probs # Modify accordingly if you have another score tensor 125 | # Create Instances object for top-k predictions 126 | result = Instances(img_size) 127 | result.pred_boxes = Boxes(boxes_for_topk) 128 | result.scores = scores_for_topk 129 | result.pred_classes = topk_labels 130 | result.pred_masks = masks_for_topk # Added masks to the result 131 | results.append(result) 132 | 133 | return results 134 | 135 | def inference_pred_box(self, box_cls, box_pred, box_score, masks, image_sizes=None): 136 | 137 | results = [] 138 | 139 | for i, (logits, box_pred_i, box_score_i, mask_i, img_size) in enumerate(zip(box_cls, box_pred, box_score, masks, image_sizes)): 140 | 141 | logits = logits.cuda() 142 | box_pred_i = box_pred_i.cuda() 143 | box_score_i = box_score_i.cuda() 144 | 145 | # Calculate probabilities and flatten them 146 | probs = F.softmax(logits, dim=-1) 147 | probs_flattened = probs.view(-1) 148 | 149 | # Determine number of top predictions to consider 150 | top_num = min(900, len(probs_flattened)) 151 | 152 | # Get top-k values and indices 153 | topk_probs, topk_indices = torch.topk(probs_flattened, top_num) 154 | 155 | # Decode the top-k indices to get corresponding labels and boxes 156 | topk_labels = topk_indices % logits.shape[1] 157 | topk_boxes_indices = topk_indices // logits.shape[1] 158 | 159 | # Retrieve predictions using the top-k indices 160 | boxes = box_pred_i[topk_boxes_indices] 161 | masks = mask_i[topk_boxes_indices] 162 | scores = box_score_i[topk_boxes_indices] * topk_probs 163 | 164 | # Construct result for the current image 165 | result = Instances(img_size) 166 | result.pred_boxes = Boxes(boxes) 167 | result.scores = scores 168 | result.pred_classes = topk_labels 169 | result.pred_masks = masks 170 | results.append(result) 171 | 172 | return results 173 | 174 | 175 | -------------------------------------------------------------------------------- /regionspot/modeling/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /regionspot/modeling/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | import torch.distributed as dist 14 | # from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if torch.__version__.split(".") < ["1", "7", "1"]: 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | # _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip"), sha_check=True): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if (not sha_check) or (sha_check and hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256): 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if sha_check and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _transform(n_px): 76 | return Compose([ 77 | Resize(n_px, interpolation=BICUBIC), 78 | CenterCrop(n_px), 79 | lambda image: image.convert("RGB"), 80 | ToTensor(), 81 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 82 | ]) 83 | 84 | 85 | def available_models() -> List[str]: 86 | """Returns the names of available CLIP models""" 87 | return list(_MODELS.keys()) 88 | 89 | 90 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False, image_size=None, download_root: str = None): 91 | """Load a CLIP model 92 | 93 | Parameters 94 | ---------- 95 | name : str 96 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 97 | 98 | device : Union[str, torch.device] 99 | The device to put the loaded model 100 | 101 | jit : bool 102 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 103 | 104 | Returns 105 | ------- 106 | model : torch.nn.Module 107 | The CLIP model 108 | 109 | preprocess : Callable[[PIL.Image], torch.Tensor] 110 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 111 | """ 112 | # if name in _MODELS: 113 | # if ((dist.is_initialized() or dist.is_available()) and int(dist.get_rank()) % torch.cuda.device_count() == 0) or not dist.is_available(): 114 | # model_path = _download(_MODELS[name]) 115 | # dist.barrier() 116 | # model_path = _download(_MODELS[name]) 117 | # elif os.path.isfile(name): 118 | # model_path = name 119 | # else: 120 | # raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 121 | if name in _MODELS: 122 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 123 | elif os.path.isfile(name): 124 | model_path = name 125 | else: 126 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 127 | 128 | with open(model_path, 'rb') as opened_file: 129 | try: 130 | # loading JIT archive 131 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 132 | state_dict = None 133 | except RuntimeError: 134 | # loading saved state dict 135 | if jit: 136 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 137 | jit = False 138 | state_dict = torch.load(opened_file, map_location="cpu") 139 | 140 | if not jit: 141 | model = build_model(state_dict or model.state_dict(), image_size).to(device) 142 | if str(device) == "cpu": 143 | model.float() 144 | return model, _transform(model.visual.input_resolution) 145 | 146 | # patch the device names 147 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 148 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 149 | 150 | def patch_device(module): 151 | try: 152 | graphs = [module.graph] if hasattr(module, "graph") else [] 153 | except RuntimeError: 154 | graphs = [] 155 | 156 | if hasattr(module, "forward1"): 157 | graphs.append(module.forward1.graph) 158 | 159 | for graph in graphs: 160 | for node in graph.findAllNodes("prim::Constant"): 161 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 162 | node.copyAttributes(device_node) 163 | 164 | model.apply(patch_device) 165 | patch_device(model.encode_image) 166 | patch_device(model.encode_text) 167 | 168 | # patch dtype to float32 on CPU 169 | if str(device) == "cpu": 170 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 171 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 172 | float_node = float_input.node() 173 | 174 | def patch_float(module): 175 | try: 176 | graphs = [module.graph] if hasattr(module, "graph") else [] 177 | except RuntimeError: 178 | graphs = [] 179 | 180 | if hasattr(module, "forward1"): 181 | graphs.append(module.forward1.graph) 182 | 183 | for graph in graphs: 184 | for node in graph.findAllNodes("aten::to"): 185 | inputs = list(node.inputs()) 186 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 187 | if inputs[i].node()["value"] == 5: 188 | inputs[i].node().copyAttributes(float_node) 189 | 190 | model.apply(patch_float) 191 | patch_float(model.encode_image) 192 | patch_float(model.encode_text) 193 | 194 | model.float() 195 | 196 | return model, _transform(model.input_resolution.item()) 197 | 198 | 199 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 200 | """ 201 | Returns the tokenized representation of given input string(s) 202 | 203 | Parameters 204 | ---------- 205 | texts : Union[str, List[str]] 206 | An input string or a list of input strings to tokenize 207 | 208 | context_length : int 209 | The context length to use; all CLIP models use 77 as the context length 210 | 211 | truncate: bool 212 | Whether to truncate the text in case its encoding is longer than the context length 213 | 214 | Returns 215 | ------- 216 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 217 | """ 218 | if isinstance(texts, str): 219 | texts = [texts] 220 | 221 | sot_token = _tokenizer.encoder["<|startoftext|>"] 222 | eot_token = _tokenizer.encoder["<|endoftext|>"] 223 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 225 | 226 | for i, tokens in enumerate(all_tokens): 227 | if len(tokens) > context_length: 228 | if truncate: 229 | tokens = tokens[:context_length] 230 | tokens[-1] = eot_token 231 | else: 232 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 233 | result[i, :len(tokens)] = torch.tensor(tokens) 234 | 235 | return result 236 | -------------------------------------------------------------------------------- /regionspot/modeling/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /regionspot/modeling/clip/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed -------------------------------------------------------------------------------- /regionspot/modeling/clip/vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from github.com/facebook/SLIP 8 | from collections import OrderedDict 9 | 10 | import torch 11 | from torch import nn 12 | 13 | from functools import partial 14 | from timm.models.layers import DropPath 15 | from timm.models.vision_transformer import PatchEmbed, Mlp 16 | 17 | class Attention(nn.Module): 18 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 19 | super().__init__() 20 | self.num_heads = num_heads 21 | head_dim = dim // num_heads 22 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 23 | self.scale = qk_scale or head_dim ** -0.5 24 | 25 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 26 | self.attn_drop = nn.Dropout(attn_drop) 27 | self.proj = nn.Linear(dim, dim) 28 | self.proj_drop = nn.Dropout(proj_drop) 29 | 30 | def forward(self, x): 31 | B, N, C = x.shape 32 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 33 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 34 | 35 | attn = (q @ k.transpose(-2, -1)) * self.scale 36 | attn = attn.softmax(dim=-1) 37 | attn = self.attn_drop(attn) 38 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 39 | x = self.proj(x) 40 | x = self.proj_drop(x) 41 | return x 42 | 43 | 44 | class Block(nn.Module): 45 | 46 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 47 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 48 | super().__init__() 49 | self.norm1 = norm_layer(dim) 50 | self.attn = Attention( 51 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 52 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 53 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 54 | self.norm2 = norm_layer(dim) 55 | mlp_hidden_dim = int(dim * mlp_ratio) 56 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 57 | 58 | def forward(self, x): 59 | attn_x = self.attn(self.norm1(x)) 60 | x = x + self.drop_path(attn_x) 61 | x = x + self.drop_path(self.mlp(self.norm2(x))) 62 | return x 63 | 64 | 65 | class VisionTransformer(nn.Module): 66 | """ Vision Transformer 67 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 68 | - https://arxiv.org/abs/2010.11929 69 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 70 | - https://arxiv.org/abs/2012.12877 71 | """ 72 | 73 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 74 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 75 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 76 | act_layer=None, weight_init=''): 77 | """ 78 | Args: 79 | img_size (int, tuple): input image size 80 | patch_size (int, tuple): patch size 81 | in_chans (int): number of input channels 82 | num_classes (int): number of classes for classification head 83 | embed_dim (int): embedding dimension 84 | depth (int): depth of transformer 85 | num_heads (int): number of attention heads 86 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 87 | qkv_bias (bool): enable bias for qkv if True 88 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 89 | distilled (bool): model includes a distillation token and head as in DeiT models 90 | drop_rate (float): dropout rate 91 | attn_drop_rate (float): attention dropout rate 92 | drop_path_rate (float): stochastic depth rate 93 | embed_layer (nn.Module): patch embedding layer 94 | norm_layer: (nn.Module): normalization layer 95 | weight_init: (str): weight init scheme 96 | """ 97 | super().__init__() 98 | self.num_classes = num_classes 99 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 100 | self.num_tokens = 2 if distilled else 1 101 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 102 | act_layer = act_layer or nn.GELU 103 | 104 | self.patch_embed = embed_layer( 105 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 106 | num_patches = self.patch_embed.num_patches 107 | 108 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 109 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 110 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 111 | self.pos_drop = nn.Dropout(p=drop_rate) 112 | 113 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 114 | self.blocks = nn.Sequential(*[ 115 | Block( 116 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 117 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) 118 | for i in range(depth)]) 119 | self.norm = norm_layer(embed_dim) 120 | 121 | # Representation layer 122 | if representation_size and not distilled: 123 | self.num_features = representation_size 124 | self.pre_logits = nn.Sequential(OrderedDict([ 125 | ('fc', nn.Linear(embed_dim, representation_size)), 126 | ('act', nn.Tanh()) 127 | ])) 128 | else: 129 | self.pre_logits = nn.Identity() 130 | 131 | # Classifier head(s) 132 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 133 | self.head_dist = None 134 | if distilled: 135 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 136 | 137 | @torch.jit.ignore 138 | def no_weight_decay(self): 139 | return {'pos_embed', 'cls_token', 'dist_token'} 140 | 141 | def get_classifier(self): 142 | if self.dist_token is None: 143 | return self.head 144 | else: 145 | return self.head, self.head_dist 146 | 147 | def reset_classifier(self, num_classes, global_pool=''): 148 | self.num_classes = num_classes 149 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 150 | if self.num_tokens == 2: 151 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 152 | 153 | def forward_featuremap(self, x): 154 | x = self.patch_embed(x) 155 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 156 | if self.dist_token is None: 157 | x = torch.cat((cls_token, x), dim=1) 158 | else: 159 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 160 | x = self.pos_drop(x + self.pos_embed) 161 | # apply Transformer blocks 162 | for blk_idx, blk in enumerate(self.blocks): 163 | x = blk(x) 164 | return x 165 | 166 | def forward_features(self, x): 167 | x = self.forward_featuremap(x) 168 | x = self.norm(x) 169 | if self.dist_token is None: 170 | return self.pre_logits(x[:, 0]) 171 | else: 172 | return x[:, 0], x[:, 1] 173 | 174 | def forward(self, x): 175 | x = self.forward_features(x) 176 | if self.head_dist is not None: 177 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 178 | if self.training and not torch.jit.is_scripting(): 179 | # during inference, return the average of both classifier predictions 180 | return x, x_dist 181 | else: 182 | return (x + x_dist) / 2 183 | else: 184 | x = self.head(x) 185 | return x 186 | -------------------------------------------------------------------------------- /regionspot/modeling/decoder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, Tensor 7 | 8 | def _get_activation_fn(activation): 9 | """Return an activation function given a string""" 10 | if activation == "relu": 11 | return F.relu 12 | if activation == "gelu": 13 | return F.gelu 14 | if activation == "glu": 15 | return F.glu 16 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 17 | 18 | def _get_clones(module, N): 19 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 20 | 21 | class TransformerDecoderLayer(nn.Module): 22 | 23 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 24 | activation="relu", normalize_before=False): 25 | super().__init__() 26 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) 27 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) 28 | # Implementation of Feedforward model 29 | self.linear1 = nn.Linear(d_model, dim_feedforward) 30 | self.dropout = nn.Dropout(dropout) 31 | self.linear2 = nn.Linear(dim_feedforward, d_model) 32 | 33 | self.norm1 = nn.LayerNorm(d_model) 34 | self.norm2 = nn.LayerNorm(d_model) 35 | self.norm3 = nn.LayerNorm(d_model) 36 | self.dropout1 = nn.Dropout(dropout) 37 | self.dropout2 = nn.Dropout(dropout) 38 | self.dropout3 = nn.Dropout(dropout) 39 | 40 | self.activation = _get_activation_fn(activation) 41 | self.normalize_before = normalize_before 42 | 43 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 44 | return tensor if pos is None else tensor + pos 45 | 46 | def forward_post(self, tgt, memory, 47 | tgt_mask: Optional[Tensor] = None, 48 | memory_mask: Optional[Tensor] = None, 49 | tgt_key_padding_mask: Optional[Tensor] = None, 50 | memory_key_padding_mask: Optional[Tensor] = None, 51 | pos: Optional[Tensor] = None, 52 | query_pos: Optional[Tensor] = None): 53 | q = k = self.with_pos_embed(tgt, query_pos) 54 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 55 | key_padding_mask=tgt_key_padding_mask)[0] 56 | 57 | tgt = tgt + self.dropout1(tgt2) 58 | tgt = self.norm1(tgt) 59 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 60 | key=self.with_pos_embed(memory, pos), 61 | value=memory, attn_mask=memory_mask, 62 | key_padding_mask=memory_key_padding_mask)[0] 63 | tgt = tgt + self.dropout2(tgt2) 64 | tgt = self.norm2(tgt) 65 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 66 | 67 | tgt = tgt + self.dropout3(tgt2) 68 | tgt = self.norm3(tgt) 69 | 70 | return tgt 71 | 72 | def forward_pre(self, tgt, memory, 73 | tgt_mask: Optional[Tensor] = None, 74 | memory_mask: Optional[Tensor] = None, 75 | tgt_key_padding_mask: Optional[Tensor] = None, 76 | memory_key_padding_mask: Optional[Tensor] = None, 77 | pos: Optional[Tensor] = None, 78 | query_pos: Optional[Tensor] = None): 79 | tgt2 = self.norm1(tgt) 80 | q = k = self.with_pos_embed(tgt2, query_pos) 81 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 82 | key_padding_mask=tgt_key_padding_mask)[0] 83 | tgt = tgt + self.dropout1(tgt2) 84 | tgt2 = self.norm2(tgt) 85 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 86 | key=self.with_pos_embed(memory, pos), 87 | value=memory, attn_mask=memory_mask, 88 | key_padding_mask=memory_key_padding_mask)[0] 89 | tgt = tgt + self.dropout2(tgt2) 90 | tgt2 = self.norm3(tgt) 91 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 92 | tgt = tgt + self.dropout3(tgt2) 93 | return tgt 94 | 95 | def forward(self, tgt, memory, 96 | tgt_mask: Optional[Tensor] = None, 97 | memory_mask: Optional[Tensor] = None, 98 | tgt_key_padding_mask: Optional[Tensor] = None, 99 | memory_key_padding_mask: Optional[Tensor] = None, 100 | pos: Optional[Tensor] = None, 101 | query_pos: Optional[Tensor] = None): 102 | if self.normalize_before: 103 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 104 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 105 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 106 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 107 | class TransformerDecoder(nn.Module): 108 | 109 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 110 | super().__init__() 111 | self.layers = _get_clones(decoder_layer, num_layers) 112 | self.num_layers = num_layers 113 | self.norm = norm 114 | self.return_intermediate = return_intermediate 115 | 116 | def forward(self, tgt, memory, 117 | tgt_mask: Optional[Tensor] = None, 118 | memory_mask: Optional[Tensor] = None, 119 | tgt_key_padding_mask: Optional[Tensor] = None, 120 | memory_key_padding_mask: Optional[Tensor] = None, 121 | pos: Optional[Tensor] = None, 122 | query_pos: Optional[Tensor] = None): 123 | output = tgt 124 | 125 | intermediate = [] 126 | 127 | for layer in self.layers: 128 | output = layer(output, memory, tgt_mask=tgt_mask, 129 | memory_mask=memory_mask, 130 | tgt_key_padding_mask=tgt_key_padding_mask, 131 | memory_key_padding_mask=memory_key_padding_mask, 132 | pos=pos, query_pos=query_pos) 133 | if self.return_intermediate: 134 | intermediate.append(self.norm(output)) 135 | 136 | if self.norm is not None: 137 | output = self.norm(output) 138 | if self.return_intermediate: 139 | intermediate.pop() 140 | intermediate.append(output) 141 | 142 | if self.return_intermediate: 143 | return torch.stack(intermediate) 144 | return output 145 | 146 | def build_decoder( 147 | d_model=256, 148 | nhead=8, 149 | num_decoder_layers=3, 150 | dim_feedforward=2048, 151 | dropout=0.1, 152 | activation="relu", 153 | normalize_before=False, 154 | return_intermediate_dec=False, 155 | ): 156 | 157 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 158 | dropout, activation, normalize_before) 159 | decoder_norm = nn.LayerNorm(d_model) 160 | decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 161 | return_intermediate=return_intermediate_dec) 162 | return decoder 163 | 164 | -------------------------------------------------------------------------------- /regionspot/modeling/regionspot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from typing import Any, Dict, List, Tuple 6 | from .segment_anything.utils.transforms import ResizeLongestSide 7 | from .segment_anything.build_sam import sam_model_registry 8 | from .decoder import build_decoder 9 | from . import constants 10 | from einops import rearrange 11 | 12 | from .segment_anything.modeling.prompt_engineering import prompt_engineering, get_prompt_templates 13 | from .clip import load as load_clip 14 | import clip 15 | 16 | 17 | class RegionSpot(nn.Module): 18 | TEXT_FEATS_MAP = { 19 | 'coco': 'text_feats_coco', 20 | 'objects365': 'text_feats_objects365', 21 | 'v3det': 'text_feats_v3det', 22 | 'lvis': 'text_feats_lvis', 23 | 'openimages': 'text_feats_openimages' 24 | } 25 | 26 | def __init__(self, sam_checkpoint='./sam_checkpoints/sam_vit_b_01ec64.pth', 27 | clip_type='CLIP_400M_Large', is_training=True, custom_vocabulary=None, image_size=224): 28 | super().__init__() 29 | 30 | self.sam = sam_model_registry['vit_b'](checkpoint=sam_checkpoint) 31 | self._freeze_module(self.sam) 32 | 33 | self.clip_model, self.text_dim, self.clip_dim = self._load_clip_model(clip_type, image_size) 34 | self.clip_model.eval() 35 | self._freeze_module(self.clip_model) 36 | self.logit_scale = self.clip_model.logit_scale.exp() 37 | 38 | self.to_clip = nn.Linear(256, self.clip_dim) 39 | self.ln_clip = nn.LayerNorm(self.clip_dim, elementwise_affine=False) 40 | self.projector = nn.Linear(self.clip_dim, self.text_dim) 41 | self.decoder = build_decoder(d_model=self.clip_dim) 42 | 43 | # Dynamically set attributes based on the datasets in the map 44 | if is_training: 45 | datasets_to_load = ['objects365', 'v3det', 'openimages'] 46 | for dataset in datasets_to_load: 47 | setattr(self, self.TEXT_FEATS_MAP[dataset], self.get_text_feat(dataset)) 48 | else: 49 | dataset_name = 'custom' if custom_vocabulary else 'lvis' 50 | self.text_feats = self.get_text_feat(dataset_name, custom_class=custom_vocabulary) 51 | 52 | @staticmethod 53 | def _freeze_module(module): 54 | for param in module.parameters(): 55 | param.requires_grad = False 56 | 57 | def _load_clip_model(self, clip_type, image_size): 58 | clip_model_map = { 59 | 'CLIP_400M': ("ViT-B/16", 512, 768), 60 | 'CLIP_400M_Large': ("ViT-L/14", 768, 1024), 61 | 'CLIP_400M_Large_336': ("ViT-L/14@336px", 768, 1024) 62 | } 63 | model_type, text_dim, clip_dim = clip_model_map[clip_type] 64 | clip_model, _ = load_clip(model_type, image_size=image_size) 65 | return clip_model, text_dim, clip_dim 66 | 67 | @torch.no_grad() 68 | def get_text_feat(self, dataset_name: str, custom_class=None) -> torch.Tensor: 69 | dataset_map = { 70 | 'coco': constants.COCO_INSTANCE_CLASSES, 71 | 'objects365': constants.OBJECTS365V1, 72 | 'v3det': constants.V3DET, 73 | 'lvis': constants.LVIS_CATEGORIES, 74 | 'openimages': constants.OPENIMAGE, 75 | 'custom': custom_class 76 | } 77 | 78 | # Error handling for custom dataset without custom classes provided 79 | if dataset_name == 'custom' and custom_class is None: 80 | raise ValueError("For custom datasets, you must provide the 'custom_class' parameter.") 81 | 82 | class_names = dataset_map.get(dataset_name, []) 83 | 84 | def clean_class_name(clss: str) -> str: 85 | """Clean class names for prompt templates.""" 86 | return clss.replace('-other', '').replace('-merged', '').replace('-stuff', '') 87 | 88 | def extract_mean_emb(text: str) -> torch.Tensor: 89 | """Extract mean embeddings from text using the clip model.""" 90 | tokens = clip.tokenize(text).cuda() 91 | 92 | if len(tokens) > 10000: 93 | split_idx = len(tokens) // 2 94 | text_features = torch.cat([ 95 | self.clip_model.encode_text(tokens[:split_idx]), 96 | self.clip_model.encode_text(tokens[split_idx:])], 97 | dim=0) 98 | else: 99 | text_features = self.clip_model.encode_text(tokens) 100 | 101 | return torch.mean(text_features, 0, keepdims=True)[0] 102 | 103 | templates = get_prompt_templates() 104 | clss_embeddings = [] 105 | for clss in class_names: 106 | txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates] 107 | clss_embeddings.append(extract_mean_emb(txts)) 108 | 109 | text_emb = torch.stack(clss_embeddings, dim=0) 110 | text_emb /= text_emb.norm(dim=-1, keepdim=True) 111 | 112 | return text_emb 113 | 114 | def sigmoid_focal_loss(self, inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2, reduction=True): 115 | """Compute the sigmoid focal loss.""" 116 | prob = inputs.sigmoid() 117 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 118 | p_t = prob * targets + (1 - prob) * (1 - targets) 119 | loss = ce_loss * ((1 - p_t) ** gamma) 120 | 121 | if alpha >= 0: 122 | loss = (alpha * targets + (1 - alpha) * (1 - targets)) * loss 123 | 124 | return loss.mean(1).sum() / num_boxes 125 | 126 | def get_logits(self, region_features, text_features, logit_scale): 127 | """Compute logits for region and text features.""" 128 | region_features = region_features / (region_features.norm(dim=-1, keepdim=True) + 1e-7) 129 | logits_per_image = logit_scale * region_features @ text_features.unsqueeze(0).transpose(1, 2) 130 | logits_per_text = logit_scale * text_features.unsqueeze(0) @ region_features.transpose(1, 2) 131 | return logits_per_image, logits_per_text 132 | 133 | def ce_loss(self, region_features, label, logit_scale, dataset_name, focal_alpha=0.25): 134 | """Compute the cross-entropy loss.""" 135 | b, n_box, d = region_features.shape 136 | text_feats = getattr(self, self.TEXT_FEATS_MAP[dataset_name]) 137 | 138 | logits_per_image, _ = self.get_logits(region_features, text_feats, logit_scale) 139 | 140 | target_classes_onehot = torch.zeros(logits_per_image.shape, dtype=logits_per_image.dtype, device=logits_per_image.device) 141 | label = label.long() 142 | target_classes_onehot.scatter_(2, label.unsqueeze(-1), 1) 143 | 144 | loss_ce = self.sigmoid_focal_loss(logits_per_image, target_classes_onehot, n_box, alpha=focal_alpha, gamma=2) * logits_per_image.shape[1] 145 | 146 | return loss_ce 147 | 148 | def forward_train(self, batched_input: List[Dict[str, Any]]) -> List[Dict[str, torch.Tensor]]: 149 | """Training forward pass.""" 150 | resized_image = torch.stack([x["resized_image"] for x in batched_input], dim=0) 151 | 152 | with torch.no_grad(): 153 | clip_feat = self.clip_model.encode_image_featuremap(resized_image).detach() 154 | 155 | 156 | masks_token = torch.stack([x["mask_tokens"] for x in batched_input], dim=0).squeeze(2) 157 | dataset_name = batched_input[0]["dataset_name"] 158 | masks_token = self.to_clip(masks_token) 159 | 160 | semantic_token = self.projector(self.decoder(masks_token, clip_feat)) 161 | label = torch.stack([x["label"] for x in batched_input], dim=0) 162 | 163 | return self.ce_loss(semantic_token, label, self.logit_scale, dataset_name) 164 | 165 | def forward_eval(self, batched_input: List[Dict[str, Any]], multimask_output=False) -> List[Dict[str, torch.Tensor]]: 166 | """Inference forward pass.""" 167 | sam_output = self.sam(batched_input, multimask_output=multimask_output) 168 | masks_token = torch.stack([x["masks_token"] for x in sam_output], dim=0).squeeze(2) 169 | pred_mask = torch.stack([x["masks"] for x in sam_output], dim=0) 170 | resized_image = torch.stack([x["resized_image"] for x in batched_input], dim=0) 171 | 172 | with torch.no_grad(): 173 | self.decoder.eval() 174 | clip_feat = self.clip_model.encode_image_featuremap(resized_image).detach() 175 | 176 | masks_token = self.to_clip(masks_token) 177 | 178 | semantic_token = self.projector(self.decoder(masks_token, clip_feat)) 179 | 180 | logits_per_image, _ = self.get_logits(semantic_token, self.text_feats, self.logit_scale) 181 | 182 | return logits_per_image, pred_mask 183 | 184 | def forward_inference(self, clip_feat, masks_token, resized_image,) -> List[Dict[str, torch.Tensor]]: 185 | """Inference forward pass.""" 186 | # if masks_token.shape 187 | masks_token = masks_token[None,:] 188 | if masks_token.shape[2] == 1: 189 | masks_token = masks_token.squeeze(2) 190 | else: 191 | masks_token = masks_token.permute(2, 1, 0, 3).squeeze(2) 192 | clip_feat = clip_feat.repeat(3, 1, 1) 193 | with torch.no_grad(): 194 | self.decoder.eval() 195 | masks_token = self.to_clip(masks_token) 196 | semantic_token = self.projector(self.decoder(masks_token, clip_feat)) 197 | 198 | logits_per_image, _ = self.get_logits(semantic_token, self.text_feats, self.logit_scale) 199 | if logits_per_image.shape[0] == 3: 200 | logits_per_image = logits_per_image.permute(1, 0, 2) 201 | return logits_per_image 202 | 203 | 204 | 205 | def build_regionspot_model(clip_type='CLIP_400M_Large', is_training=True, pretrain_ckpt=None, image_size=224, custom_vocabulary=None): 206 | model = RegionSpot(clip_type=clip_type, is_training=is_training, image_size=image_size, custom_vocabulary=custom_vocabulary) 207 | if pretrain_ckpt: 208 | checkpoint = torch.load(pretrain_ckpt, map_location='cpu')['model'] 209 | 210 | # Remove the 'model.' prefix 211 | new_checkpoint = {} 212 | for key in checkpoint.keys(): 213 | if key.startswith('model.'): 214 | new_key = key[len('model.'):] 215 | new_checkpoint[new_key] = checkpoint[key] 216 | else: 217 | new_checkpoint[key] = checkpoint[key] 218 | 219 | # Load the modified state dict 220 | msg = model.load_state_dict(new_checkpoint, strict=False) 221 | else: 222 | msg= 'training stage' 223 | return model, msg 224 | -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/common.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/common.cpython-39.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/image_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/image_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/image_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/image_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/mask_decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/mask_decoder.cpython-38.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/mask_decoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/mask_decoder.cpython-39.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/prompt_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/prompt_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/prompt_encoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/prompt_encoder.cpython-39.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/prompt_engineering.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/prompt_engineering.cpython-38.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/sam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/sam.cpython-38.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/sam.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/sam.cpython-39.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/__pycache__/transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/modeling/__pycache__/transformer.cpython-39.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | return self.lin2(self.act(self.lin1(x))) 27 | 28 | 29 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 30 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 31 | class LayerNorm2d(nn.Module): 32 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 33 | super().__init__() 34 | self.weight = nn.Parameter(torch.ones(num_channels)) 35 | self.bias = nn.Parameter(torch.zeros(num_channels)) 36 | self.eps = eps 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | u = x.mean(1, keepdim=True) 40 | s = (x - u).pow(2).mean(1, keepdim=True) 41 | x = (x - u) / torch.sqrt(s + self.eps) 42 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 43 | return x 44 | -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred, mask_token = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | mask_token = mask_token[:, mask_slice, :] 108 | 109 | iou_pred = iou_pred[:, mask_slice] 110 | 111 | # Prepare output 112 | return masks, iou_pred, mask_token 113 | 114 | def predict_masks( 115 | self, 116 | image_embeddings: torch.Tensor, 117 | image_pe: torch.Tensor, 118 | sparse_prompt_embeddings: torch.Tensor, 119 | dense_prompt_embeddings: torch.Tensor, 120 | ) -> Tuple[torch.Tensor, torch.Tensor]: 121 | """Predicts masks. See 'forward' for more details.""" 122 | # Concatenate output tokens 123 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 124 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 125 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 126 | 127 | # Expand per-image data in batch direction to be per-mask 128 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 129 | src = src + dense_prompt_embeddings 130 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 131 | b, c, h, w = src.shape 132 | 133 | # Run the transformer 134 | hs, src = self.transformer(src, pos_src, tokens) 135 | iou_token_out = hs[:, 0, :] 136 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 137 | # Upscale mask embeddings and predict masks using the mask tokens 138 | src = src.transpose(1, 2).view(b, c, h, w) 139 | upscaled_embedding = self.output_upscaling(src) 140 | hyper_in_list: List[torch.Tensor] = [] 141 | for i in range(self.num_mask_tokens): 142 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 143 | hyper_in = torch.stack(hyper_in_list, dim=1) 144 | b, c, h, w = upscaled_embedding.shape 145 | if len(hyper_in) == 0: 146 | print("here") 147 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 148 | # Generate mask quality predictions 149 | iou_pred = self.iou_prediction_head(iou_token_out) 150 | 151 | return masks, iou_pred, mask_tokens_out 152 | 153 | 154 | # Lightly adapted from 155 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 156 | class MLP(nn.Module): 157 | def __init__( 158 | self, 159 | input_dim: int, 160 | hidden_dim: int, 161 | output_dim: int, 162 | num_layers: int, 163 | sigmoid_output: bool = False, 164 | ) -> None: 165 | super().__init__() 166 | self.num_layers = num_layers 167 | h = [hidden_dim] * (num_layers - 1) 168 | self.layers = nn.ModuleList( 169 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 170 | ) 171 | self.sigmoid_output = sigmoid_output 172 | 173 | def forward(self, x): 174 | for i, layer in enumerate(self.layers): 175 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 176 | if self.sigmoid_output: 177 | x = F.sigmoid(x) 178 | return x -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn 10 | 11 | from typing import Any, Optional, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] 47 | self.point_embeddings = nn.ModuleList(point_embeddings) 48 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 49 | 50 | self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) 51 | self.mask_downscaling = nn.Sequential( 52 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 53 | LayerNorm2d(mask_in_chans // 4), 54 | activation(), 55 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 56 | LayerNorm2d(mask_in_chans), 57 | activation(), 58 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 59 | ) 60 | self.no_mask_embed = nn.Embedding(1, embed_dim) 61 | 62 | def get_dense_pe(self) -> torch.Tensor: 63 | """ 64 | Returns the positional encoding used to encode point prompts, 65 | applied to a dense set of points the shape of the image encoding. 66 | 67 | Returns: 68 | torch.Tensor: Positional encoding with shape 69 | 1x(embed_dim)x(embedding_h)x(embedding_w) 70 | """ 71 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 72 | 73 | def _embed_points( 74 | self, 75 | points: torch.Tensor, 76 | labels: torch.Tensor, 77 | pad: bool, 78 | ) -> torch.Tensor: 79 | """Embeds point prompts.""" 80 | points = points + 0.5 # Shift to center of pixel 81 | if pad: 82 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 83 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 84 | points = torch.cat([points, padding_point], dim=1) 85 | labels = torch.cat([labels, padding_label], dim=1) 86 | point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) 87 | point_embedding[labels == -1] = 0.0 88 | point_embedding[labels == -1] += self.not_a_point_embed.weight 89 | point_embedding[labels == 0] += self.point_embeddings[0].weight 90 | point_embedding[labels == 1] += self.point_embeddings[1].weight 91 | return point_embedding 92 | 93 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 94 | """Embeds box prompts.""" 95 | boxes = boxes + 0.5 # Shift to center of pixel 96 | coords = boxes.reshape(-1, 2, 2) 97 | corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) 98 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 99 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 100 | return corner_embedding 101 | 102 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 103 | """Embeds mask inputs.""" 104 | mask_embedding = self.mask_downscaling(masks) 105 | return mask_embedding 106 | 107 | def _get_batch_size( 108 | self, 109 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 110 | boxes: Optional[torch.Tensor], 111 | masks: Optional[torch.Tensor], 112 | ) -> int: 113 | """ 114 | Gets the batch size of the output given the batch size of the input prompts. 115 | """ 116 | if points is not None: 117 | return points[0].shape[0] 118 | elif boxes is not None: 119 | return boxes.shape[0] 120 | elif masks is not None: 121 | return masks.shape[0] 122 | else: 123 | return 1 124 | 125 | def _get_device(self) -> torch.device: 126 | return self.point_embeddings[0].weight.device 127 | 128 | def forward( 129 | self, 130 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 131 | boxes: Optional[torch.Tensor], 132 | masks: Optional[torch.Tensor], 133 | ) -> Tuple[torch.Tensor, torch.Tensor]: 134 | """ 135 | Embeds different types of prompts, returning both sparse and dense 136 | embeddings. 137 | 138 | Arguments: 139 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 140 | and labels to embed. 141 | boxes (torch.Tensor or none): boxes to embed 142 | masks (torch.Tensor or none): masks to embed 143 | 144 | Returns: 145 | torch.Tensor: sparse embeddings for the points and boxes, with shape 146 | BxNx(embed_dim), where N is determined by the number of input points 147 | and boxes. 148 | torch.Tensor: dense embeddings for the masks, in the shape 149 | Bx(embed_dim)x(embed_H)x(embed_W) 150 | """ 151 | bs = self._get_batch_size(points, boxes, masks) 152 | sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) 153 | if points is not None: 154 | coords, labels = points 155 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 156 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 157 | if boxes is not None: 158 | box_embeddings = self._embed_boxes(boxes) 159 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 160 | 161 | if masks is not None: 162 | dense_embeddings = self._embed_masks(masks) 163 | else: 164 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 165 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 166 | ) 167 | 168 | return sparse_embeddings, dense_embeddings 169 | 170 | 171 | class PositionEmbeddingRandom(nn.Module): 172 | """ 173 | Positional encoding using random spatial frequencies. 174 | """ 175 | 176 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 177 | super().__init__() 178 | if scale is None or scale <= 0.0: 179 | scale = 1.0 180 | self.register_buffer( 181 | "positional_encoding_gaussian_matrix", 182 | scale * torch.randn((2, num_pos_feats)), 183 | ) 184 | 185 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 186 | """Positionally encode points that are normalized to [0,1].""" 187 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 188 | coords = 2 * coords - 1 189 | coords = coords @ self.positional_encoding_gaussian_matrix 190 | coords = 2 * np.pi * coords 191 | # outputs d_1 x ... x d_n x C shape 192 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 193 | 194 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 195 | """Generate positional encoding for a grid of the specified size.""" 196 | h, w = size 197 | device: Any = self.positional_encoding_gaussian_matrix.device 198 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 199 | y_embed = grid.cumsum(dim=0) - 0.5 200 | x_embed = grid.cumsum(dim=1) - 0.5 201 | y_embed = y_embed / h 202 | x_embed = x_embed / w 203 | 204 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 205 | return pe.permute(2, 0, 1) # C x H x W 206 | 207 | def forward_with_coords( 208 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 209 | ) -> torch.Tensor: 210 | """Positionally encode points that are not normalized to [0,1].""" 211 | coords = coords_input.clone() 212 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 213 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 214 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 215 | -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/prompt_engineering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_prompt_templates(): 5 | prompt_templates = [ 6 | '{}.', 7 | 'a photo of a {}.', 8 | 'a bad photo of a {}.', 9 | 'a photo of many {}.', 10 | 'a sculpture of a {}.', 11 | 'a photo of the hard to see {}.', 12 | 'a low resolution photo of the {}.', 13 | 'a rendering of a {}.', 14 | 'graffiti of a {}.', 15 | 'a bad photo of the {}.', 16 | 'a cropped photo of the {}.', 17 | 'a tattoo of a {}.', 18 | 'the embroidered {}.', 19 | 'a photo of a hard to see {}.', 20 | 'a bright photo of a {}.', 21 | 'a photo of a clean {}.', 22 | 'a photo of a dirty {}.', 23 | 'a dark photo of the {}.', 24 | 'a drawing of a {}.', 25 | 'a photo of my {}.', 26 | 'the plastic {}.', 27 | 'a photo of the cool {}.', 28 | 'a close-up photo of a {}.', 29 | 'a black and white photo of the {}.', 30 | 'a painting of the {}.', 31 | 'a painting of a {}.', 32 | 'a pixelated photo of the {}.', 33 | 'a sculpture of the {}.', 34 | 'a bright photo of the {}.', 35 | 'a cropped photo of a {}.', 36 | 'a plastic {}.', 37 | 'a photo of the dirty {}.', 38 | 'a jpeg corrupted photo of a {}.', 39 | 'a blurry photo of the {}.', 40 | 'a photo of the {}.', 41 | 'a good photo of the {}.', 42 | 'a rendering of the {}.', 43 | 'a {} in a video game.', 44 | 'a photo of one {}.', 45 | 'a doodle of a {}.', 46 | 'a close-up photo of the {}.', 47 | 'the origami {}.', 48 | 'the {} in a video game.', 49 | 'a sketch of a {}.', 50 | 'a doodle of the {}.', 51 | 'a origami {}.', 52 | 'a low resolution photo of a {}.', 53 | 'the toy {}.', 54 | 'a rendition of the {}.', 55 | 'a photo of the clean {}.', 56 | 'a photo of a large {}.', 57 | 'a rendition of a {}.', 58 | 'a photo of a nice {}.', 59 | 'a photo of a weird {}.', 60 | 'a blurry photo of a {}.', 61 | 'a cartoon {}.', 62 | 'art of a {}.', 63 | 'a sketch of the {}.', 64 | 'a embroidered {}.', 65 | 'a pixelated photo of a {}.', 66 | 'itap of the {}.', 67 | 'a jpeg corrupted photo of the {}.', 68 | 'a good photo of a {}.', 69 | 'a plushie {}.', 70 | 'a photo of the nice {}.', 71 | 'a photo of the small {}.', 72 | 'a photo of the weird {}.', 73 | 'the cartoon {}.', 74 | 'art of the {}.', 75 | 'a drawing of the {}.', 76 | 'a photo of the large {}.', 77 | 'a black and white photo of a {}.', 78 | 'the plushie {}.', 79 | 'a dark photo of a {}.', 80 | 'itap of a {}.', 81 | 'graffiti of the {}.', 82 | 'a toy {}.', 83 | 'itap of my {}.', 84 | 'a photo of a cool {}.', 85 | 'a photo of a small {}.', 86 | 'a tattoo of the {}.', 87 | ] 88 | return prompt_templates 89 | 90 | def prompt_engineering(classnames, topk=1, suffix='.'): 91 | prompt_templates = get_prompt_templates() 92 | temp_idx = np.random.randint(min(len(prompt_templates), topk)) 93 | 94 | if isinstance(classnames, list): 95 | classname = random.choice(classnames) 96 | else: 97 | classname = classnames 98 | 99 | return prompt_templates[temp_idx].replace('.', suffix).format(classname.replace(',', '').replace('+', ' ')) -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | import numpy as np 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @torch.no_grad() 54 | def forward( 55 | self, 56 | batched_input: List[Dict[str, Any]], 57 | multimask_output: bool, 58 | ) -> List[Dict[str, torch.Tensor]]: 59 | """ 60 | Predicts masks end-to-end from provided images and prompts. 61 | If prompts are not known in advance, using SamPredictor is 62 | recommended over calling the model directly. 63 | 64 | Arguments: 65 | batched_input (list(dict)): A list over input images, each a 66 | dictionary with the following keys. A prompt key can be 67 | excluded if it is not present. 68 | 'image': The image as a torch tensor in 3xHxW format, 69 | already transformed for input to the model. 70 | 'original_size': (tuple(int, int)) The original size of 71 | the image before transformation, as (H, W). 72 | 'point_coords': (torch.Tensor) Batched point prompts for 73 | this image, with shape BxNx2. Already transformed to the 74 | input frame of the model. 75 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 76 | with shape BxN. 77 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 78 | Already transformed to the input frame of the model. 79 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 80 | in the form Bx1xHxW. 81 | multimask_output (bool): Whether the model should predict multiple 82 | disambiguating masks, or return a single mask. 83 | 84 | Returns: 85 | (list(dict)): A list over input images, where each element is 86 | as dictionary with the following keys. 87 | 'masks': (torch.Tensor) Batched binary mask predictions, 88 | with shape BxCxHxW, where B is the number of input prompts, 89 | C is determined by multimask_output, and (H, W) is the 90 | original size of the image. 91 | 'iou_predictions': (torch.Tensor) The model's predictions 92 | of mask quality, in shape BxC. 93 | 'low_res_logits': (torch.Tensor) Low resolution logits with 94 | shape BxCxHxW, where H=W=256. Can be passed as mask input 95 | to subsequent iterations of prediction. 96 | """ 97 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 98 | 99 | image_embeddings = self.image_encoder(input_images) 100 | outputs = [] 101 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 102 | if "point_coords" in image_record: 103 | points = (image_record["point_coords"], image_record["point_labels"]) 104 | else: 105 | points = None 106 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 107 | points=points, 108 | boxes=image_record.get("boxes", None), 109 | masks=image_record.get("mask_inputs", None), 110 | ) 111 | low_res_masks, iou_predictions, mask_token = self.mask_decoder( 112 | image_embeddings=curr_embedding.unsqueeze(0), 113 | image_pe=self.prompt_encoder.get_dense_pe(), 114 | sparse_prompt_embeddings=sparse_embeddings, 115 | dense_prompt_embeddings=dense_embeddings, 116 | multimask_output=multimask_output, 117 | ) 118 | 119 | masks = self.postprocess_masks( 120 | low_res_masks, 121 | input_size=image_record["image"].shape[-2:], 122 | original_size=image_record["original_size"], 123 | ) 124 | masks = masks > self.mask_threshold 125 | outputs.append( 126 | { 127 | "masks": masks, 128 | "masks_token": mask_token, 129 | "iou_predictions": iou_predictions, 130 | "low_res_logits": low_res_masks, 131 | } 132 | ) 133 | return outputs 134 | def postprocess_masks( 135 | self, 136 | masks: torch.Tensor, 137 | input_size: Tuple[int, ...], 138 | original_size: Tuple[int, ...], 139 | ) -> torch.Tensor: 140 | """ 141 | Remove padding and upscale masks to the original image size. 142 | 143 | Arguments: 144 | masks (torch.Tensor): Batched masks from the mask_decoder, 145 | in BxCxHxW format. 146 | input_size (tuple(int, int)): The size of the image input to the 147 | model, in (H, W) format. Used to remove padding. 148 | original_size (tuple(int, int)): The original size of the image 149 | before resizing for input to the model, in (H, W) format. 150 | 151 | Returns: 152 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 153 | is given by original_size. 154 | """ 155 | masks = F.interpolate( 156 | masks, 157 | (self.image_encoder.img_size, self.image_encoder.img_size), 158 | mode="bilinear", 159 | align_corners=False, 160 | ) 161 | masks = masks[..., : input_size[0], : input_size[1]] 162 | masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) 163 | return masks 164 | 165 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 166 | """Normalize pixel values and pad to a square input.""" 167 | # Normalize colors 168 | x = (x - self.pixel_mean) / self.pixel_std 169 | 170 | # Pad 171 | h, w = x.shape[-2:] 172 | padh = self.image_encoder.img_size - h 173 | padw = self.image_encoder.img_size - w 174 | x = F.pad(x, (0, padw, 0, padh)) 175 | return x 176 | 177 | -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/modeling/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import numpy as np 4 | from torch.nn import functional as F 5 | 6 | def load_class_freq( 7 | path='/opt/tiger/SAM_Adapter/datasets/lvis/lvis_v1_val.json', freq_weight=1.0): 8 | cat_info = json.load(open(path, 'r')) 9 | cat_info = torch.tensor( 10 | [c['image_count'] for c in sorted(cat_info, key=lambda x: x['id'])]) 11 | freq_weight = cat_info.float() ** freq_weight 12 | return freq_weight 13 | 14 | def get_fed_loss_inds(gt_classes, num_sample_cats, C, weight=None): 15 | appeared = torch.unique(gt_classes) # C' 16 | prob = appeared.new_ones(C + 1).float() 17 | prob[-1] = 0 18 | if len(appeared) < num_sample_cats: 19 | if weight is not None: 20 | prob[:C] = weight.float().clone() 21 | prob[appeared] = 0 22 | more_appeared = torch.multinomial( 23 | prob, num_sample_cats - len(appeared), 24 | replacement=False) 25 | appeared = torch.cat([appeared, more_appeared]) 26 | return appeared -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from .modeling import Sam 11 | 12 | from typing import Optional, Tuple 13 | 14 | from .utils.transforms import ResizeLongestSide 15 | 16 | 17 | class SamPredictor: 18 | def __init__( 19 | self, 20 | sam_model: Sam, 21 | ) -> None: 22 | """ 23 | Uses SAM to calculate the image embedding for an image, and then 24 | allow repeated, efficient mask prediction given prompts. 25 | 26 | Arguments: 27 | sam_model (Sam): The model to use for mask prediction. 28 | """ 29 | super().__init__() 30 | self.model = sam_model 31 | self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) 32 | self.reset_image() 33 | 34 | def set_image( 35 | self, 36 | image: np.ndarray, 37 | image_format: str = "RGB", 38 | ) -> None: 39 | """ 40 | Calculates the image embeddings for the provided image, allowing 41 | masks to be predicted with the 'predict' method. 42 | 43 | Arguments: 44 | image (np.ndarray): The image for calculating masks. Expects an 45 | image in HWC uint8 format, with pixel values in [0, 255]. 46 | image_format (str): The color format of the image, in ['RGB', 'BGR']. 47 | """ 48 | assert image_format in [ 49 | "RGB", 50 | "BGR", 51 | ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." 52 | if image_format != self.model.image_format: 53 | image = image[..., ::-1] 54 | 55 | # Transform the image to the form expected by the model 56 | input_image = self.transform.apply_image(image) 57 | input_image_torch = torch.as_tensor(input_image, device=self.device) 58 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 59 | 60 | self.set_torch_image(input_image_torch, image.shape[:2]) 61 | 62 | @torch.no_grad() 63 | def set_torch_image( 64 | self, 65 | transformed_image: torch.Tensor, 66 | original_image_size: Tuple[int, ...], 67 | ) -> None: 68 | """ 69 | Calculates the image embeddings for the provided image, allowing 70 | masks to be predicted with the 'predict' method. Expects the input 71 | image to be already transformed to the format expected by the model. 72 | 73 | Arguments: 74 | transformed_image (torch.Tensor): The input image, with shape 75 | 1x3xHxW, which has been transformed with ResizeLongestSide. 76 | original_image_size (tuple(int, int)): The size of the image 77 | before transformation, in (H, W) format. 78 | """ 79 | assert ( 80 | len(transformed_image.shape) == 4 81 | and transformed_image.shape[1] == 3 82 | and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size 83 | ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." 84 | self.reset_image() 85 | 86 | self.original_size = original_image_size 87 | self.input_size = tuple(transformed_image.shape[-2:]) 88 | input_image = self.model.preprocess(transformed_image) 89 | self.features = self.model.image_encoder(input_image) 90 | self.is_image_set = True 91 | 92 | def predict( 93 | self, 94 | point_coords: Optional[np.ndarray] = None, 95 | point_labels: Optional[np.ndarray] = None, 96 | box: Optional[np.ndarray] = None, 97 | mask_input: Optional[np.ndarray] = None, 98 | multimask_output: bool = True, 99 | return_logits: bool = False, 100 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 101 | """ 102 | Predict masks for the given input prompts, using the currently set image. 103 | 104 | Arguments: 105 | point_coords (np.ndarray or None): A Nx2 array of point prompts to the 106 | model. Each point is in (X,Y) in pixels. 107 | point_labels (np.ndarray or None): A length N array of labels for the 108 | point prompts. 1 indicates a foreground point and 0 indicates a 109 | background point. 110 | box (np.ndarray or None): A length 4 array given a box prompt to the 111 | model, in XYXY format. 112 | mask_input (np.ndarray): A low resolution mask input to the model, typically 113 | coming from a previous prediction iteration. Has form 1xHxW, where 114 | for SAM, H=W=256. 115 | multimask_output (bool): If true, the model will return three masks. 116 | For ambiguous input prompts (such as a single click), this will often 117 | produce better masks than a single prediction. If only a single 118 | mask is needed, the model's predicted quality score can be used 119 | to select the best mask. For non-ambiguous prompts, such as multiple 120 | input prompts, multimask_output=False can give better results. 121 | return_logits (bool): If true, returns un-thresholded masks logits 122 | instead of a binary mask. 123 | 124 | Returns: 125 | (np.ndarray): The output masks in CxHxW format, where C is the 126 | number of masks, and (H, W) is the original image size. 127 | (np.ndarray): An array of length C containing the model's 128 | predictions for the quality of each mask. 129 | (np.ndarray): An array of shape CxHxW, where C is the number 130 | of masks and H=W=256. These low resolution logits can be passed to 131 | a subsequent iteration as mask input. 132 | """ 133 | if not self.is_image_set: 134 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 135 | 136 | # Transform input prompts 137 | coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None 138 | if point_coords is not None: 139 | assert ( 140 | point_labels is not None 141 | ), "point_labels must be supplied if point_coords is supplied." 142 | point_coords = self.transform.apply_coords(point_coords, self.original_size) 143 | coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) 144 | labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) 145 | coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] 146 | if box is not None: 147 | box = self.transform.apply_boxes(box, self.original_size) 148 | box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) 149 | box_torch = box_torch[None, :] 150 | if mask_input is not None: 151 | mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) 152 | mask_input_torch = mask_input_torch[None, :, :, :] 153 | 154 | masks, iou_predictions, low_res_masks = self.predict_torch( 155 | coords_torch, 156 | labels_torch, 157 | box_torch, 158 | mask_input_torch, 159 | multimask_output, 160 | return_logits=return_logits, 161 | ) 162 | masks_np = masks[0].detach().cpu().numpy() 163 | iou_predictions_np = iou_predictions[0].detach().cpu().numpy() 164 | low_res_masks_np = low_res_masks[0].detach().cpu().numpy() 165 | return masks_np, iou_predictions_np, low_res_masks_np 166 | 167 | @torch.no_grad() 168 | def predict_torch( 169 | self, 170 | point_coords: Optional[torch.Tensor], 171 | point_labels: Optional[torch.Tensor], 172 | boxes: Optional[torch.Tensor] = None, 173 | mask_input: Optional[torch.Tensor] = None, 174 | multimask_output: bool = True, 175 | return_logits: bool = False, 176 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 177 | """ 178 | Predict masks for the given input prompts, using the currently set image. 179 | Input prompts are batched torch tensors and are expected to already be 180 | transformed to the input frame using ResizeLongestSide. 181 | 182 | Arguments: 183 | point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the 184 | model. Each point is in (X,Y) in pixels. 185 | point_labels (torch.Tensor or None): A BxN array of labels for the 186 | point prompts. 1 indicates a foreground point and 0 indicates a 187 | background point. 188 | boxes (np.ndarray or None): A Bx4 array given a box prompt to the 189 | model, in XYXY format. 190 | mask_input (np.ndarray): A low resolution mask input to the model, typically 191 | coming from a previous prediction iteration. Has form Bx1xHxW, where 192 | for SAM, H=W=256. Masks returned by a previous iteration of the 193 | predict method do not need further transformation. 194 | multimask_output (bool): If true, the model will return three masks. 195 | For ambiguous input prompts (such as a single click), this will often 196 | produce better masks than a single prediction. If only a single 197 | mask is needed, the model's predicted quality score can be used 198 | to select the best mask. For non-ambiguous prompts, such as multiple 199 | input prompts, multimask_output=False can give better results. 200 | return_logits (bool): If true, returns un-thresholded masks logits 201 | instead of a binary mask. 202 | 203 | Returns: 204 | (torch.Tensor): The output masks in BxCxHxW format, where C is the 205 | number of masks, and (H, W) is the original image size. 206 | (torch.Tensor): An array of shape BxC containing the model's 207 | predictions for the quality of each mask. 208 | (torch.Tensor): An array of shape BxCxHxW, where C is the number 209 | of masks and H=W=256. These low res logits can be passed to 210 | a subsequent iteration as mask input. 211 | """ 212 | if not self.is_image_set: 213 | raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") 214 | 215 | if point_coords is not None: 216 | points = (point_coords, point_labels) 217 | else: 218 | points = None 219 | 220 | # Embed prompts 221 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 222 | points=points, 223 | boxes=boxes, 224 | masks=mask_input, 225 | ) 226 | 227 | # Predict masks 228 | low_res_masks, iou_predictions = self.model.mask_decoder( 229 | image_embeddings=self.features, 230 | image_pe=self.model.prompt_encoder.get_dense_pe(), 231 | sparse_prompt_embeddings=sparse_embeddings, 232 | dense_prompt_embeddings=dense_embeddings, 233 | multimask_output=multimask_output, 234 | ) 235 | # print("low_res_masks", low_res_masks.shape) 236 | # Upscale the masks to the original image resolution 237 | masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) 238 | # print("masks", masks.shape) 239 | 240 | if not return_logits: 241 | masks = masks > self.model.mask_threshold 242 | 243 | return masks, iou_predictions, low_res_masks 244 | 245 | def get_image_embedding(self) -> torch.Tensor: 246 | """ 247 | Returns the image embeddings for the currently set image, with 248 | shape 1xCxHxW, where C is the embedding dimension and (H,W) are 249 | the embedding spatial dimension of SAM (typically C=256, H=W=64). 250 | """ 251 | if not self.is_image_set: 252 | raise RuntimeError( 253 | "An image must be set with .set_image(...) to generate an embedding." 254 | ) 255 | assert self.features is not None, "Features must exist if an image has been set." 256 | return self.features 257 | 258 | @property 259 | def device(self) -> torch.device: 260 | return self.model.device 261 | 262 | def reset_image(self) -> None: 263 | """Resets the currently set image.""" 264 | self.is_image_set = False 265 | self.features = None 266 | self.orig_h = None 267 | self.orig_w = None 268 | self.input_h = None 269 | self.input_w = None 270 | -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/utils/__pycache__/amg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/utils/__pycache__/amg.cpython-38.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/utils/__pycache__/amg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/utils/__pycache__/amg.cpython-39.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/utils/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/utils/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/utils/__pycache__/transforms.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/modeling/segment_anything/utils/__pycache__/transforms.cpython-39.pyc -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /regionspot/modeling/segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /regionspot/test_time_augmentation.py: -------------------------------------------------------------------------------- 1 | from itertools import count 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.parallel import DistributedDataParallel 5 | 6 | from detectron2.modeling import GeneralizedRCNNWithTTA, DatasetMapperTTA 7 | from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference_single_image 8 | from detectron2.structures import Instances, Boxes 9 | 10 | 11 | class RegionSpotWithTTA(GeneralizedRCNNWithTTA): 12 | 13 | def __init__(self, cfg, model, tta_mapper=None, batch_size=3): 14 | """ 15 | Args: 16 | cfg (CfgNode): 17 | model ( RegionSpot): a RegionSpot to apply TTA on. 18 | tta_mapper (callable): takes a dataset dict and returns a list of 19 | augmented versions of the dataset dict. Defaults to 20 | `DatasetMapperTTA(cfg)`. 21 | batch_size (int): batch the augmented images into this batch size for inference. 22 | """ 23 | # fix the issue: cannot assign module before Module.__init__() call 24 | nn.Module.__init__(self) 25 | if isinstance(model, DistributedDataParallel): 26 | model = model.module 27 | 28 | self.cfg = cfg.clone() 29 | self.model = model 30 | 31 | if tta_mapper is None: 32 | tta_mapper = DatasetMapperTTA(cfg) 33 | self.tta_mapper = tta_mapper 34 | self.batch_size = batch_size 35 | 36 | # cvpods tta. 37 | self.enable_cvpods_tta = cfg.TEST.AUG.CVPODS_TTA 38 | self.enable_scale_filter = cfg.TEST.AUG.SCALE_FILTER 39 | self.scale_ranges = cfg.TEST.AUG.SCALE_RANGES 40 | self.max_detection = cfg.MODEL.RegionSpot.NUM_PROPOSALS 41 | 42 | def _batch_inference(self, batched_inputs, detected_instances=None): 43 | """ 44 | Execute inference on a list of inputs, 45 | using batch size = self.batch_size, instead of the length of the list. 46 | 47 | """ 48 | if detected_instances is None: 49 | detected_instances = [None] * len(batched_inputs) 50 | 51 | factors = 2 if self.tta_mapper.flip else 1 52 | if self.enable_scale_filter: 53 | assert len(batched_inputs) == len(self.scale_ranges) * factors 54 | 55 | outputs = [] 56 | inputs, instances = [], [] 57 | for idx, input, instance in zip(count(), batched_inputs, detected_instances): 58 | inputs.append(input) 59 | instances.append(instance) 60 | if self.enable_cvpods_tta: 61 | output = self.model.forward(inputs, do_postprocess=False)[0] 62 | if self.enable_scale_filter: 63 | pred_boxes = output.get("pred_boxes") 64 | keep = self.filter_boxes(pred_boxes.tensor, *self.scale_ranges[idx // factors]) 65 | output = Instances( 66 | image_size=output.image_size, 67 | pred_boxes=Boxes(pred_boxes.tensor[keep]), 68 | pred_classes=output.pred_classes[keep], 69 | scores=output.scores[keep]) 70 | outputs.extend([output]) 71 | else: 72 | 73 | if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1: 74 | outputs.extend( 75 | self.model.forward( 76 | inputs, 77 | do_postprocess=False, 78 | ) 79 | ) 80 | inputs, instances = [], [] 81 | return outputs 82 | 83 | @staticmethod 84 | def filter_boxes(boxes, min_scale, max_scale): 85 | """ 86 | boxes: (N, 4) shape 87 | """ 88 | # assert boxes.mode == "xyxy" 89 | w = boxes[:, 2] - boxes[:, 0] 90 | h = boxes[:, 3] - boxes[:, 1] 91 | keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale) 92 | return keep 93 | 94 | def _inference_one_image(self, input): 95 | """ 96 | Args: 97 | input (dict): one dataset dict with "image" field being a CHW tensor 98 | 99 | Returns: 100 | dict: one output dict 101 | """ 102 | orig_shape = (input["height"], input["width"]) 103 | augmented_inputs, tfms = self._get_augmented_inputs(input) 104 | # Detect boxes from all augmented versions 105 | all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms) 106 | # merge all detected boxes to obtain final predictions for boxes 107 | if self.enable_cvpods_tta: 108 | merged_instances = self._merge_detections_cvpods_tta(all_boxes, all_scores, all_classes, orig_shape) 109 | else: 110 | merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape) 111 | 112 | return {"instances": merged_instances} 113 | 114 | def _merge_detections(self, all_boxes, all_scores, all_classes, shape_hw): 115 | # select from the union of all results 116 | num_boxes = len(all_boxes) 117 | num_classes = self.cfg.MODEL. RegionSpot.NUM_CLASSES 118 | # +1 because fast_rcnn_inference expects background scores as well 119 | all_scores_2d = torch.zeros(num_boxes, num_classes + 1, device=all_boxes.device) 120 | for idx, cls, score in zip(count(), all_classes, all_scores): 121 | all_scores_2d[idx, cls] = score 122 | 123 | merged_instances, _ = fast_rcnn_inference_single_image( 124 | all_boxes, 125 | all_scores_2d, 126 | shape_hw, 127 | 1e-8, 128 | self.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST, 129 | self.cfg.TEST.DETECTIONS_PER_IMAGE, 130 | ) 131 | 132 | return merged_instances 133 | 134 | def _merge_detections_cvpods_tta(self, all_boxes, all_scores, all_classes, shape_hw): 135 | all_scores = torch.tensor(all_scores).to(all_boxes.device) 136 | all_classes = torch.tensor(all_classes).to(all_boxes.device) 137 | 138 | all_boxes, all_scores, all_classes = self.merge_result_from_multi_scales( 139 | all_boxes, all_scores, all_classes, 140 | nms_type="soft_vote", vote_thresh=0.65, 141 | max_detection=self.max_detection 142 | ) 143 | 144 | all_boxes = Boxes(all_boxes) 145 | all_boxes.clip(shape_hw) 146 | 147 | result = Instances(shape_hw) 148 | result.pred_boxes = all_boxes 149 | result.scores = all_scores 150 | result.pred_classes = all_classes.long() 151 | return result 152 | 153 | def merge_result_from_multi_scales( 154 | self, boxes, scores, labels, nms_type="soft-vote", vote_thresh=0.65, max_detection=100 155 | ): 156 | boxes, scores, labels = self.batched_vote_nms( 157 | boxes, scores, labels, nms_type, vote_thresh 158 | ) 159 | 160 | number_of_detections = boxes.shape[0] 161 | # Limit to max_per_image detections **over all classes** 162 | if number_of_detections > max_detection > 0: 163 | boxes = boxes[:max_detection] 164 | scores = scores[:max_detection] 165 | labels = labels[:max_detection] 166 | 167 | return boxes, scores, labels 168 | 169 | def batched_vote_nms(self, boxes, scores, labels, vote_type, vote_thresh=0.65): 170 | # apply per class level nms, add max_coordinates on boxes first, then remove it. 171 | labels = labels.float() 172 | max_coordinates = boxes.max() + 1 173 | offsets = labels.reshape(-1, 1) * max_coordinates 174 | boxes = boxes + offsets 175 | 176 | boxes, scores, labels = self.bbox_vote(boxes, scores, labels, vote_thresh, vote_type) 177 | boxes -= labels.reshape(-1, 1) * max_coordinates 178 | 179 | return boxes, scores, labels 180 | 181 | def bbox_vote(self, boxes, scores, labels, vote_thresh, vote_type="softvote"): 182 | assert boxes.shape[0] == scores.shape[0] == labels.shape[0] 183 | det = torch.cat((boxes, scores.reshape(-1, 1), labels.reshape(-1, 1)), dim=1) 184 | 185 | vote_results = torch.zeros(0, 6, device=det.device) 186 | if det.numel() == 0: 187 | return vote_results[:, :4], vote_results[:, 4], vote_results[:, 5] 188 | 189 | order = scores.argsort(descending=True) 190 | det = det[order] 191 | 192 | while det.shape[0] > 0: 193 | # IOU 194 | area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]) 195 | xx1 = torch.max(det[0, 0], det[:, 0]) 196 | yy1 = torch.max(det[0, 1], det[:, 1]) 197 | xx2 = torch.min(det[0, 2], det[:, 2]) 198 | yy2 = torch.min(det[0, 3], det[:, 3]) 199 | w = torch.clamp(xx2 - xx1, min=0.) 200 | h = torch.clamp(yy2 - yy1, min=0.) 201 | inter = w * h 202 | iou = inter / (area[0] + area[:] - inter) 203 | 204 | # get needed merge det and delete these det 205 | merge_index = torch.where(iou >= vote_thresh)[0] 206 | vote_det = det[merge_index, :] 207 | det = det[iou < vote_thresh] 208 | 209 | if merge_index.shape[0] <= 1: 210 | vote_results = torch.cat((vote_results, vote_det), dim=0) 211 | else: 212 | if vote_type == "soft_vote": 213 | vote_det_iou = iou[merge_index] 214 | det_accu_sum = self.get_soft_dets_sum(vote_det, vote_det_iou) 215 | elif vote_type == "vote": 216 | det_accu_sum = self.get_dets_sum(vote_det) 217 | vote_results = torch.cat((vote_results, det_accu_sum), dim=0) 218 | 219 | order = vote_results[:, 4].argsort(descending=True) 220 | vote_results = vote_results[order, :] 221 | 222 | return vote_results[:, :4], vote_results[:, 4], vote_results[:, 5] 223 | 224 | @staticmethod 225 | def get_dets_sum(vote_det): 226 | vote_det[:, :4] *= vote_det[:, 4:5].repeat(1, 4) 227 | max_score = vote_det[:, 4].max() 228 | det_accu_sum = torch.zeros((1, 6), device=vote_det.device) 229 | det_accu_sum[:, :4] = torch.sum(vote_det[:, :4], dim=0) / torch.sum(vote_det[:, 4]) 230 | det_accu_sum[:, 4] = max_score 231 | det_accu_sum[:, 5] = vote_det[0, 5] 232 | return det_accu_sum 233 | 234 | @staticmethod 235 | def get_soft_dets_sum(vote_det, vote_det_iou): 236 | soft_vote_det = vote_det.detach().clone() 237 | soft_vote_det[:, 4] *= (1 - vote_det_iou) 238 | 239 | INFERENCE_TH = 0.05 240 | soft_index = torch.where(soft_vote_det[:, 4] >= INFERENCE_TH)[0] 241 | soft_vote_det = soft_vote_det[soft_index, :] 242 | 243 | vote_det[:, :4] *= vote_det[:, 4:5].repeat(1, 4) 244 | max_score = vote_det[:, 4].max() 245 | det_accu_sum = torch.zeros((1, 6), device=vote_det.device) 246 | det_accu_sum[:, :4] = torch.sum(vote_det[:, :4], dim=0) / torch.sum(vote_det[:, 4]) 247 | det_accu_sum[:, 4] = max_score 248 | det_accu_sum[:, 5] = vote_det[0, 5] 249 | 250 | if soft_vote_det.shape[0] > 0: 251 | det_accu_sum = torch.cat((det_accu_sum, soft_vote_det), dim=0) 252 | return det_accu_sum 253 | -------------------------------------------------------------------------------- /regionspot/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Surrey-UP-Lab/RegionSpot/47a40632dba1723b9d45eef6aefc87f7ef605ad2/regionspot/util/__init__.py -------------------------------------------------------------------------------- /regionspot/util/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | # modified from torchvision to also return the union 24 | def box_iou(boxes1, boxes2): 25 | area1 = box_area(boxes1) 26 | area2 = box_area(boxes2) 27 | 28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 30 | 31 | wh = (rb - lt).clamp(min=0) # [N,M,2] 32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 33 | 34 | union = area1[:, None] + area2 - inter 35 | 36 | iou = inter / union 37 | return iou, union 38 | 39 | 40 | def generalized_box_iou(boxes1, boxes2): 41 | """ 42 | Generalized IoU from https://giou.stanford.edu/ 43 | 44 | The boxes should be in [x0, y0, x1, y1] format 45 | 46 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 47 | and M = len(boxes2) 48 | """ 49 | # degenerate boxes gives inf / nan results 50 | # so do an early check 51 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 52 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 53 | iou, union = box_iou(boxes1, boxes2) 54 | 55 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 56 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 57 | 58 | wh = (rb - lt).clamp(min=0) # [N,M,2] 59 | area = wh[:, :, 0] * wh[:, :, 1] 60 | 61 | return iou - (area - union) / area 62 | 63 | 64 | def masks_to_boxes(masks): 65 | """Compute the bounding boxes around the provided masks 66 | 67 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 68 | 69 | Returns a [N, 4] tensors, with the boxes in xyxy format 70 | """ 71 | if masks.numel() == 0: 72 | return torch.zeros((0, 4), device=masks.device) 73 | 74 | h, w = masks.shape[-2:] 75 | 76 | y = torch.arange(0, h, dtype=torch.float) 77 | x = torch.arange(0, w, dtype=torch.float) 78 | y, x = torch.meshgrid(y, x) 79 | 80 | x_mask = (masks * x.unsqueeze(0)) 81 | x_max = x_mask.flatten(1).max(-1)[0] 82 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 83 | 84 | y_mask = (masks * y.unsqueeze(0)) 85 | y_max = y_mask.flatten(1).max(-1)[0] 86 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 87 | 88 | return torch.stack([x_min, y_min, x_max, y_max], 1) 89 | -------------------------------------------------------------------------------- /regionspot/util/colormap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def colormap(rgb=False): 5 | color_list = np.array( 6 | [ 7 | 0.000, 0.447, 0.741, 8 | 0.850, 0.325, 0.098, 9 | 0.929, 0.694, 0.125, 10 | 0.494, 0.184, 0.556, 11 | 0.466, 0.674, 0.188, 12 | 0.301, 0.745, 0.933, 13 | 0.635, 0.078, 0.184, 14 | 0.300, 0.300, 0.300, 15 | 0.600, 0.600, 0.600, 16 | 1.000, 0.000, 0.000, 17 | 1.000, 0.500, 0.000, 18 | 0.749, 0.749, 0.000, 19 | 0.000, 1.000, 0.000, 20 | 0.000, 0.000, 1.000, 21 | 0.667, 0.000, 1.000, 22 | 0.333, 0.333, 0.000, 23 | 0.333, 0.667, 0.000, 24 | 0.333, 1.000, 0.000, 25 | 0.667, 0.333, 0.000, 26 | 0.667, 0.667, 0.000, 27 | 0.667, 1.000, 0.000, 28 | 1.000, 0.333, 0.000, 29 | 1.000, 0.667, 0.000, 30 | 1.000, 1.000, 0.000, 31 | 0.000, 0.333, 0.500, 32 | 0.000, 0.667, 0.500, 33 | 0.000, 1.000, 0.500, 34 | 0.333, 0.000, 0.500, 35 | 0.333, 0.333, 0.500, 36 | 0.333, 0.667, 0.500, 37 | 0.333, 1.000, 0.500, 38 | 0.667, 0.000, 0.500, 39 | 0.667, 0.333, 0.500, 40 | 0.667, 0.667, 0.500, 41 | 0.667, 1.000, 0.500, 42 | 1.000, 0.000, 0.500, 43 | 1.000, 0.333, 0.500, 44 | 1.000, 0.667, 0.500, 45 | 1.000, 1.000, 0.500, 46 | 0.000, 0.333, 1.000, 47 | 0.000, 0.667, 1.000, 48 | 0.000, 1.000, 1.000, 49 | 0.333, 0.000, 1.000, 50 | 0.333, 0.333, 1.000, 51 | 0.333, 0.667, 1.000, 52 | 0.333, 1.000, 1.000, 53 | 0.667, 0.000, 1.000, 54 | 0.667, 0.333, 1.000, 55 | 0.667, 0.667, 1.000, 56 | 0.667, 1.000, 1.000, 57 | 1.000, 0.000, 1.000, 58 | 1.000, 0.333, 1.000, 59 | 1.000, 0.667, 1.000, 60 | 0.167, 0.000, 0.000, 61 | 0.333, 0.000, 0.000, 62 | 0.500, 0.000, 0.000, 63 | 0.667, 0.000, 0.000, 64 | 0.833, 0.000, 0.000, 65 | 1.000, 0.000, 0.000, 66 | 0.000, 0.167, 0.000, 67 | 0.000, 0.333, 0.000, 68 | 0.000, 0.500, 0.000, 69 | 0.000, 0.667, 0.000, 70 | 0.000, 0.833, 0.000, 71 | 0.000, 1.000, 0.000, 72 | 0.000, 0.000, 0.167, 73 | 0.000, 0.000, 0.333, 74 | 0.000, 0.000, 0.500, 75 | 0.000, 0.000, 0.667, 76 | 0.000, 0.000, 0.833, 77 | 0.000, 0.000, 1.000, 78 | 0.000, 0.000, 0.000, 79 | 0.143, 0.143, 0.143, 80 | 0.286, 0.286, 0.286, 81 | 0.429, 0.429, 0.429, 82 | 0.571, 0.571, 0.571, 83 | 0.714, 0.714, 0.714, 84 | 0.857, 0.857, 0.857, 85 | 1.000, 1.000, 1.000 86 | ] 87 | ).astype(np.float32) 88 | color_list = color_list.reshape((-1, 3)) * 255 89 | if not rgb: 90 | color_list = color_list[:, ::-1] 91 | return color_list 92 | 93 | 94 | def category(): 95 | 96 | category = [ 97 | "person", 98 | "bicycle", 99 | "car", 100 | "motorbike", 101 | "aeroplane", 102 | "bus", 103 | "train", 104 | "truck", 105 | "boat", 106 | "traffic light", 107 | "fire hydrant", 108 | "stop sign", 109 | "parking meter", 110 | "bench", 111 | "bird", 112 | "cat", 113 | "dog", 114 | "horse", 115 | "sheep", 116 | "cow", 117 | "elephant", 118 | "bear", 119 | "zebra", 120 | "giraffe", 121 | "backpack", 122 | "umbrella", 123 | "handbag", 124 | "tie", 125 | "suitcase", 126 | "frisbee", 127 | "skis", 128 | "snowboard", 129 | "sports ball", 130 | "kite", 131 | "baseball bat", 132 | "baseball glove", 133 | "skateboard", 134 | "surfboard", 135 | "tennis racket", 136 | "bottle", 137 | "wine glass", 138 | "cup", 139 | "fork", 140 | "knife", 141 | "spoon", 142 | "bowl", 143 | "banana", 144 | "apple", 145 | "sandwich", 146 | "orange", 147 | "broccoli", 148 | "carrot", 149 | "hot dog", 150 | "pizza", 151 | "donut", 152 | "cake", 153 | "chair", 154 | "sofa", 155 | "pottedplant", 156 | "bed", 157 | "diningtable", 158 | "toilet", 159 | "tvmonitor", 160 | "laptop", 161 | "mouse", 162 | "remote", 163 | "keyboard", 164 | "cell phone", 165 | "microwave", 166 | "oven", 167 | "toaster", 168 | "sink", 169 | "refrigerator", 170 | "book", 171 | "clock", 172 | "vase", 173 | "scissors", 174 | "teddy bear", 175 | "hair drier", 176 | "toothbrush"] 177 | 178 | return category -------------------------------------------------------------------------------- /regionspot/util/model_ema.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | 4 | import copy 5 | import math 6 | import itertools 7 | import logging 8 | from typing import Dict, Any 9 | from contextlib import contextmanager 10 | 11 | import torch 12 | from detectron2.engine.train_loop import HookBase 13 | from detectron2.checkpoint import DetectionCheckpointer 14 | 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class EMADetectionCheckpointer(DetectionCheckpointer): 20 | def resume_or_load(self, path: str, *, resume: bool = True) -> Dict[str, Any]: 21 | """ 22 | If `resume` is True, this method attempts to resume from the last 23 | checkpoint, if exists. Otherwise, load checkpoint from the given path. 24 | This is useful when restarting an interrupted training job. 25 | 26 | Args: 27 | path (str): path to the checkpoint. 28 | resume (bool): if True, resume from the last checkpoint if it exists 29 | and load the model together with all the checkpointables. Otherwise 30 | only load the model without loading any checkpointables. 31 | 32 | Returns: 33 | same as :meth:`load`. 34 | """ 35 | if resume and self.has_checkpoint(): 36 | path = self.get_checkpoint_file() 37 | return self.load(path) 38 | else: 39 | # workaround `self.load` 40 | return self.load(path, checkpointables=None) # modify 41 | 42 | 43 | class EMAState(object): 44 | def __init__(self): 45 | self.state = {} 46 | 47 | @classmethod 48 | def FromModel(cls, model: torch.nn.Module, device: str = ""): 49 | ret = cls() 50 | ret.save_from(model, device) 51 | return ret 52 | 53 | def save_from(self, model: torch.nn.Module, device: str = ""): 54 | """Save model state from `model` to this object""" 55 | for name, val in self.get_model_state_iterator(model): 56 | val = val.detach().clone() 57 | self.state[name] = val.to(device) if device else val 58 | 59 | def apply_to(self, model: torch.nn.Module): 60 | """Apply state to `model` from this object""" 61 | with torch.no_grad(): 62 | for name, val in self.get_model_state_iterator(model): 63 | assert ( 64 | name in self.state 65 | ), f"Name {name} not existed, available names {self.state.keys()}" 66 | val.copy_(self.state[name]) 67 | 68 | @contextmanager 69 | def apply_and_restore(self, model): 70 | old_state = EMAState.FromModel(model, self.device) 71 | self.apply_to(model) 72 | yield old_state 73 | old_state.apply_to(model) 74 | 75 | def get_ema_model(self, model): 76 | ret = copy.deepcopy(model) 77 | self.apply_to(ret) 78 | return ret 79 | 80 | @property 81 | def device(self): 82 | if not self.has_inited(): 83 | return None 84 | return next(iter(self.state.values())).device 85 | 86 | def to(self, device): 87 | for name in self.state: 88 | self.state[name] = self.state[name].to(device) 89 | return self 90 | 91 | def has_inited(self): 92 | return self.state 93 | 94 | def clear(self): 95 | self.state.clear() 96 | return self 97 | 98 | def get_model_state_iterator(self, model): 99 | param_iter = model.named_parameters() 100 | buffer_iter = model.named_buffers() 101 | return itertools.chain(param_iter, buffer_iter) 102 | 103 | def state_dict(self): 104 | return self.state 105 | 106 | def load_state_dict(self, state_dict, strict: bool = True): 107 | self.clear() 108 | for x, y in state_dict.items(): 109 | self.state[x] = y 110 | return torch.nn.modules.module._IncompatibleKeys( 111 | missing_keys=[], unexpected_keys=[] 112 | ) 113 | 114 | def __repr__(self): 115 | ret = f"EMAState(state=[{','.join(self.state.keys())}])" 116 | return ret 117 | 118 | 119 | class EMAUpdater(object): 120 | """Model Exponential Moving Average 121 | Keep a moving average of everything in the model state_dict (parameters and 122 | buffers). This is intended to allow functionality like 123 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 124 | 125 | Note: It's very important to set EMA for ALL network parameters (instead of 126 | parameters that require gradient), including batch-norm moving average mean 127 | and variance. This leads to significant improvement in accuracy. 128 | For example, for EfficientNetB3, with default setting (no mixup, lr exponential 129 | decay) without bn_sync, the EMA accuracy with EMA on params that requires 130 | gradient is 79.87%, while the corresponding accuracy with EMA on all params 131 | is 80.61%. 132 | 133 | Also, bn sync should be switched on for EMA. 134 | """ 135 | 136 | def __init__(self, state: EMAState, decay: float = 0.999, device: str = "", yolox: bool = False): 137 | self.decay = decay 138 | self.device = device 139 | 140 | self.state = state 141 | self.updates = 0 142 | self.yolox = yolox 143 | if yolox: 144 | decay = 0.9998 145 | self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) 146 | 147 | def init_state(self, model): 148 | self.state.clear() 149 | self.state.save_from(model, self.device) 150 | 151 | def update(self, model): 152 | with torch.no_grad(): 153 | self.updates += 1 154 | d = self.decay(self.updates) if self.yolox else self.decay 155 | for name, val in self.state.get_model_state_iterator(model): 156 | ema_val = self.state.state[name] 157 | if self.device: 158 | val = val.to(self.device) 159 | ema_val.copy_(ema_val * d + val * (1.0 - d)) 160 | 161 | 162 | def add_model_ema_configs(_C): 163 | _C.MODEL_EMA = type(_C)() 164 | _C.MODEL_EMA.ENABLED = False 165 | _C.MODEL_EMA.DECAY = 0.999 166 | # use the same as MODEL.DEVICE when empty 167 | _C.MODEL_EMA.DEVICE = "" 168 | # When True, loading the ema weight to the model when eval_only=True in build_model() 169 | _C.MODEL_EMA.USE_EMA_WEIGHTS_FOR_EVAL_ONLY = False 170 | # when True, use YOLOX EMA: https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/ema.py#L22 171 | _C.MODEL_EMA.YOLOX = False 172 | 173 | 174 | def _remove_ddp(model): 175 | from torch.nn.parallel import DistributedDataParallel 176 | 177 | if isinstance(model, DistributedDataParallel): 178 | return model.module 179 | return model 180 | 181 | 182 | def may_build_model_ema(cfg, model): 183 | if not cfg.MODEL_EMA.ENABLED: 184 | return 185 | model = _remove_ddp(model) 186 | assert not hasattr( 187 | model, "ema_state" 188 | ), "Name `ema_state` is reserved for model ema." 189 | model.ema_state = EMAState() 190 | logger.info("Using Model EMA.") 191 | 192 | 193 | def may_get_ema_checkpointer(cfg, model): 194 | if not cfg.MODEL_EMA.ENABLED: 195 | return {} 196 | model = _remove_ddp(model) 197 | return {"ema_state": model.ema_state} 198 | 199 | 200 | def get_model_ema_state(model): 201 | """Return the ema state stored in `model`""" 202 | model = _remove_ddp(model) 203 | assert hasattr(model, "ema_state") 204 | ema = model.ema_state 205 | return ema 206 | 207 | 208 | def apply_model_ema(model, state=None, save_current=False): 209 | """Apply ema stored in `model` to model and returns a function to restore 210 | the weights are applied 211 | """ 212 | model = _remove_ddp(model) 213 | 214 | if state is None: 215 | state = get_model_ema_state(model) 216 | 217 | if save_current: 218 | # save current model state 219 | old_state = EMAState.FromModel(model, state.device) 220 | state.apply_to(model) 221 | 222 | if save_current: 223 | return old_state 224 | return None 225 | 226 | 227 | @contextmanager 228 | def apply_model_ema_and_restore(model, state=None): 229 | """Apply ema stored in `model` to model and returns a function to restore 230 | the weights are applied 231 | """ 232 | model = _remove_ddp(model) 233 | 234 | if state is None: 235 | state = get_model_ema_state(model) 236 | 237 | old_state = EMAState.FromModel(model, state.device) 238 | state.apply_to(model) 239 | yield old_state 240 | old_state.apply_to(model) 241 | 242 | 243 | class EMAHook(HookBase): 244 | def __init__(self, cfg, model): 245 | model = _remove_ddp(model) 246 | assert cfg.MODEL_EMA.ENABLED 247 | assert hasattr( 248 | model, "ema_state" 249 | ), "Call `may_build_model_ema` first to initilaize the model ema" 250 | self.model = model 251 | self.ema = self.model.ema_state 252 | self.device = cfg.MODEL_EMA.DEVICE or cfg.MODEL.DEVICE 253 | self.ema_updater = EMAUpdater( 254 | self.model.ema_state, decay=cfg.MODEL_EMA.DECAY, device=self.device, yolox=cfg.MODEL_EMA.YOLOX 255 | ) 256 | 257 | def before_train(self): 258 | if self.ema.has_inited(): 259 | self.ema.to(self.device) 260 | else: 261 | self.ema_updater.init_state(self.model) 262 | 263 | def after_train(self): 264 | pass 265 | 266 | def before_step(self): 267 | pass 268 | 269 | def after_step(self): 270 | if not self.model.train: 271 | return 272 | self.ema_updater.update(self.model) 273 | -------------------------------------------------------------------------------- /regionspot/util/plot_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plotting utilities to visualize training logs. 3 | """ 4 | import torch 5 | import pandas as pd 6 | import numpy as np 7 | import seaborn as sns 8 | import matplotlib.pyplot as plt 9 | 10 | from pathlib import Path, PurePath 11 | 12 | 13 | def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): 14 | ''' 15 | Function to plot specific fields from training log(s). Plots both training and test results. 16 | 17 | :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file 18 | - fields = which results to plot from each log file - plots both training and test for each field. 19 | - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots 20 | - log_name = optional, name of log file if different than default 'log.txt'. 21 | 22 | :: Outputs - matplotlib plots of results in fields, color coded for each log file. 23 | - solid lines are training results, dashed lines are test results. 24 | 25 | ''' 26 | func_name = "plot_utils.py::plot_logs" 27 | 28 | # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, 29 | # convert single Path to list to avoid 'not iterable' error 30 | 31 | if not isinstance(logs, list): 32 | if isinstance(logs, PurePath): 33 | logs = [logs] 34 | print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") 35 | else: 36 | raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ 37 | Expect list[Path] or single Path obj, received {type(logs)}") 38 | 39 | # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir 40 | for i, dir in enumerate(logs): 41 | if not isinstance(dir, PurePath): 42 | raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") 43 | if not dir.exists(): 44 | raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") 45 | # verify log_name exists 46 | fn = Path(dir / log_name) 47 | if not fn.exists(): 48 | print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?") 49 | print(f"--> full path of missing log file: {fn}") 50 | return 51 | 52 | # load log file(s) and plot 53 | dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] 54 | 55 | fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) 56 | 57 | for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): 58 | for j, field in enumerate(fields): 59 | if field == 'mAP': 60 | coco_eval = pd.DataFrame( 61 | np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1] 62 | ).ewm(com=ewm_col).mean() 63 | axs[j].plot(coco_eval, c=color) 64 | else: 65 | df.interpolate().ewm(com=ewm_col).mean().plot( 66 | y=[f'train_{field}', f'test_{field}'], 67 | ax=axs[j], 68 | color=[color] * 2, 69 | style=['-', '--'] 70 | ) 71 | for ax, field in zip(axs, fields): 72 | ax.legend([Path(p).name for p in logs]) 73 | ax.set_title(field) 74 | 75 | 76 | def plot_precision_recall(files, naming_scheme='iter'): 77 | if naming_scheme == 'exp_id': 78 | # name becomes exp_id 79 | names = [f.parts[-3] for f in files] 80 | elif naming_scheme == 'iter': 81 | names = [f.stem for f in files] 82 | else: 83 | raise ValueError(f'not supported {naming_scheme}') 84 | fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) 85 | for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): 86 | data = torch.load(f) 87 | # precision is n_iou, n_points, n_cat, n_area, max_det 88 | precision = data['precision'] 89 | recall = data['params'].recThrs 90 | scores = data['scores'] 91 | # take precision for all classes, all areas and 100 detections 92 | precision = precision[0, :, :, 0, -1].mean(1) 93 | scores = scores[0, :, :, 0, -1].mean(1) 94 | prec = precision.mean() 95 | rec = data['recall'][0, :, 0, -1].mean() 96 | print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + 97 | f'score={scores.mean():0.3f}, ' + 98 | f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' 99 | ) 100 | axs[0].plot(recall, precision, c=color) 101 | axs[1].plot(recall, scores, c=color) 102 | 103 | axs[0].set_title('Precision / Recall') 104 | axs[0].legend(names) 105 | axs[1].set_title('Scores / Recall') 106 | axs[1].legend(names) 107 | return fig, axs 108 | -------------------------------------------------------------------------------- /regionspot/util/postprocessing.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from detectron2.structures import Instances 6 | 7 | def segmentation_postprocess( 8 | results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5 9 | ): 10 | 11 | if isinstance(output_width, torch.Tensor): 12 | # This shape might (but not necessarily) be tensors during tracing. 13 | # Converts integer tensors to float temporaries to ensure true 14 | # division is performed when computing scale_x and scale_y. 15 | output_width_tmp = output_width.float() 16 | output_height_tmp = output_height.float() 17 | new_size = torch.stack([output_height, output_width]) 18 | else: 19 | new_size = (output_height, output_width) 20 | output_width_tmp = output_width 21 | output_height_tmp = output_height 22 | 23 | scale_x, scale_y = ( 24 | output_width_tmp / results.image_size[1], 25 | output_height_tmp / results.image_size[0], 26 | ) 27 | results = Instances(new_size, **results.get_fields()) 28 | 29 | if results.has("pred_boxes"): 30 | output_boxes = results.pred_boxes 31 | elif results.has("proposal_boxes"): 32 | output_boxes = results.proposal_boxes 33 | else: 34 | output_boxes = None 35 | assert output_boxes is not None, "Predictions must contain boxes!" 36 | 37 | output_boxes.scale(scale_x, scale_y) 38 | output_boxes.clip(results.image_size) 39 | 40 | results = results[output_boxes.nonempty()] 41 | 42 | if results.has("pred_masks"): 43 | # import pdb;pdb.set_trace() 44 | mask = F.interpolate(results.pred_masks.float(), size=(output_height, output_width), mode='nearest') 45 | # import pdb;pdb.set_trace() 46 | mask = mask.squeeze(1).byte() 47 | results.pred_masks = mask 48 | 49 | # import pdb;pdb.set_trace() 50 | # results.pred_masks [N, output-height, output-width] 51 | 52 | 53 | return results -------------------------------------------------------------------------------- /regionspot/util/preprocessing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import json 4 | import torchvision.transforms.functional as F 5 | from regionspot.modeling.segment_anything.utils.transforms import ResizeLongestSide 6 | 7 | NORM_MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073]).unsqueeze(1).unsqueeze(2) 8 | NORM_STD = torch.tensor([0.26862954, 0.26130258, 0.27577711]).unsqueeze(1).unsqueeze(2) 9 | 10 | 11 | def resize_box(after_image_size, befor_image_size, boxes, size=800, max_size=1333): 12 | # size can be min_size (scalar) or (w, h) tuple 13 | #size 14 | # 15 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 16 | w, h = image_size 17 | if max_size is not None: 18 | min_original_size = float(min((w, h))) 19 | max_original_size = float(max((w, h))) 20 | if max_original_size / min_original_size * size > max_size: 21 | size = int(round(max_size * min_original_size / max_original_size)) 22 | 23 | if (w <= h and w == size) or (h <= w and h == size): 24 | return (h, w) 25 | 26 | if w < h: 27 | ow = size 28 | oh = int(size * h / w) 29 | else: 30 | oh = size 31 | ow = int(size * w / h) 32 | 33 | return (oh, ow) 34 | 35 | def get_size(image_size, size, max_size=None): 36 | if isinstance(size, (list, tuple)): 37 | return size[::-1] 38 | else: 39 | return get_size_with_aspect_ratio(image_size, size, max_size) 40 | 41 | size = get_size(befor_image_size, size, max_size) 42 | 43 | 44 | 45 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(after_image_size, befor_image_size)) 46 | ratio_width, ratio_height = ratios 47 | # ratio_width, ratio_height = 1, 1 48 | 49 | scaled_boxes = boxes * torch.as_tensor( 50 | [ratio_width, ratio_height, ratio_width, ratio_height] 51 | ) 52 | 53 | return scaled_boxes 54 | 55 | def resize_and_normalize(image, target_size=(224, 224)): 56 | resized_image = F.resize(image, target_size) 57 | device = resized_image.device 58 | return (resized_image - NORM_MEAN.to(device)) / NORM_STD.to(device) 59 | 60 | 61 | def get_pred_boxes(pred_results, image_id): 62 | scores = torch.tensor(pred_results[image_id]['scores']) 63 | labels = torch.tensor(pred_results[image_id]['labels']) 64 | boxes = torch.tensor(pred_results[image_id]['boxes']) 65 | 66 | return scores, labels, boxes 67 | 68 | 69 | def prepare_prompt_infer(batched_inputs, num_proposals=None, pred_results=None, target_size=(224,224)): 70 | boxes_type = 'GT' 71 | if pred_results is not None: 72 | boxes_type = 'PRED_BOX' 73 | for x in batched_inputs: 74 | curr_image = x["image"] 75 | x["curr_image"] = curr_image.clone() 76 | image_id = x["image_id"] 77 | image = curr_image.permute(1, 2, 0).to(torch.uint8) 78 | curr_size = (image.shape[0], image.shape[1]) 79 | 80 | resized_image = resize_and_normalize(curr_image.cuda() / 255, target_size=target_size) 81 | x["image"] = torch.as_tensor(ResizeLongestSide(1024).apply_image(np.array(image.cpu())), dtype=torch.float).permute(2, 0, 1).cuda() 82 | raw_size = (x['height'], x['width']) 83 | 84 | if boxes_type != 'GT': 85 | scores, gt_label, boxes_prompt = get_pred_boxes(pred_results, str(image_id)) 86 | boxes_prompt = resize_box(curr_size, raw_size, boxes_prompt) 87 | x['pred_boxes'] = boxes_prompt 88 | x['scores'] = scores 89 | else: 90 | boxes_prompt = x["instances"].gt_boxes.tensor.cpu() 91 | if len(boxes_prompt) == 0: 92 | boxes_prompt = torch.tensor([[0, 0, *curr_size]]) 93 | boxes_prompt = ResizeLongestSide(1024).apply_boxes(np.array(boxes_prompt), curr_size) 94 | x['boxes'] = torch.as_tensor(boxes_prompt, dtype=torch.float).cuda() 95 | x['resized_image'] = resized_image 96 | x['original_size'] = curr_size 97 | return batched_inputs 98 | 99 | 100 | def prepare_prompt_train(batched_inputs, target_size=(224,224)): 101 | max_boxes = max(len(x["extra_info"]['mask_tokens']) for x in batched_inputs) 102 | num_proposals = max(max_boxes, 1) 103 | 104 | for x in batched_inputs: 105 | raw_image = x["image"] 106 | image = (x["image"].permute(1, 2, 0)).to(torch.uint8) 107 | curr_size = (image.shape[0], image.shape[1]) 108 | resized_image = resize_and_normalize(raw_image.cuda() / 255, target_size=target_size) 109 | input_image = ResizeLongestSide(1024).apply_image(np.array(image.cpu())) 110 | input_image_torch = torch.as_tensor(input_image, dtype=torch.float).permute(2, 0, 1).cuda() 111 | x["image"] = input_image_torch 112 | mask_tokens = x["extra_info"]['mask_tokens'].clone().detach().cuda() 113 | labels = torch.tensor(x["extra_info"]['classes']).cuda() 114 | 115 | if x['dataset_name'] == 'coco': 116 | try: 117 | # Convert labels using the coco_new_dict 118 | labels = [constants.coco_new_dict[label.item()] for label in labels] 119 | labels = torch.tensor(labels).cuda() 120 | except: 121 | pass 122 | else: 123 | # Decrement each label by 1 unless it's zero 124 | new_labels = [label.item() - 1 if label.item() != 0 else 0 for label in labels] 125 | labels = torch.tensor(new_labels).cuda() 126 | 127 | num_gt = len(mask_tokens) 128 | num_repeat = num_proposals // num_gt 129 | repeat_tensor = [num_repeat] * (num_gt - num_proposals % num_gt) + [num_repeat + 1] * (num_proposals % num_gt) 130 | repeat_tensor = torch.tensor(repeat_tensor).cuda() 131 | mask_tokens = torch.repeat_interleave(mask_tokens, repeat_tensor, dim=0) 132 | labels = torch.repeat_interleave(labels, repeat_tensor, dim=0) 133 | 134 | x['resized_image'] = resized_image 135 | x['label'] = labels 136 | x['mask_tokens'] = mask_tokens 137 | x['original_size'] = curr_size 138 | 139 | return batched_inputs 140 | -------------------------------------------------------------------------------- /regionspot/util/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /tools/re_save_ckpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | pretrain_ckpt = './pretrained_model/model_final.pth' 4 | checkpoint = torch.load(pretrain_ckpt, map_location='cpu') 5 | 6 | # Remove specific keys from the top-level dictionary 7 | top_level_keys_to_remove = ['trainer', 'iteration'] 8 | for key in top_level_keys_to_remove: 9 | if key in checkpoint: 10 | del checkpoint[key] 11 | 12 | # Remove keys that start with 'clip_model' and 'sam' from the checkpoint's 'model' dictionary 13 | model_keys_to_remove = ['model.clip_model', 'model.sam'] 14 | for key in list(checkpoint['model'].keys()): # Use list to copy keys 15 | if any(key.startswith(to_remove) for to_remove in model_keys_to_remove): 16 | print(key) 17 | del checkpoint['model'][key] 18 | 19 | # Save the modified checkpoint back to a file 20 | modified_ckpt_path = './pretrained_model/model_final_modified.pth' 21 | torch.save(checkpoint, modified_ckpt_path) 22 | print(checkpoint['model'].keys()) 23 | -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | This script is a simplified version of the training script in detectron2/tools. 5 | """ 6 | import os 7 | import itertools 8 | import weakref 9 | from typing import Any, Dict, List, Set 10 | import logging 11 | from collections import OrderedDict 12 | 13 | import torch 14 | from fvcore.nn.precise_bn import get_bn_modules 15 | 16 | import detectron2.utils.comm as comm 17 | from detectron2.utils.logger import setup_logger 18 | from detectron2.checkpoint import DetectionCheckpointer 19 | from detectron2.config import get_cfg 20 | from detectron2.data import build_detection_train_loader 21 | from regionspot import build_custom_train_loader 22 | 23 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch, create_ddp_model, \ 24 | AMPTrainer, SimpleTrainer, hooks 25 | from detectron2.evaluation import COCOEvaluator, LVISEvaluator, verify_results 26 | from detectron2.solver.build import maybe_add_gradient_clipping 27 | from detectron2.modeling import build_model 28 | from regionspot.data import objects365 29 | from regionspot.data import openimages 30 | from regionspot.data import v3det 31 | 32 | 33 | from regionspot import RegionSpotDatasetMapper, add_regionspot_config, RegionSpotWithTTA 34 | from regionspot.util.model_ema import add_model_ema_configs, may_build_model_ema, may_get_ema_checkpointer, EMAHook, \ 35 | apply_model_ema_and_restore, EMADetectionCheckpointer 36 | 37 | 38 | class Trainer(DefaultTrainer): 39 | """ Extension of the Trainer class adapted to RegionSpot. """ 40 | 41 | def __init__(self, cfg): 42 | """ 43 | Args: 44 | cfg (CfgNode): 45 | """ 46 | super(DefaultTrainer, self).__init__() # call grandfather's `__init__` while avoid father's `__init()` 47 | logger = logging.getLogger("detectron2") 48 | if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2 49 | setup_logger() 50 | cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) 51 | # Assume these objects must be constructed in this order. 52 | model = self.build_model(cfg) 53 | optimizer = self.build_optimizer(cfg, model) 54 | data_loader = self.build_train_loader(cfg) 55 | 56 | model = create_ddp_model(model, broadcast_buffers=False) 57 | self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( 58 | model, data_loader, optimizer 59 | ) 60 | 61 | self.scheduler = self.build_lr_scheduler(cfg, optimizer) 62 | 63 | ########## EMA ############ 64 | kwargs = { 65 | 'trainer': weakref.proxy(self), 66 | } 67 | kwargs.update(may_get_ema_checkpointer(cfg, model)) 68 | self.checkpointer = DetectionCheckpointer( 69 | # Assume you want to save checkpoints together with logs/statistics 70 | model, 71 | cfg.OUTPUT_DIR, 72 | **kwargs, 73 | # trainer=weakref.proxy(self), 74 | ) 75 | self.start_iter = 0 76 | self.max_iter = cfg.SOLVER.MAX_ITER 77 | self.cfg = cfg 78 | 79 | self.register_hooks(self.build_hooks()) 80 | 81 | @classmethod 82 | def build_model(cls, cfg): 83 | """ 84 | Returns: 85 | torch.nn.Module: 86 | 87 | It now calls :func:`detectron2.modeling.build_model`. 88 | Overwrite it if you'd like a different model. 89 | """ 90 | model = build_model(cfg) 91 | logger = logging.getLogger(__name__) 92 | logger.info("Model:\n{}".format(model)) 93 | # setup EMA 94 | may_build_model_ema(cfg, model) 95 | return model 96 | 97 | @classmethod 98 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 99 | """ 100 | Create evaluator(s) for a given dataset. 101 | This uses the special metadata "evaluator_type" associated with each builtin dataset. 102 | For your own dataset, you can simply create an evaluator manually in your 103 | script and do not have to worry about the hacky if-else logic here. 104 | """ 105 | if output_folder is None: 106 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 107 | if 'lvis' in dataset_name: 108 | return LVISEvaluator(dataset_name, cfg, True, output_folder) 109 | else: 110 | return COCOEvaluator(dataset_name, cfg, True, output_folder) 111 | 112 | @classmethod 113 | def build_train_loader(cls, cfg): 114 | mapper = RegionSpotDatasetMapper(cfg, is_train=True) 115 | if cfg.DATALOADER.SAMPLER_TRAIN in ['TrainingSampler', 'RepeatFactorTrainingSampler']: 116 | data_loader = build_detection_train_loader(cfg, mapper=mapper) 117 | else: 118 | data_loader = build_custom_train_loader(cfg, mapper=mapper) 119 | return data_loader 120 | @classmethod 121 | def build_optimizer(cls, cfg, model): 122 | params: List[Dict[str, Any]] = [] 123 | memo: Set[torch.nn.parameter.Parameter] = set() 124 | for key, value in model.named_parameters(recurse=True): 125 | if not value.requires_grad: 126 | continue 127 | # Avoid duplicating parameters 128 | if value in memo: 129 | continue 130 | memo.add(value) 131 | lr = cfg.SOLVER.BASE_LR 132 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 133 | if "backbone" in key: 134 | lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER 135 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 136 | 137 | def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class 138 | # detectron2 doesn't have full model gradient clipping now 139 | clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE 140 | enable = ( 141 | cfg.SOLVER.CLIP_GRADIENTS.ENABLED 142 | and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" 143 | and clip_norm_val > 0.0 144 | ) 145 | 146 | class FullModelGradientClippingOptimizer(optim): 147 | def step(self, closure=None): 148 | all_params = itertools.chain(*[x["params"] for x in self.param_groups]) 149 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) 150 | super().step(closure=closure) 151 | 152 | return FullModelGradientClippingOptimizer if enable else optim 153 | 154 | optimizer_type = cfg.SOLVER.OPTIMIZER 155 | if optimizer_type == "SGD": 156 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( 157 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM 158 | ) 159 | elif optimizer_type == "ADAMW": 160 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( 161 | params, cfg.SOLVER.BASE_LR 162 | ) 163 | else: 164 | raise NotImplementedError(f"no optimizer type {optimizer_type}") 165 | if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": 166 | optimizer = maybe_add_gradient_clipping(cfg, optimizer) 167 | return optimizer 168 | 169 | @classmethod 170 | def ema_test(cls, cfg, model, evaluators=None): 171 | # model with ema weights 172 | logger = logging.getLogger("detectron2.trainer") 173 | if cfg.MODEL_EMA.ENABLED: 174 | logger.info("Run evaluation with EMA.") 175 | with apply_model_ema_and_restore(model): 176 | results = cls.test(cfg, model, evaluators=evaluators) 177 | else: 178 | results = cls.test(cfg, model, evaluators=evaluators) 179 | return results 180 | 181 | @classmethod 182 | def test_with_TTA(cls, cfg, model): 183 | logger = logging.getLogger("detectron2.trainer") 184 | logger.info("Running inference with test-time augmentation ...") 185 | model = RegionSpotWithTTA(cfg, model) 186 | evaluators = [ 187 | cls.build_evaluator( 188 | cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") 189 | ) 190 | for name in cfg.DATASETS.TEST 191 | ] 192 | if cfg.MODEL_EMA.ENABLED: 193 | cls.ema_test(cfg, model, evaluators) 194 | else: 195 | res = cls.test(cfg, model, evaluators) 196 | res = OrderedDict({k + "_TTA": v for k, v in res.items()}) 197 | return res 198 | 199 | def build_hooks(self): 200 | """ 201 | Build a list of default hooks, including timing, evaluation, 202 | checkpointing, lr scheduling, precise BN, writing events. 203 | 204 | Returns: 205 | list[HookBase]: 206 | """ 207 | cfg = self.cfg.clone() 208 | cfg.defrost() 209 | cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN 210 | 211 | ret = [ 212 | hooks.IterationTimer(), 213 | EMAHook(self.cfg, self.model) if cfg.MODEL_EMA.ENABLED else None, # EMA hook 214 | hooks.LRScheduler(), 215 | hooks.PreciseBN( 216 | # Run at the same freq as (but before) evaluation. 217 | cfg.TEST.EVAL_PERIOD, 218 | self.model, 219 | # Build a new data loader to not affect training 220 | self.build_train_loader(cfg), 221 | cfg.TEST.PRECISE_BN.NUM_ITER, 222 | ) 223 | if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) 224 | else None, 225 | ] 226 | 227 | # Do PreciseBN before checkpointer, because it updates the model and need to 228 | # be saved by checkpointer. 229 | # This is not always the best: if checkpointing has a different frequency, 230 | # some checkpoints may have more precise statistics than others. 231 | if comm.is_main_process(): 232 | ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) 233 | 234 | def test_and_save_results(): 235 | self._last_eval_results = self.test(self.cfg, self.model) 236 | return self._last_eval_results 237 | 238 | # Do evaluation after checkpointer, because then if it fails, 239 | # we can use the saved checkpoint to debug. 240 | ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) 241 | 242 | if comm.is_main_process(): 243 | # Here the default print/log frequency of each writer is used. 244 | # run writers in the end, so that evaluation metrics are written 245 | ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) 246 | return ret 247 | 248 | 249 | def setup(args): 250 | """ 251 | Create configs and perform basic setups. 252 | """ 253 | cfg = get_cfg() 254 | add_regionspot_config(cfg) 255 | add_model_ema_configs(cfg) 256 | cfg.merge_from_file(args.config_file) 257 | cfg.merge_from_list(args.opts) 258 | cfg.freeze() 259 | default_setup(cfg, args) 260 | return cfg 261 | 262 | 263 | def main(args): 264 | cfg = setup(args) 265 | 266 | if args.eval_only: 267 | model = Trainer.build_model(cfg) 268 | kwargs = may_get_ema_checkpointer(cfg, model) 269 | if cfg.MODEL_EMA.ENABLED: 270 | EMADetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, **kwargs).resume_or_load(cfg.MODEL.WEIGHTS, 271 | resume=args.resume) 272 | else: 273 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, **kwargs).resume_or_load(cfg.MODEL.WEIGHTS, 274 | resume=args.resume) 275 | res = Trainer.ema_test(cfg, model) 276 | if cfg.TEST.AUG.ENABLED: 277 | res.update(Trainer.test_with_TTA(cfg, model)) 278 | if comm.is_main_process(): 279 | verify_results(cfg, res) 280 | return res 281 | 282 | trainer = Trainer(cfg) 283 | trainer.resume_or_load(resume=args.resume) 284 | return trainer.train() 285 | 286 | 287 | if __name__ == "__main__": 288 | args = default_argument_parser().parse_args() 289 | print("Command Line Args:", args) 290 | launch( 291 | main, 292 | args.num_gpus, 293 | num_machines=args.num_machines, 294 | machine_rank=args.machine_rank, 295 | dist_url=args.dist_url, 296 | args=(args,), 297 | ) 298 | --------------------------------------------------------------------------------