├── .gitignore ├── LICENSE.txt ├── README.md ├── __init__.py ├── diagram.png ├── examples ├── image_1.jpg ├── image_2.jpg ├── image_3.jpg ├── inference_demo.py └── output_image_1_refcoco.png ├── requirements.txt ├── scripts ├── create_all_masks_dataset.py ├── create_zero_shot_masks_dataset.py ├── test_unsupervised_masks_dataset.py └── train │ ├── __init__.py │ ├── test_model.py │ ├── train_model.py │ └── train_test_args.py ├── setup.py └── ssc_ris ├── correct ├── __init__.py ├── loss.py └── loss_utils.py ├── refer_dataset ├── __init__.py ├── bert │ ├── __init__.py │ ├── activations.py │ ├── configuration_bert.py │ ├── configuration_utils.py │ ├── file_utils.py │ ├── generation_utils.py │ ├── modeling_bert.py │ ├── modeling_utils.py │ ├── tokenization_bert.py │ ├── tokenization_utils.py │ └── tokenization_utils_base.py ├── dataset.py ├── refer │ ├── LICENSE │ ├── Makefile │ ├── README.md │ ├── evaluation │ │ ├── __init__.py │ │ ├── bleu │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ ├── bleu.py │ │ │ └── bleu_scorer.py │ │ ├── cider │ │ │ ├── __init__.py │ │ │ ├── cider.py │ │ │ └── cider_scorer.py │ │ ├── meteor │ │ │ ├── __init__.py │ │ │ └── meteor.py │ │ ├── readme.txt │ │ ├── refEvaluation.py │ │ ├── rouge │ │ │ ├── __init__.py │ │ │ └── rouge.py │ │ └── tokenizer │ │ │ ├── __init__.py │ │ │ ├── ptbtokenizer.py │ │ │ └── stanford-corenlp-3.4.1.jar │ ├── external │ │ ├── README.md │ │ ├── __init__.py │ │ ├── _mask.pyx │ │ ├── mask.py │ │ ├── maskApi.c │ │ └── maskApi.h │ ├── refer.py │ └── setup.py ├── unsupervised_dataset.py └── utils.py ├── segment ├── __init__.py ├── _utils.py └── segment_fns.py ├── select ├── __init__.py ├── clip_sim.py ├── utils.py └── visual_prompting_fns.py └── utils ├── __init__.py ├── _utils.py ├── lavt_lib ├── _utils.py ├── backbone.py ├── lavt_utils.py ├── mask_predictor.py ├── mmcv_custom │ ├── __init__.py │ └── checkpoint.py └── segmentation.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *_stage_masks 3 | .vscode 4 | *.pth 5 | *.pb 6 | ssc_ris/refer_dataset/refer/data 7 | ssc_ris/refer_dataset/refer/test 8 | ssc_ris/refer_dataset/refer/*.ipynb 9 | *.egg-info -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # *Segment*, *Select*, *Correct*: A Framework for Weakly-Supervised Referring Image Segmentation 2 | 3 | ![Three-Stage Framework](diagram.png) 4 | 5 | This is the official repo for the paper [*Segment*, *Select*, *Correct*: A Framework for Weakly-Supervised Referring Image Segmentation](https://arxiv.org/abs/2310.13479) or S+S+C for short (`ssc` in code). This repository is licensed under a GNU general public license (find more in LICENSE.txt). 6 | 7 | 8 | ## Getting Started 9 | 10 | To perform inference or run Stages 1, 2 or 3, you must first perform the basic setup. Start by cloning this GitHub repo: 11 | ``` 12 | git clone https://github.com/fgirbal/segment-select-correct.git 13 | cd segment-select-correct 14 | ``` 15 | 16 | You can now install the package and its requirements. Given the outside package dependencies, it is highly recommended that you do this in a separate virtual/conda environment. To create and activate a new Conda environment that supports this execute: 17 | ``` 18 | conda create -n ssc_ris python=3.7 -y 19 | conda activate ssc_ris 20 | ``` 21 | 22 | The main package installation should be done with (omit `-e` if you do not want to edit the package, though that might cause some file system issues when installing the steps below): 23 | ``` 24 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html 25 | pip install -r requirements.txt 26 | pip install -e . 27 | ``` 28 | 29 | Follow the appropriate instructions below to run inference or run Stages 1, 2, or 3. 30 | 31 | ## Inference Demo on Example Images 32 | 33 | To simply test our pre-trained models (`S+S+C` in the paper) by running inference on images, you can download their weights [here](https://drive.google.com/drive/folders/1FHbcVz-HfseheGcRo0SiD3Bq0FDGMqVz), or for the individual models: 34 | 35 | | [RefCOCO](https://drive.google.com/file/d/1ZDXktKvNdai-2IPqhzeu5gC9Y1oKilsW/view) | [RefCOCO+](https://drive.google.com/file/d/1vQX2q11C2YK3OQMYFYiCIzzsrUYE5744/view) | [RefCOCOg (U)](https://drive.google.com/file/d/1YB_dDMhpTL541BobN-jKd1ie6enKNHY_/view) | [RefCOCOg (G)](https://drive.google.com/file/d/1ASvev63wODZLQ4xxg7v6z4lebW2opV_G/view) | 36 | |---|---|---|---| 37 | 38 | To perform inference on example images from the `examples` folder (or others in your machine), from the main directory simply run for example: 39 | ``` 40 | python examples/inference_demo.py --model-checkpoint [PATH_TO_MODEL_CHECKPOINT] --sentence "man in blue" --input-image examples/image_1.jpg 41 | ``` 42 | which, given the RefCOCO corrected model from above should generate the following output: 43 | ![Three-Stage Framework](examples/output_image_1_refcoco.png) 44 | 45 | ## Running Stages 1, 2 and 3 46 | 47 | To use this package with RefCOCO, RefCOCO+ or RefCOCOg, you must: 48 | - Follow instructions from [the LAVT repository](https://github.com/yz93/LAVT-RIS/tree/main/refer) to set up subdirectories and download annotations. The API itself is inside the `ssc` package (`ssc_ris.refer_dataset`), so you can setup the data there or anywhere else in your machine. 49 | - [Download images from COCO](https://cocodataset.org/#download) (the link entitled 2014 Train images [83K/13GB]), extract them from the ZIP to a folder `[REFER_DATASET_ROOT]/images/mscoco/images`, where `[REFER_DATASET_ROOT]` is a pointer to the `data` folder of Refer as per the instructions in the previous point. 50 | 51 | If you want to run the `Segment` step (Stage 1) of our framework to generate all the instance segmentation masks, you must also download the spaCy dictionary, install GroundingDINO and download the relevant weights. Follow steps 1., 2. and 3. to do that. 52 | 53 | If you want to train models from the `Correct` step (Stage 3), then you must download the Swin transformer pre-trained weights. Follow steps 4. to do that. 54 | 55 | ### 1. (Stage 1) Download spaCy dictionary 56 | ``` 57 | python -m spacy download en_core_web_md 58 | ``` 59 | 60 | ### 2. (Stage 1) Installing GroundingDINO and downloading the weights: 61 | 62 | Clone and install GroundingDINO by running: 63 | ``` 64 | git clone https://github.com/IDEA-Research/GroundingDINO.git 65 | cd GroundingDINO 66 | 67 | # for reproducibility of our results, checkout this commit 68 | git checkout 57535c5a79791cb76e36fdb64975271354f10251 69 | pip install -e . 70 | ``` 71 | 72 | Download relevant weights and return to the `scripts` folder: 73 | ``` 74 | mkdir weights 75 | cd weights 76 | wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth 77 | cd ../.. 78 | ``` 79 | 80 | ### 3. (Stage 1) Downloading the SAM weights: 81 | 82 | Making sure you are in within `scripts` folder, execute: 83 | ``` 84 | mkdir sam_weights 85 | cd sam_weights 86 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 87 | ``` 88 | 89 | ### 4. (Stage 3) Download Swin transformer weights 90 | Inside the `scripts` folder, execute: 91 | ``` 92 | mkdir train/swin_pretrained_weights 93 | cd train/swin_pretrained_weights 94 | wget https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth 95 | ``` 96 | 97 | ## 2. *Segment*: Generating all the masks 98 | 99 | To run the first stage of our framework for the RefCOCO training set, move to the `scripts` folder and execute: 100 | ``` 101 | python create_all_masks_dataset.py -n example -d refcoco --dataset-root [REFER_DATASET_ROOT] --most-likely-noun --project --context-projections -f 102 | ``` 103 | where `[REFER_DATASET_ROOT]` is the pointer to the folder `data` of the REFER. This will create an `example` folder inside `segment_stage_masks` which can be used in the next step. 104 | 105 | For help with this script execute `python create_all_masks_dataset.py --help`. 106 | 107 | ## 3. *Select*: Zero-shot instance choice 108 | 109 | To run the second stage of our framework for the RefCOCO training set on the `example` masks generated in 2.2. still inside the `scripts` folder execute: 110 | ``` 111 | python create_zero_shot_masks_dataset.py --original-name example --new-name example_selected --dataset-root [REFER_DATASET_ROOT] --mask-choice reverse_blur -f 112 | ``` 113 | This will create an `example_selected` folder inside `select_stage_masks` which can be used for training. 114 | 115 | For help with this script execute `python create_zero_shot_masks_dataset.py --help`. 116 | 117 | ## 4. Testing unsupervised masks 118 | 119 | To test the quality of the masks generated in 2.2. (or 2.3.), run the following script: 120 | ``` 121 | python test_unsupervised_masks_dataset.py -n example -s segment --dataset-root [REFER_DATASET_ROOT] --mask-choice reverse_blur -f 122 | ``` 123 | This will test all of the masks generated in 2.2. using the reverse blur zero-shot choice criteria. Other options for `mask-choice` include `random` (to choose a random mask) or `best` (to pick the one that maximizes mean intersection over union). Testing the select stage masks can be done by replacing the name with the experiment one (e.g., `example_selected`) and modifying `-s select` to identify the stage. Note that `mask-choice` won't influence the outcome in that case, since each mask only has one option (the previously chosen one by the select stage mechanism). 124 | 125 | For help with this script execute `python test_unsupervised_masks_dataset.py --help`. 126 | 127 | ## 5. Pre-training or constrained greedy matching 128 | 129 | To pre-train a model using the zero-shot selected masks from `example_selected` in 2.2., run the following script from inside the `scripts/train`: 130 | ``` 131 | python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train_model.py \ 132 | --dataset refcoco --model_id refcoco \ 133 | --batch-size 7 --batch-size-unfolded-limit 15 --lr 0.00005 --wd 1e-2 --epochs 40 --img_size 480 \ 134 | --swin_type base --pretrained_swin_weights swin_pretrained_weights/swin_base_patch4_window12_384_22k.pth \ 135 | --refer_data_root [REFER_DATASET_ROOT] --pseudo_masks_root ../select_stage_masks/example_selected/instance_masks/ \ 136 | --model-experiment-name test_experiment --one-sentence --loss-mode random_ce 137 | ``` 138 | Note that the total batch size in this case will be 60. 139 | 140 | Once this model is trained (the output will be inside the `models` folder and this case it will be called `test_example_best_refcoco.pth`), we can do 40 epochs of constrained greedy matching on the `segment_stage_masks` of `example` by running: 141 | ``` 142 | python -m torch.distributed.launch --nproc_per_node 4 --master_port 12345 train_model.py \ 143 | --dataset refcoco --model_id refcoco \ 144 | --batch-size 7 --batch-size-unfolded-limit 15 --lr 0.00005 --wd 1e-2 --epochs 40 --img_size 480 \ 145 | --swin_type base --pretrained_swin_weights swin_pretrained_weights/swin_base_patch4_window12_384_22k.pth \ 146 | --refer_data_root [REFER_DATASET_ROOT] --pseudo_masks_root ../segment_stage_masks/example/instance_masks/ \ 147 | --model-experiment-name test_experiment_greedy --one-sentence --loss-mode greedy_ce 148 | --init-from models/refcoco/test_example_best_refcoco.pth 149 | ``` 150 | The only changes between this script and the previous one is that this one includes an `--init-from` option to initialize the model, the `--pseudo_masks_root` come from the segment stage instead of the select stage, and `--loss-mode` is now `greedy_ce` instead of `random_ce` (which is effectively using the only mask in the zero-shot choosen masks). 151 | 152 | For more help with this script execute `python train_model.py --help`. 153 | 154 | ## 6. Testing a trained model 155 | 156 | To test a trained model on the validation split of RefCOCO, run the following script from inside the `scripts/train`: 157 | ``` 158 | python test_model.py --dataset refcoco --split val --model_id refcoco \ 159 | --workers 4 --swin_type base --img_size 480 \ 160 | --refer_data_root [REFER_DATASET_ROOT] --ddp_trained_weights \ 161 | --window12 --resume models/refcoco/test_example_best_refcoco.pth 162 | ``` 163 | for the pre-trained model from 2.5., or change `--resume` to `models/refcoco/test_example_greedy_best_refcoco.pth` to test the constrained greedy trained one. 164 | 165 | For more help with this script execute `python test_model.py --help`. 166 | 167 | ## Citation and Acknowledgements 168 | 169 | If you use `ssc` in your work, please cite the following: 170 | ``` 171 | @misc{eiras2023segment, 172 | title={Segment, Select, Correct: A Framework for Weakly-Supervised Referring Segmentation}, 173 | author={Francisco Eiras and Kemal Oksuz and Adel Bibi and Philip H. S. Torr and Puneet K. Dokania}, 174 | year={2023}, 175 | eprint={2310.13479}, 176 | archivePrefix={arXiv}, 177 | primaryClass={cs.CV} 178 | } 179 | ``` 180 | 181 | This work was supported by the EPSRC Centre for Doctoral Training in Autonomous Intelligent Machines and Systems [EP/S024050/1], by Five AI Limited, by the UKRI grant: Turing AI Fellowship EP/W002981/1, and by the Royal Academy of Engineering under the Research Chair and Senior Research Fellowships scheme. -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgirbal/segment-select-correct/0d36714c1166a598051e467b0bddc5ec23a67d6a/__init__.py -------------------------------------------------------------------------------- /diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgirbal/segment-select-correct/0d36714c1166a598051e467b0bddc5ec23a67d6a/diagram.png -------------------------------------------------------------------------------- /examples/image_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgirbal/segment-select-correct/0d36714c1166a598051e467b0bddc5ec23a67d6a/examples/image_1.jpg -------------------------------------------------------------------------------- /examples/image_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgirbal/segment-select-correct/0d36714c1166a598051e467b0bddc5ec23a67d6a/examples/image_2.jpg -------------------------------------------------------------------------------- /examples/image_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgirbal/segment-select-correct/0d36714c1166a598051e467b0bddc5ec23a67d6a/examples/image_3.jpg -------------------------------------------------------------------------------- /examples/inference_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from PIL import Image 3 | 4 | import torch 5 | import matplotlib.pyplot as plt 6 | 7 | from ssc_ris.refer_dataset.bert.modeling_bert import BertModel 8 | from ssc_ris.refer_dataset.bert.tokenization_bert import BertTokenizer 9 | from ssc_ris.refer_dataset.utils import get_transform 10 | from ssc_ris.utils.lavt_lib import segmentation 11 | 12 | 13 | def get_parser(): 14 | parser = argparse.ArgumentParser(description="LAVT training and testing") 15 | parser.add_argument( 16 | "--model-checkpoint", required=True, help="model from checkpoint" 17 | ) 18 | parser.add_argument("--input-image", required=True, help="input image") 19 | parser.add_argument("--sentence", required=True, help="RIS sentence") 20 | parser.add_argument("--device", default="cuda:0", help="device") 21 | 22 | # model parameters 23 | parser.add_argument( 24 | "--fusion_drop", default=0.0, type=float, help="dropout rate for PWAMs" 25 | ) 26 | parser.add_argument( 27 | "--mha", 28 | default="", 29 | help="If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4," 30 | "where a, b, c, and d refer to the numbers of heads in stage-1," 31 | "stage-2, stage-3, and stage-4 PWAMs", 32 | ) 33 | parser.add_argument( 34 | "--swin_type", 35 | default="base", 36 | help="tiny, small, base, or large variants of the Swin Transformer", 37 | ) 38 | parser.add_argument( 39 | "--window12", 40 | action="store_true", 41 | help="only needs specified when testing," 42 | "when training, window size is inferred from pre-trained weights file name" 43 | "(containing 'window12'). Initialize Swin with window size 12 instead of the default 7.", 44 | ) 45 | 46 | return parser 47 | 48 | 49 | def plot_side_by_side(img, mask, sentence): 50 | _, axs = plt.subplots(1, 2, figsize=(8, 4)) 51 | 52 | axs[0].imshow(img) 53 | axs[0].set_axis_off() 54 | 55 | axs[1].imshow(img) 56 | axs[1].imshow(mask, alpha=0.6) 57 | axs[1].set_title(sentence) 58 | axs[1].set_axis_off() 59 | 60 | plt.tight_layout() 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = get_parser() 65 | args = parser.parse_args() 66 | 67 | model_checkpoint = args.model_checkpoint 68 | sentence_raw = args.sentence 69 | img_path = args.input_image 70 | device = args.device 71 | 72 | img_size = 480 73 | bert_model_name = "bert-base-uncased" 74 | args.window12 = True 75 | 76 | # load the model 77 | single_model = segmentation.__dict__["lavt"](pretrained="", args=args) 78 | checkpoint = torch.load(model_checkpoint, map_location="cpu") 79 | single_model.load_state_dict(checkpoint["model"]) 80 | model = single_model.to(device) 81 | 82 | model_class = BertModel 83 | single_bert_model = model_class.from_pretrained(bert_model_name) 84 | single_bert_model.pooler = None 85 | 86 | single_bert_model.load_state_dict(checkpoint["bert_model"]) 87 | bert_model = single_bert_model.to(device) 88 | 89 | tokenizer = BertTokenizer.from_pretrained(bert_model_name) 90 | attention_mask = [0] * 20 91 | padded_input_ids = [0] * 20 92 | input_ids = tokenizer.encode(text=sentence_raw, add_special_tokens=True) 93 | 94 | # truncation of tokens 95 | input_ids = input_ids[:20] 96 | padded_input_ids[: len(input_ids)] = input_ids 97 | attention_mask[: len(input_ids)] = [1] * len(input_ids) 98 | 99 | sentence = torch.tensor(padded_input_ids).unsqueeze(0).to(device) 100 | attention = torch.tensor(attention_mask).unsqueeze(0).to(device) 101 | 102 | # load image and perform the input transformations 103 | orig_img = Image.open(img_path).convert("RGB") 104 | transform = get_transform(img_size=img_size) 105 | img, _ = transform(orig_img, orig_img) 106 | img = img.unsqueeze(0).to(device) 107 | 108 | # inference 109 | last_hidden_states = bert_model(sentence, attention_mask=attention)[0] 110 | embedding = last_hidden_states.permute(0, 2, 1) 111 | output = model(img, embedding, l_mask=attention.unsqueeze(-1)) 112 | output_mask = output.cpu().argmax(1).data.numpy() 113 | 114 | # plot and save output 115 | plot_side_by_side( 116 | orig_img.resize((img_size, img_size)), output_mask[0], sentence_raw 117 | ) 118 | plt.savefig("example_output.png") 119 | -------------------------------------------------------------------------------- /examples/output_image_1_refcoco.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgirbal/segment-select-correct/0d36714c1166a598051e467b0bddc5ec23a67d6a/examples/output_image_1_refcoco.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | filelock 2 | pycocotools 3 | timm 4 | nltk 5 | spacy 6 | numpy 7 | opencv-python 8 | matplotlib 9 | scipy 10 | scikit-image 11 | wandb 12 | tokenizers==0.8.1rc1 13 | mmcv-full==1.3.12 14 | mmsegmentation==0.17.0 15 | ftfy 16 | regex 17 | git+https://github.com/facebookresearch/segment-anything.git 18 | git+https://github.com/openai/CLIP.git 19 | -------------------------------------------------------------------------------- /scripts/create_all_masks_dataset.py: -------------------------------------------------------------------------------- 1 | # Segment: Stage 1 of the three-stage framework presented 2 | import os 3 | import json 4 | import inspect 5 | import pickle 6 | import argparse 7 | 8 | from tqdm import tqdm 9 | import numpy as np 10 | import torch 11 | import clip 12 | import groundingdino 13 | from groundingdino.util.inference import Model as GroundingDino 14 | from segment_anything import sam_model_registry, SamPredictor 15 | 16 | from ssc_ris.refer_dataset.utils import get_dataset, save_binary_object_masks 17 | from ssc_ris.segment import get_nouns_and_noun_phrases_nltk, get_nouns_and_noun_phrases_spacy, segment_from_image_and_nouns, project_to_dataset_classes 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | 23 | parser.add_argument('-n', '--name', type=str, required=True) 24 | parser.add_argument('-f', '--full', action='store_true') 25 | 26 | # dataset arguments 27 | parser.add_argument('-d', '--dataset', type=str, default="refcoco") 28 | parser.add_argument('--dataset-root', type=str, default="refer/data") 29 | parser.add_argument('--split', type=str, default="train") 30 | parser.add_argument('--split-by', type=str, default="unc", choices=["unc", "umd", "google"]) 31 | parser.add_argument('-s', '--start-index', type=int, default=0) 32 | 33 | # text projection and segmentation 34 | parser.add_argument('--bg-threshold', type=float, default=0.95) 35 | parser.add_argument('--project', action='store_true') 36 | parser.add_argument('--keep-top-k-matches', type=int, default=1) 37 | parser.add_argument('--context-projections', action='store_true') 38 | parser.add_argument('--noun-extraction', type=str, choices=["nltk", "spacy"], default="spacy") 39 | parser.add_argument( 40 | '--one-query-per-noun', 41 | action='store_true', 42 | help="if true, will call the segmentation model once for each noun in the sentences; leads to more false positives" 43 | ) 44 | parser.add_argument( 45 | '--most-likely-noun', 46 | action='store_true', 47 | help="query only using the most likely noun in the set of sentences" 48 | ) 49 | 50 | args = parser.parse_args() 51 | return args 52 | 53 | 54 | if __name__ == "__main__": 55 | args = parse_args() 56 | 57 | experiment_name = args.name 58 | dataset = args.dataset 59 | dataset_root = args.dataset_root 60 | data_split = args.split 61 | img_size = 480 62 | bg_threshold = args.bg_threshold 63 | keep_top_k_matches = args.keep_top_k_matches 64 | project_to_COCO_classes = args.project 65 | with_context_projections = args.context_projections 66 | noun_extraction = args.noun_extraction 67 | test_on_small_subset = not args.full 68 | 69 | context_string = "a photo of a " 70 | COCO_obj_list = np.array([ 71 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 72 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 73 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 74 | 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 75 | 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 76 | 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 77 | 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 78 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 79 | 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 80 | 'toothbrush' 81 | ]) 82 | 83 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 84 | if project_to_COCO_classes: 85 | clip_model, preprocess = clip.load("ViT-B/32", device=device) 86 | 87 | if not with_context_projections: 88 | if os.path.isfile("CLIP_COCO_obj_list_features.pb"): 89 | with open("CLIP_COCO_obj_list_features.pb", "rb") as f: 90 | CLIP_COCO_obj_list_features = pickle.load(f).to(device) 91 | else: 92 | with torch.no_grad(): 93 | CLIP_COCO_obj_list_features = torch.vstack([ 94 | clip_model.encode_text(clip.tokenize(obj_name).to(device)) for obj_name in COCO_obj_list 95 | ]) 96 | 97 | CLIP_COCO_obj_list_features = CLIP_COCO_obj_list_features.to(torch.float32) 98 | CLIP_COCO_obj_list_features /= CLIP_COCO_obj_list_features.norm(dim=-1, keepdim=True) 99 | 100 | with open("CLIP_COCO_obj_list_features.pb", "wb") as f: 101 | pickle.dump(CLIP_COCO_obj_list_features, f) 102 | else: 103 | if os.path.isfile("CLIP_COCO_obj_list_features_w_context.pb"): 104 | with open("CLIP_COCO_obj_list_features_w_context.pb", "rb") as f: 105 | CLIP_COCO_obj_list_features = pickle.load(f).to(device) 106 | else: 107 | with torch.no_grad(): 108 | CLIP_COCO_obj_list_features = torch.vstack([ 109 | clip_model.encode_text(clip.tokenize(context_string + obj_name).to(device)) for obj_name in COCO_obj_list 110 | ]) 111 | 112 | CLIP_COCO_obj_list_features = CLIP_COCO_obj_list_features.to(torch.float32) 113 | CLIP_COCO_obj_list_features /= CLIP_COCO_obj_list_features.norm(dim=-1, keepdim=True) 114 | 115 | with open("CLIP_COCO_obj_list_features_w_context.pb", "wb") as f: 116 | pickle.dump(CLIP_COCO_obj_list_features, f) 117 | 118 | print("loading GroundingDino...") 119 | grounding_dino_root = os.path.dirname(groundingdino.__file__) 120 | grounding_dino_model = GroundingDino( 121 | model_config_path=os.path.join(grounding_dino_root, "config/GroundingDINO_SwinT_OGC.py"), 122 | model_checkpoint_path=os.path.join(grounding_dino_root, "../weights/groundingdino_swint_ogc.pth") 123 | ) 124 | print("loaded GroundingDino") 125 | 126 | print("loading SAM...") 127 | sam = sam_model_registry["vit_h"]( 128 | checkpoint=os.path.join(grounding_dino_root, "../../sam_weights/sam_vit_h_4b8939.pth") 129 | ).to(device=device) 130 | sam_predictor = SamPredictor(sam) 131 | print("loaded SAM") 132 | 133 | experiment_folder = f"segment_stage_masks/{experiment_name}/" 134 | if data_split != "train": 135 | experiment_folder = f"segment_stage_masks_{data_split}/{experiment_name}/" 136 | 137 | output_mask_folder = experiment_folder 138 | if dataset != "refcocog": 139 | output_mask_folder += f"instance_masks/{dataset}" 140 | else: 141 | output_mask_folder += f"instance_masks/{dataset}_{args.split_by}" 142 | 143 | if not os.path.exists(output_mask_folder): 144 | os.makedirs(output_mask_folder) 145 | 146 | masks_viz_save_path = f"segment_stage_masks/{experiment_name}/masks_viz/{dataset}" 147 | if not os.path.exists(masks_viz_save_path): 148 | os.makedirs(masks_viz_save_path) 149 | 150 | old_dataset, num_classes = get_dataset( 151 | dataset, 152 | dataset_root, 153 | data_split, 154 | split_by=args.split_by 155 | ) 156 | coco_images_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), old_dataset.refer.IMAGE_DIR) 157 | 158 | # write the configuration file to be able to trace which parameters were used in generating this data 159 | with open(experiment_folder + f"generation_details_{dataset}_{args.split_by}.log", 'w') as f: 160 | json.dump(vars(args), f, indent=4) 161 | 162 | # testing on a small subset to determine whether the generated masks are good enough or not 163 | if test_on_small_subset: 164 | print("Testing on 1000 examples...") 165 | n_examples = 1000 166 | else: 167 | print("Creating the full dataset...") 168 | n_examples = len(old_dataset) 169 | 170 | if args.start_index != 0: 171 | print(f"Starting mask generation from index {args.start_index} of the dataset...") 172 | 173 | current_image = "" 174 | image_sentences = [] 175 | image_nouns = [] 176 | ref_ids = [] 177 | save_path = None 178 | for i in tqdm(range(n_examples)): 179 | if i < args.start_index: 180 | continue 181 | 182 | dataset_object = old_dataset[i] 183 | 184 | if current_image == "": 185 | # this is the first image ever, just write it down and move on 186 | current_image = dataset_object[-1]['file_name'] 187 | elif dataset_object[-1]['file_name'] != current_image: 188 | # it's a new image; process the info in the noun buffer and clear it to prepare for new ones 189 | if test_on_small_subset: 190 | save_path = os.path.join(masks_viz_save_path, f"mask_{i}.png") 191 | 192 | object_masks = segment_from_image_and_nouns( 193 | grounding_dino_model, 194 | sam_predictor, 195 | current_image, 196 | coco_images_directory, 197 | image_nouns, 198 | save_path, 199 | image_sentences, 200 | one_query_per_noun=args.one_query_per_noun, 201 | most_likely_noun=args.most_likely_noun 202 | ) 203 | save_binary_object_masks( 204 | object_masks, 205 | ref_ids, 206 | output_mask_folder 207 | ) 208 | 209 | current_image = dataset_object[-1]['file_name'] 210 | image_sentences = [] 211 | image_nouns = [] 212 | ref_ids = [] 213 | 214 | obj_nouns = [] 215 | for sentence in dataset_object[-1]['sentences_sent']: 216 | if noun_extraction == "nltk": 217 | nouns = get_nouns_and_noun_phrases_nltk(sentence) 218 | if len(nouns) == 0: 219 | nouns = get_nouns_and_noun_phrases_spacy(sentence) 220 | elif noun_extraction == "spacy": 221 | nouns = get_nouns_and_noun_phrases_spacy(sentence) 222 | if len(nouns) == 0: 223 | nouns = get_nouns_and_noun_phrases_nltk(sentence) 224 | 225 | if project_to_COCO_classes: 226 | nouns = project_to_dataset_classes( 227 | clip_model, 228 | nouns, 229 | dataset_object_list=COCO_obj_list, 230 | dataset_object_list_features=CLIP_COCO_obj_list_features, 231 | with_context=with_context_projections, 232 | top_k=keep_top_k_matches 233 | ) 234 | 235 | obj_nouns.append(nouns) 236 | 237 | image_sentences.append(dataset_object[-1]['sentences_sent']) 238 | image_nouns.append(obj_nouns) 239 | ref_ids.append(dataset_object[-1]['ref_id']) 240 | 241 | object_masks = segment_from_image_and_nouns( 242 | grounding_dino_model, 243 | sam_predictor, 244 | current_image, 245 | coco_images_directory, 246 | image_nouns, 247 | save_path=None, 248 | sentences=image_sentences, 249 | one_query_per_noun=args.one_query_per_noun, 250 | most_likely_noun=args.most_likely_noun 251 | ) 252 | save_binary_object_masks( 253 | object_masks, 254 | ref_ids, 255 | output_mask_folder 256 | ) 257 | -------------------------------------------------------------------------------- /scripts/create_zero_shot_masks_dataset.py: -------------------------------------------------------------------------------- 1 | # Select: Stage 2 of the three-stage framework presented 2 | import os 3 | import argparse 4 | from PIL import Image 5 | import random 6 | 7 | from tqdm import tqdm 8 | import numpy as np 9 | import clip 10 | 11 | from ssc_ris.refer_dataset.utils import get_dataset, save_binary_object_mask 12 | from ssc_ris.select.utils import separate_instance_masks 13 | from ssc_ris.select import vp_and_max_sim_clip_choice 14 | 15 | def best_mask_choice(masks, gt_mask): 16 | print("------------- WARNING: best mask being chosen w.r.t. GT, should only be used for testing -------------") 17 | 18 | best_mask_iou = -np.Inf 19 | best_mask = None 20 | for pseudo_mask_ in masks: 21 | intersect = (gt_mask & pseudo_mask_) 22 | intersect_area = intersect.sum().astype(np.float32) 23 | union = (gt_mask | pseudo_mask_) 24 | 25 | iou = intersect_area / union.sum() 26 | if iou > best_mask_iou: 27 | best_mask_iou = iou 28 | best_mask = pseudo_mask_ 29 | 30 | return best_mask 31 | 32 | 33 | def random_mask_choice(image, masks): 34 | return random.choice(masks) 35 | 36 | 37 | def first_mask_choice(image, masks): 38 | return masks[0] 39 | 40 | 41 | def parse_args(): 42 | parser = argparse.ArgumentParser() 43 | 44 | parser.add_argument('--original-name', type=str, required=True) 45 | parser.add_argument('--new-name', type=str, required=True) 46 | parser.add_argument('-f', '--full', action='store_true') 47 | 48 | parser.add_argument('-d', '--dataset', type=str, default="refcoco") 49 | parser.add_argument('--dataset-root', type=str, default="refer/data") 50 | parser.add_argument('--split', type=str, default="train") 51 | parser.add_argument('--split-by', type=str, default="unc", choices=["unc", "umd", "google"]) 52 | 53 | parser.add_argument( 54 | '--mask-choice', 55 | type=str, 56 | required=True, 57 | choices=[ 58 | "best", 59 | "random", 60 | "first", 61 | "red_ellipse", 62 | "rectangle", 63 | "red_dense_mask", 64 | "reverse_blur" 65 | ] 66 | ) 67 | 68 | args = parser.parse_args() 69 | return args 70 | 71 | 72 | if __name__ == "__main__": 73 | args = parse_args() 74 | 75 | original_name = args.original_name 76 | new_name = args.new_name 77 | dataset = args.dataset 78 | dataset_root = args.dataset_root 79 | data_split = args.split 80 | test_on_small_subset = not args.full 81 | img_size = 480 82 | 83 | original_pseudo_mask_folder = f"segment_stage_masks/{original_name}/" 84 | if data_split != "train": 85 | original_pseudo_mask_folder = f"segment_stage_masks_{data_split}/{original_name}/" 86 | 87 | new_chosen_mask_folder = f"select_stage_masks/{new_name}/" 88 | if data_split != "train": 89 | new_chosen_mask_folder = f"select_stage_masks_{data_split}/{new_name}/" 90 | 91 | if dataset != "refcocog": 92 | original_pseudo_mask_folder += f"instance_masks/{dataset}" 93 | new_chosen_mask_folder += f"instance_masks/{dataset}" 94 | else: 95 | original_pseudo_mask_folder += f"instance_masks/{dataset}_{args.split_by}" 96 | new_chosen_mask_folder += f"instance_masks/{dataset}_{args.split_by}" 97 | 98 | if not os.path.exists(new_chosen_mask_folder): 99 | os.makedirs(new_chosen_mask_folder) 100 | 101 | old_dataset, num_classes = get_dataset( 102 | dataset, 103 | dataset_root, 104 | data_split, 105 | split_by=args.split_by 106 | ) 107 | coco_images_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), old_dataset.refer.IMAGE_DIR) 108 | 109 | # testing on a small subset to determine whether the generated masks are good enough or not 110 | if test_on_small_subset: 111 | print("Testing on 1000 examples...") 112 | n_examples = 1000 113 | else: 114 | print("Creating the full dataset...") 115 | n_examples = len(old_dataset) 116 | 117 | if args.mask_choice not in ["best", "random", "first"]: 118 | print("loading CLIP...") 119 | device = "cuda:0" 120 | clip_model, preprocess = clip.load("ViT-L/14@336px", device=device) 121 | print("loaded") 122 | 123 | for i in tqdm(range(n_examples)): 124 | dataset_object = old_dataset[i] 125 | 126 | ref_id = dataset_object[-1]['ref_id'] 127 | gt_mask = np.array(dataset_object[1]) 128 | 129 | pseudo_mask_img_name = os.path.join(original_pseudo_mask_folder, f"pseudo_gt_mask_{ref_id}.png") 130 | instance_merged_pseudo_mask = np.array(Image.open(pseudo_mask_img_name).convert('L')) 131 | 132 | all_instance_masks = separate_instance_masks(instance_merged_pseudo_mask) 133 | 134 | if len(all_instance_masks) == 0: 135 | chosen_mask = instance_merged_pseudo_mask 136 | else: 137 | if args.mask_choice == "best": 138 | chosen_mask = best_mask_choice(all_instance_masks, gt_mask) 139 | elif args.mask_choice == "random": 140 | chosen_mask = random_mask_choice(dataset_object[0], all_instance_masks) 141 | elif args.mask_choice == "first": 142 | chosen_mask = first_mask_choice(dataset_object[0], all_instance_masks) 143 | else: 144 | chosen_mask = vp_and_max_sim_clip_choice( 145 | clip_model, 146 | preprocess, 147 | dataset_object[0], 148 | all_instance_masks, 149 | dataset_object[-1]['sentences_sent'], 150 | method=args.mask_choice 151 | ) 152 | 153 | # couldn't decide on a mask, return an empty one to ignore this example in training 154 | if chosen_mask is None: 155 | chosen_mask = np.zeros_like(instance_merged_pseudo_mask) 156 | 157 | save_binary_object_mask( 158 | chosen_mask, 159 | ref_id, 160 | new_chosen_mask_folder 161 | ) 162 | -------------------------------------------------------------------------------- /scripts/test_unsupervised_masks_dataset.py: -------------------------------------------------------------------------------- 1 | # Testing script for Stages 1 and 2 2 | import os 3 | import json 4 | import argparse 5 | import pickle 6 | from PIL import Image 7 | import random 8 | 9 | from scipy.ndimage import gaussian_filter 10 | from tqdm import tqdm 11 | import matplotlib.pyplot as plt 12 | import matplotlib.cm as cm 13 | from matplotlib.colors import ListedColormap 14 | import numpy as np 15 | import torch 16 | import nltk 17 | import clip 18 | import cv2 19 | 20 | from ssc_ris.refer_dataset.utils import get_dataset 21 | from ssc_ris.select.utils import separate_instance_masks 22 | from ssc_ris.select import vp_and_max_sim_clip_choice 23 | 24 | 25 | colormap = np.array([ 26 | [ 0, 0, 0, 0], 27 | [ 245, 233, 66, 128] 28 | ], dtype=np.float32) 29 | colormap /= 255.0 30 | seg_colormap = ListedColormap(colormap) 31 | 32 | def plot_img_pseudo_gt(image, pseudo_mask, gt_mask, output_filename): 33 | fig, axs = plt.subplots(1, 2, figsize=(8, 4)) 34 | 35 | axs[0].imshow(image) 36 | axs[0].imshow(gt_mask, cmap=seg_colormap) 37 | axs[0].set_title('GT') 38 | axs[0].axis('off') 39 | 40 | axs[1].imshow(image) 41 | axs[1].imshow(pseudo_mask, cmap=seg_colormap) 42 | axs[1].set_title('Pseudo') 43 | axs[1].axis('off') 44 | 45 | plt.tight_layout() 46 | plt.savefig(output_filename) 47 | 48 | fig.clf() 49 | plt.close() 50 | 51 | 52 | def best_mask_choice(masks, gt_mask): 53 | print("------------- WARNING: best mask being chosen w.r.t. GT, should only be used for testing -------------") 54 | 55 | best_mask_iou = -np.Inf 56 | best_mask = None 57 | for pseudo_mask_ in masks: 58 | intersect = (gt_mask & pseudo_mask_) 59 | intersect_area = intersect.sum().astype(np.float32) 60 | union = (gt_mask | pseudo_mask_) 61 | 62 | iou = intersect_area / union.sum() 63 | if iou > best_mask_iou: 64 | best_mask_iou = iou 65 | best_mask = pseudo_mask_ 66 | 67 | return best_mask 68 | 69 | 70 | def random_mask_choice(image, masks): 71 | return random.choice(masks) 72 | 73 | 74 | def parse_args(): 75 | parser = argparse.ArgumentParser() 76 | 77 | parser.add_argument('-n', '--name', type=str, required=True) 78 | parser.add_argument('-s', '--stage', type=str, required=True, choices=["segment", "select"]) 79 | parser.add_argument('-d', '--dataset', type=str, default="refcoco") 80 | parser.add_argument('--dataset-root', type=str, default="refer/data") 81 | parser.add_argument('--split', type=str, default="train") 82 | parser.add_argument('--split-by', type=str, default="unc", choices=["unc", "umd", "google"]) 83 | 84 | parser.add_argument( 85 | '--mask-choice', 86 | type=str, 87 | default="best", 88 | choices=[ 89 | "best", 90 | "random", 91 | "red_ellipse", 92 | "red_ellipse_vit_32", 93 | "rectangle", 94 | "red_dense_mask", 95 | "reverse_blur", 96 | "reverse_blur_vit_32" 97 | ] 98 | ) 99 | 100 | args = parser.parse_args() 101 | return args 102 | 103 | 104 | if __name__ == "__main__": 105 | args = parse_args() 106 | 107 | experiment_name = args.name 108 | dataset = args.dataset 109 | dataset_root = args.dataset_root 110 | data_split = args.split 111 | img_size = 480 112 | 113 | if args.stage == "segment": 114 | trail = "segment_stage_masks" 115 | else: 116 | trail = "select_stage_masks" 117 | 118 | output_path = f"{trail}/{experiment_name}" 119 | if data_split != "train": 120 | output_path = f"{trail}_{data_split}/{experiment_name}" 121 | 122 | pseudo_mask_folder = f"{output_path}/" 123 | if dataset != "refcocog": 124 | pseudo_mask_folder += f"instance_masks/{dataset}" 125 | else: 126 | pseudo_mask_folder += f"instance_masks/{dataset}_{args.split_by}" 127 | 128 | old_dataset, num_classes = get_dataset( 129 | dataset, 130 | dataset_root, 131 | data_split, 132 | split_by=args.split_by 133 | ) 134 | coco_images_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), old_dataset.refer.IMAGE_DIR) 135 | 136 | n_masks_per_image = [] 137 | empty_masks = 0 138 | intersection = [] 139 | iou = [] 140 | cum_I, cum_U = 0, 0 141 | 142 | if data_split == "train": 143 | range_end = 1000 144 | else: 145 | range_end = len(old_dataset) 146 | 147 | if args.mask_choice not in ["best", "random"]: 148 | print("loading CLIP...") 149 | device = "cuda:0" 150 | if args.mask_choice == "reverse_blur_vit_32" or args.mask_choice == "red_ellipse_vit_32": 151 | clip_model, preprocess = clip.load("ViT-B/32", device=device) 152 | else: 153 | clip_model, preprocess = clip.load("ViT-L/14@336px", device=device) 154 | 155 | print("loaded") 156 | 157 | for i in tqdm(range(range_end)): 158 | dataset_object = old_dataset[i] 159 | 160 | ref_id = dataset_object[-1]['ref_id'] 161 | gt_mask = np.array(dataset_object[1]) 162 | 163 | pseudo_mask_img_name = os.path.join(pseudo_mask_folder, f"pseudo_gt_mask_{ref_id}.png") 164 | instance_merged_pseudo_mask = np.array(Image.open(pseudo_mask_img_name).convert('L')) 165 | 166 | all_instance_masks = separate_instance_masks(instance_merged_pseudo_mask) 167 | n_masks_per_image.append(len(all_instance_masks)) 168 | 169 | if len(all_instance_masks) == 0: 170 | intersection.append(0.0) 171 | iou.append(0.0) 172 | else: 173 | if args.mask_choice == "best": 174 | chosen_mask = best_mask_choice(all_instance_masks, gt_mask) 175 | elif args.mask_choice == "random": 176 | chosen_mask = random_mask_choice(dataset_object[0], all_instance_masks) 177 | else: 178 | chosen_mask = vp_and_max_sim_clip_choice( 179 | clip_model, 180 | preprocess, 181 | dataset_object[0], 182 | all_instance_masks, 183 | dataset_object[-1]['sentences_sent'], 184 | method=args.mask_choice 185 | ) 186 | 187 | # couldn't decide on a mask, return an empty one 188 | if chosen_mask is None: 189 | chosen_mask = np.zeros_like(instance_merged_pseudo_mask) 190 | 191 | intersect = (gt_mask & chosen_mask) 192 | intersect_area = intersect.sum().astype(np.float32) 193 | union = (gt_mask | chosen_mask) 194 | union_area = union.sum() 195 | 196 | intersection.append(intersect_area / gt_mask.sum()) 197 | iou.append(intersect_area / union_area) 198 | cum_I += intersect_area 199 | cum_U += union_area 200 | 201 | if instance_merged_pseudo_mask.sum() == 0: 202 | empty_masks += 1 203 | 204 | print('mIoU:', np.array(iou).mean()) 205 | print('oIoU:', cum_I / cum_U) 206 | 207 | n_masks_per_image = np.array(n_masks_per_image) 208 | print('------') 209 | print('mean # masks:', n_masks_per_image.mean()) 210 | print('max # masks:', n_masks_per_image.max()) 211 | 212 | data_stats = { 213 | 'mIoU': np.array(iou).mean(), 214 | 'oIoU': cum_I / cum_U, 215 | 'mean # masks': n_masks_per_image.mean(), 216 | 'max # masks': float(n_masks_per_image.max()) 217 | } 218 | 219 | with open(output_path + f"/test_results_{args.mask_choice}_{dataset}_{args.split_by}.json", 'w') as f: 220 | json.dump(data_stats, f, indent=4) 221 | -------------------------------------------------------------------------------- /scripts/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgirbal/segment-select-correct/0d36714c1166a598051e467b0bddc5ec23a67d6a/scripts/train/__init__.py -------------------------------------------------------------------------------- /scripts/train/test_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import torch.utils.data 7 | 8 | from ssc_ris.refer_dataset.bert.modeling_bert import BertModel 9 | from ssc_ris.refer_dataset.utils import get_dataset, get_transform 10 | from ssc_ris.utils.lavt_lib import segmentation 11 | import ssc_ris.utils.lavt_lib.lavt_utils as utils 12 | 13 | from train_test_args import get_parser 14 | 15 | 16 | def evaluate( 17 | model: torch.nn.Module, 18 | data_loader: torch.utils.data.DataLoader, 19 | bert_model: BertModel, 20 | device: torch.device, 21 | model_name: str, 22 | split_name: str 23 | ): 24 | """Evaluate a model on the data loader provided. Prints and writes results to file. 25 | 26 | Args: 27 | model (torch.nn.Module): model to be evaluated 28 | data_loader (torch.utils.data.DataLoader): val/test dataloader 29 | bert_model (BertModel): Bert model 30 | device (torch.device): device where to perform the computations 31 | model_name (str): model identifier for results 32 | split_name (str): dataset split identifier to file writing 33 | """ 34 | def computeIoU(pred_seg: np.ndarray, gd_seg: np.ndarray) -> Tuple[float, float]: 35 | I = np.sum(np.logical_and(pred_seg, gd_seg)) 36 | U = np.sum(np.logical_or(pred_seg, gd_seg)) 37 | 38 | return I, U 39 | 40 | model.eval() 41 | metric_logger = utils.MetricLogger(delimiter=" ") 42 | 43 | # evaluation variables 44 | cum_I, cum_U = 0, 0 45 | eval_seg_iou_list = [.5, .6, .7, .8, .9] 46 | seg_correct = np.zeros(len(eval_seg_iou_list), dtype=np.int32) 47 | seg_total = 0 48 | mean_IoU = [] 49 | header = 'Test:' 50 | 51 | with torch.no_grad(): 52 | for data in metric_logger.log_every(data_loader, 100, header): 53 | image, target, sentences, attentions = data 54 | image, target, sentences, attentions = image.to(device), target.to(device), \ 55 | sentences.to(device), attentions.to(device) 56 | sentences = sentences.squeeze(1) 57 | attentions = attentions.squeeze(1) 58 | target = target.cpu().data.numpy() 59 | for j in range(sentences.size(-1)): 60 | if bert_model is not None: 61 | last_hidden_states = bert_model(sentences[:, :, j], attention_mask=attentions[:, :, j])[0] 62 | embedding = last_hidden_states.permute(0, 2, 1) 63 | output = model(image, embedding, l_mask=attentions[:, :, j].unsqueeze(-1)) 64 | else: 65 | output = model(image, sentences[:, :, j], l_mask=attentions[:, :, j]) 66 | 67 | output = output.cpu() 68 | output_mask = output.argmax(1).data.numpy() 69 | I, U = computeIoU(output_mask, target) 70 | if U == 0: 71 | this_iou = 0.0 72 | else: 73 | this_iou = I*1.0/U 74 | mean_IoU.append(this_iou) 75 | cum_I += I 76 | cum_U += U 77 | for n_eval_iou in range(len(eval_seg_iou_list)): 78 | eval_seg_iou = eval_seg_iou_list[n_eval_iou] 79 | seg_correct[n_eval_iou] += (this_iou >= eval_seg_iou) 80 | seg_total += 1 81 | 82 | del image, target, sentences, attentions, output, output_mask 83 | if bert_model is not None: 84 | del last_hidden_states, embedding 85 | 86 | mean_IoU = np.array(mean_IoU) 87 | mIoU = np.mean(mean_IoU) 88 | results_str = 'Final results:\nMean IoU is %.2f\n' % (mIoU*100.) 89 | for n_eval_iou in range(len(eval_seg_iou_list)): 90 | results_str += ' precision@%s = %.2f\n' % \ 91 | (str(eval_seg_iou_list[n_eval_iou]), seg_correct[n_eval_iou] * 100. / seg_total) 92 | results_str += ' overall IoU = %.2f\n' % (cum_I * 100. / cum_U) 93 | print(results_str) 94 | 95 | if not os.path.exists("results/"): 96 | os.makedirs("results/") 97 | 98 | with open(f"results/{model_name[:-4]}_{split_name}.txt", 'w') as f: 99 | f.write(results_str) 100 | 101 | 102 | def main(args): 103 | device = torch.device(args.device) 104 | 105 | # load the dataset and prep the dataloader 106 | dataset_test, _ = get_dataset( 107 | dataset=args.dataset, 108 | dataset_root=args.refer_data_root, 109 | data_split=args.split, 110 | transforms=get_transform(img_size=args.img_size), 111 | split_by=args.split_by, 112 | return_attributes=False, 113 | eval_model=True 114 | ) 115 | test_sampler = torch.utils.data.SequentialSampler(dataset_test) 116 | data_loader_test = torch.utils.data.DataLoader( 117 | dataset_test, 118 | batch_size=1, 119 | sampler=test_sampler, 120 | num_workers=args.workers 121 | ) 122 | 123 | # load the model 124 | single_model = segmentation.__dict__["lavt"](pretrained='',args=args) 125 | checkpoint = torch.load(args.resume, map_location='cpu') 126 | single_model.load_state_dict(checkpoint['model']) 127 | model = single_model.to(device) 128 | 129 | model_class = BertModel 130 | single_bert_model = model_class.from_pretrained(args.ck_bert) 131 | # work-around for a transformers bug; need to update to a newer version of transformers to remove these two lines 132 | if args.ddp_trained_weights: 133 | single_bert_model.pooler = None 134 | 135 | single_bert_model.load_state_dict(checkpoint['bert_model']) 136 | bert_model = single_bert_model.to(device) 137 | 138 | evaluate( 139 | model, 140 | data_loader_test, 141 | bert_model, 142 | device=device, 143 | model_name=os.path.basename(args.resume), 144 | split_name=args.split 145 | ) 146 | 147 | 148 | if __name__ == "__main__": 149 | parser = get_parser() 150 | args = parser.parse_args() 151 | 152 | main(args) 153 | -------------------------------------------------------------------------------- /scripts/train/train_test_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_parser(): 5 | parser = argparse.ArgumentParser(description='LAVT training and testing') 6 | parser.add_argument('--amsgrad', action='store_true', 7 | help='if true, set amsgrad to True in an Adam or AdamW optimizer.') 8 | parser.add_argument('-b', '--batch-size', default=3, type=int) 9 | parser.add_argument('--batch-size-unfolded-limit', default=8, type=int) 10 | parser.add_argument('--bert_tokenizer', default='bert-base-uncased', help='BERT tokenizer') 11 | parser.add_argument('--ck_bert', default='bert-base-uncased', help='pre-trained BERT weights') 12 | parser.add_argument('--dataset', default='refcoco', help='refcoco, refcoco+, or refcocog') 13 | parser.add_argument('--one-sentence-per-item', action='store_true', help='if passed, a random sentence will be taken instead of returning all of the sentences') 14 | parser.add_argument('--ddp_trained_weights', action='store_true', 15 | help='Only needs specified when testing,' 16 | 'whether the weights to be loaded are from a DDP-trained model') 17 | parser.add_argument('--device', default='cuda:0', help='device') # only used when testing on a single machine 18 | parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run') 19 | parser.add_argument('--fusion_drop', default=0.0, type=float, help='dropout rate for PWAMs') 20 | parser.add_argument('--img_size', default=480, type=int, help='input image size') 21 | parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel') 22 | parser.add_argument('--lr', default=0.00005, type=float, help='the initial learning rate') 23 | parser.add_argument('--contrastive-alpha', default=0.01, type=float, help='weight of the contrastive term of loss') 24 | parser.add_argument('--mha', default='', help='If specified, should be in the format of a-b-c-d, e.g., 4-4-4-4,' 25 | 'where a, b, c, and d refer to the numbers of heads in stage-1,' 26 | 'stage-2, stage-3, and stage-4 PWAMs') 27 | parser.add_argument('--model_id', default='lavt', help='name to identify the model') 28 | parser.add_argument('--output-dir', default='./unsupervised/checkpoints/', help='path where to save checkpoint weights') 29 | parser.add_argument('--model-experiment-name', default='model', help='identifier of the model') 30 | parser.add_argument('--pin_mem', action='store_true', 31 | help='If true, pin memory when using the data loader.') 32 | parser.add_argument('--pretrained_swin_weights', default='', 33 | help='path to pre-trained Swin backbone weights') 34 | parser.add_argument('--print-freq', default=10, type=int, help='print frequency') 35 | parser.add_argument('--refer_data_root', default='./refer/data/', help='REFER dataset root directory') 36 | parser.add_argument('--pseudo_masks_root', default='./outputs/only_projected_classes_top_3/group_vit_masks/', help='Path to unsupervised masks') 37 | parser.add_argument('--resume', default='', help='resume from checkpoint') 38 | parser.add_argument('--init-from', default='', help='init-from from checkpoint') 39 | parser.add_argument('--split', default='test', help='only used when testing') 40 | parser.add_argument('--split-by', default='unc', help='change to umd or google when the dataset is G-Ref (RefCOCOg)') 41 | parser.add_argument('--swin_type', default='base', 42 | help='tiny, small, base, or large variants of the Swin Transformer') 43 | parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay', 44 | dest='weight_decay') 45 | parser.add_argument('--window12', action='store_true', 46 | help='only needs specified when testing,' 47 | 'when training, window size is inferred from pre-trained weights file name' 48 | '(containing \'window12\'). Initialize Swin with window size 12 instead of the default 7.') 49 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers') 50 | parser.add_argument('--groundtruth-masks', action='store_true', help='if passed, use the groundtruth masks for training instead of pseudo groundtruth ones') 51 | parser.add_argument( 52 | '--loss-mode', 53 | default='random_ce', 54 | help='choice of the loss mode for training this model', 55 | choices=["random_ce", "greedy_ce", "greedy_ce_contrastive"] 56 | ) 57 | 58 | return parser 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = get_parser() 63 | args_dict = parser.parse_args() 64 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import setuptools 3 | 4 | repository_dir = os.path.dirname(__file__) 5 | 6 | with open(os.path.join(repository_dir, "requirements.txt")) as fh: 7 | requirements = [line for line in fh.readlines()] 8 | 9 | setuptools.setup( 10 | name="ssc_ris", 11 | version="1.0.0", 12 | author="Francisco Eiras", 13 | author_email="francisco.girbal@gmail.com", 14 | license="MIT", 15 | python_requires=">=3.7", 16 | packages=setuptools.find_packages(), 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "Programming Language :: Python :: 3.7", 20 | ], 21 | dependency_links=requirements, 22 | include_package_data=True, 23 | ) 24 | -------------------------------------------------------------------------------- /ssc_ris/correct/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgirbal/segment-select-correct/0d36714c1166a598051e467b0bddc5ec23a67d6a/ssc_ris/correct/__init__.py -------------------------------------------------------------------------------- /ssc_ris/correct/loss.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from ssc_ris.select.utils import separate_instance_masks_torch as separate_instance_masks 7 | from .loss_utils import cross_entropy, contrastive_kl_loss, pseudo_intersection_over_union 8 | 9 | softmax = torch.nn.Softmax(dim=1) 10 | 11 | 12 | def random_assign_and_cross_entropy(inputs, outputs, targets, mask_idx_info, params={}): 13 | assigned_masks = [] 14 | output_loss = torch.Tensor([0.0]).cuda(non_blocking=True) 15 | 16 | for img_idx, (input, output, merged_target) in enumerate(zip(inputs, outputs, targets)): 17 | all_instance_masks = separate_instance_masks(merged_target) 18 | 19 | same_img_indices = torch.where((inputs.reshape((inputs.shape[0], -1)) == input.reshape(-1)).all(dim=1))[0] 20 | same_img_same_mask_indices = torch.where(mask_idx_info[same_img_indices] == mask_idx_info[img_idx])[0] + same_img_indices[0].cpu() 21 | same_img_diff_mask_indices = torch.where(mask_idx_info[same_img_indices] != mask_idx_info[img_idx])[0] + same_img_indices[0].cpu() 22 | 23 | if len(all_instance_masks) == 1: 24 | # there's only one mask, match it 25 | output_loss += cross_entropy(output.unsqueeze(0), all_instance_masks[0].unsqueeze(0).to(torch.long)) 26 | assigned_masks.append(0) 27 | else: 28 | # there're multiple masks, must assign one of these masks later 29 | 30 | # has this object been assigned a mask already? 31 | if same_img_same_mask_indices.shape[0] > 0 and same_img_same_mask_indices[0] < img_idx: 32 | # it has, just use this one and continue 33 | already_assigned_mask = assigned_masks[same_img_same_mask_indices[0]] 34 | output_loss += cross_entropy(output.unsqueeze(0), all_instance_masks[already_assigned_mask].unsqueeze(0).to(torch.long)) 35 | assigned_masks.append(already_assigned_mask) 36 | continue 37 | 38 | # no mask has been assigned to this object yet, choose one 39 | 40 | # has a mask for a different object already been assigned? If so, exclude that one from the random draw 41 | exclude_masks = [] 42 | if same_img_diff_mask_indices.shape[0] > 0 and same_img_diff_mask_indices[0] < img_idx: 43 | for other_img_idx in same_img_diff_mask_indices: 44 | if other_img_idx < img_idx and assigned_masks[other_img_idx] not in exclude_masks: 45 | exclude_masks.append(assigned_masks[other_img_idx]) 46 | 47 | options = [i for i in range(0, len(all_instance_masks)) if i not in exclude_masks] 48 | if len(options) == 0: 49 | print("---- Not enough masks for objects in image :( ----") 50 | options = [i for i in range(0, len(all_instance_masks))] 51 | 52 | random_mask_idx = random.choice(options) 53 | output_loss += cross_entropy(output.unsqueeze(0), all_instance_masks[random_mask_idx].unsqueeze(0).to(torch.long)) 54 | assigned_masks.append(random_mask_idx) 55 | 56 | return output_loss / outputs.shape[0], {} 57 | 58 | 59 | def greedy_no_constraints_and_cross_entropy(inputs, outputs, targets, mask_idx_info, params={}): 60 | # no constraints on assigning the same mask to the same output, simply greedy on the output 61 | assigned_masks = {} 62 | all_masks = {} 63 | output_loss = torch.Tensor([0.0]).cuda(non_blocking=True) 64 | softmaxed_outputs = softmax(outputs) 65 | 66 | for img_idx, (input, output, merged_target) in enumerate(zip(inputs, outputs, targets)): 67 | all_instance_masks = separate_instance_masks(merged_target) 68 | all_masks[img_idx] = all_instance_masks 69 | 70 | if len(all_instance_masks) == 1: 71 | # there's only one mask, match it 72 | assigned_masks[img_idx] = 0 73 | output_loss += cross_entropy(output.unsqueeze(0), all_instance_masks[0].unsqueeze(0).to(torch.long)) 74 | else: 75 | mask_preferences_img = torch.tensor([ 76 | pseudo_intersection_over_union(softmaxed_outputs[img_idx], instance_mask) 77 | for instance_mask in all_instance_masks 78 | ]) 79 | assigned_masks[img_idx] = torch.argmax(mask_preferences_img) 80 | output_loss += cross_entropy( 81 | output.unsqueeze(0), 82 | all_instance_masks[assigned_masks[img_idx]].unsqueeze(0).to(torch.long) 83 | ) 84 | 85 | assigned_masks_list = torch.cat([ 86 | all_masks[idx][assigned_masks[idx]].unsqueeze(0) for idx in range(len(assigned_masks) 87 | )]) 88 | ice_loss = output_loss / outputs.shape[0] 89 | return ice_loss, {"ice loss": ice_loss, "contrastive loss": torch.tensor(0), "matched masks": assigned_masks_list} 90 | 91 | 92 | def greedy_match_and_cross_entropy(inputs, outputs, targets, mask_idx_info, params={}): 93 | assigned_masks = {} 94 | same_mask_constraints = {} 95 | diff_mask_constraints = {} 96 | all_masks = {} 97 | mask_preferences = {} 98 | output_loss = torch.Tensor([0.0]).cuda(non_blocking=True) 99 | softmaxed_outputs = softmax(outputs) 100 | 101 | for img_idx, (input, output, merged_target) in enumerate(zip(inputs, outputs, targets)): 102 | all_instance_masks = separate_instance_masks(merged_target) 103 | 104 | same_img_indices = torch.where((inputs.reshape((inputs.shape[0], -1)) == input.reshape(-1)).all(dim=1))[0] 105 | same_img_same_mask_indices = torch.where(mask_idx_info[same_img_indices] == mask_idx_info[img_idx])[0] + same_img_indices[0].cpu() 106 | same_img_diff_mask_indices = torch.where(mask_idx_info[same_img_indices] != mask_idx_info[img_idx])[0] + same_img_indices[0].cpu() 107 | 108 | if len(all_instance_masks) == 1: 109 | # there's only one mask, match it 110 | output_loss += cross_entropy(output.unsqueeze(0), all_instance_masks[0].unsqueeze(0).to(torch.long)) 111 | assigned_masks[img_idx] = 0 112 | 113 | same_mask_constraints[img_idx] = same_img_same_mask_indices 114 | diff_mask_constraints[img_idx] = same_img_diff_mask_indices 115 | mask_preferences[img_idx] = torch.tensor([ 116 | pseudo_intersection_over_union(softmaxed_outputs[img_idx], instance_mask) 117 | for instance_mask in all_instance_masks 118 | ]) 119 | all_masks[img_idx] = all_instance_masks 120 | 121 | # if there are any outputs that have more than one mask, then assign them through Hungarian matching 122 | for img_idx in same_mask_constraints.keys(): 123 | # this object's mask has been assigned already, no need to do anything 124 | if img_idx in assigned_masks: 125 | continue 126 | 127 | # this object's mask has not been choosen yet, assign masks to same and diff mask constraints based on Hungarian matching 128 | matching_indices = torch.cat((same_mask_constraints[img_idx], diff_mask_constraints[img_idx])).numpy() 129 | matching_mask_preferences = [] 130 | matching_masks = [] 131 | for idx in matching_indices: 132 | matching_mask_preferences.append(mask_preferences[idx]) 133 | matching_masks.append(all_masks[idx]) 134 | 135 | # matching_mask_preferences = torch.vstack(matching_mask_preferences) 136 | 137 | # iteratively select the highest score masks and eliminate the ones that have been selected already 138 | while True: 139 | highest_iou_idx = torch.argmax(torch.Tensor([torch.max(idx_masks) for idx_masks in matching_mask_preferences])) 140 | mask_idx = torch.argmax(matching_mask_preferences[highest_iou_idx]) 141 | 142 | # if the best if -Inf, all existing masks have been assigned 143 | if matching_mask_preferences[highest_iou_idx][mask_idx] == -float('inf'): 144 | break 145 | 146 | highest_index = matching_indices[highest_iou_idx] 147 | highest_mask = matching_masks[highest_iou_idx][mask_idx] 148 | 149 | # assign this mask to all same_mask_constraints[highest_index] 150 | for same_mask_index in same_mask_constraints[highest_index]: 151 | assigned_masks[int(same_mask_index)] = int(mask_idx) 152 | output_loss += cross_entropy(outputs[same_mask_index].unsqueeze(0), highest_mask.unsqueeze(0).to(torch.long)) 153 | 154 | index = np.where(matching_indices == same_mask_index.numpy())[0][0] 155 | matching_mask_preferences[index][:] = -float('inf') 156 | 157 | # look for the same mask in diff_mask_constraints and change the IoU to -inf to blacklist this mask 158 | for diff_mask_index in diff_mask_constraints[highest_index]: 159 | if diff_mask_index in assigned_masks: 160 | continue 161 | 162 | diff_mask_matching_index = np.where(matching_indices == diff_mask_index.numpy())[0][0] 163 | 164 | for i, diff_mask in enumerate(matching_masks[diff_mask_matching_index]): 165 | if (diff_mask == highest_mask).all(): 166 | matching_mask_preferences[diff_mask_matching_index][i] = -float('inf') 167 | break 168 | 169 | # if there are any masks that have not been assigned, it's because there's not enough masks for the number of objects 170 | # assign a random mask 171 | for img_idx in same_mask_constraints.keys(): 172 | # this object's mask has been assigned already, no need to do anything 173 | if img_idx in assigned_masks: 174 | continue 175 | 176 | random_mask_idx = random.choice(range(len(all_masks[img_idx]))) 177 | assigned_masks[img_idx] = random_mask_idx 178 | 179 | assigned_masks_list = torch.cat([ 180 | all_masks[idx][assigned_masks[idx]].unsqueeze(0) for idx in range(len(assigned_masks) 181 | )]) 182 | ice_loss = output_loss / outputs.shape[0] 183 | return ice_loss, {"ice loss": ice_loss, "contrastive loss": torch.tensor(0), "matched masks": assigned_masks_list} 184 | 185 | 186 | def greedy_match_and_contrastive(inputs, outputs, targets, mask_idx_info, params={"contrastive_alpha": 0.01}): 187 | ice_loss, ice_log = greedy_match_and_cross_entropy( 188 | inputs, outputs, targets, mask_idx_info 189 | ) 190 | matched_masks = ice_log["matched masks"] 191 | 192 | contrastive_loss = contrastive_kl_loss(inputs, outputs, matched_masks, mask_idx_info) 193 | 194 | return ice_loss + params["contrastive_alpha"] * contrastive_loss, {"ice loss": ice_loss, "contrastive loss": contrastive_loss, "matched masks": matched_masks} 195 | -------------------------------------------------------------------------------- /ssc_ris/correct/loss_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn, Tensor 8 | 9 | kl_loss = nn.KLDivLoss(reduction="batchmean") 10 | cos_sim = torch.nn.CosineSimilarity(dim=0) 11 | 12 | 13 | def cross_entropy(input, target): 14 | weight = torch.FloatTensor([0.9, 1.1]).cuda() 15 | return nn.functional.cross_entropy(input, target, weight=weight) 16 | 17 | 18 | def pseudo_intersection_over_union(prediction, target_mask): 19 | # from Recurrent Instance Segmentation paper (https://arxiv.org/pdf/1511.08250.pdf) 20 | # assumes prediction is the result of a softmax function 21 | y_hat_y_dot = (prediction[1] * target_mask).sum() 22 | return y_hat_y_dot / torch.max(prediction[1].sum() + target_mask.sum() - y_hat_y_dot, torch.tensor(1e-7).to(y_hat_y_dot.device)) 23 | 24 | 25 | def contrastive_kl_loss(batch_images, output, batch_targets, batch_mask_idx_info): 26 | # compute KL divergence on negative terms; don't add anything for positive terms 27 | negative_examples_constant = 0.1 28 | contrastive_loss = torch.Tensor([0.0]).cuda(non_blocking=True) 29 | output = torch.softmax(output, dim=1) 30 | 31 | n_terms = 0 32 | for img_idx, img in enumerate(batch_images): 33 | same_img_indices = torch.where((batch_images.reshape((batch_images.shape[0], -1)) == img.reshape(-1)).all(dim=1))[0] 34 | 35 | same_img_same_mask_indices = torch.where(batch_mask_idx_info[same_img_indices] == batch_mask_idx_info[img_idx])[0] + same_img_indices[0].cpu() 36 | same_img_diff_mask_indices = torch.where(batch_mask_idx_info[same_img_indices] != batch_mask_idx_info[img_idx])[0] + same_img_indices[0].cpu() 37 | 38 | # sum masks of the same object 39 | for other_pos_img_idx in same_img_same_mask_indices: 40 | if img_idx == other_pos_img_idx: 41 | continue 42 | 43 | contrastive_loss += kl_loss(torch.log(output[img_idx]), output[other_pos_img_idx]) / (output[img_idx].shape[1] * output[img_idx].shape[2]) 44 | n_terms += 1 45 | 46 | # 1 / kl for masks of different objects 47 | for other_neg_img_idx in same_img_diff_mask_indices: 48 | active_pixels = (batch_targets[img_idx] == 1) | (batch_targets[other_neg_img_idx] == 1) 49 | masked_img_output = output[img_idx][:, active_pixels] 50 | masked_other_img_output = output[other_neg_img_idx][:, active_pixels] 51 | 52 | # hinge loss combined with an inverse KL to incentivize low similarity 53 | negative_loss_term = torch.min( 54 | torch.Tensor([1]).to(masked_img_output.device), 55 | 1 / (negative_examples_constant * kl_loss(torch.log(masked_img_output), masked_other_img_output)) 56 | ) 57 | contrastive_loss += negative_loss_term 58 | 59 | n_terms += 1 60 | 61 | # save_batch_images_and_masks(batch_images.cpu(), batch_targets.cpu(), 'example.png') 62 | 63 | if n_terms: 64 | contrastive_loss /= n_terms 65 | 66 | return contrastive_loss 67 | 68 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * 2 | from .unsupervised_dataset import * -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/bert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgirbal/segment-select-correct/0d36714c1166a598051e467b0bddc5ec23a67d6a/ssc_ris/refer_dataset/bert/__init__.py -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/bert/activations.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def swish(x): 12 | return x * torch.sigmoid(x) 13 | 14 | 15 | def _gelu_python(x): 16 | """ Original Implementation of the gelu activation function in Google Bert repo when initially created. 17 | For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 18 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 19 | This is now written in C in torch.nn.functional 20 | Also see https://arxiv.org/abs/1606.08415 21 | """ 22 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 23 | 24 | 25 | def gelu_new(x): 26 | """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). 27 | Also see https://arxiv.org/abs/1606.08415 28 | """ 29 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 30 | 31 | 32 | if torch.__version__ < "1.4.0": 33 | gelu = _gelu_python 34 | else: 35 | gelu = F.gelu 36 | 37 | 38 | def gelu_fast(x): 39 | return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) 40 | 41 | 42 | ACT2FN = { 43 | "relu": F.relu, 44 | "swish": swish, 45 | "gelu": gelu, 46 | "tanh": torch.tanh, 47 | "gelu_new": gelu_new, 48 | "gelu_fast": gelu_fast, 49 | } 50 | 51 | 52 | def get_activation(activation_string): 53 | if activation_string in ACT2FN: 54 | return ACT2FN[activation_string] 55 | else: 56 | raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys()))) 57 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/bert/configuration_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT model configuration """ 17 | 18 | 19 | import logging 20 | 21 | from .configuration_utils import PretrainedConfig 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { 27 | "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", 28 | "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", 29 | "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", 30 | "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", 31 | "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", 32 | "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", 33 | "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", 34 | "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", 35 | "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", 36 | "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", 37 | "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", 38 | "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", 39 | "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", 40 | "bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json", 41 | "bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json", 42 | "cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json", 43 | "cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json", 44 | "cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json", 45 | "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json", 46 | "TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json", 47 | "TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json", 48 | "wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json", 49 | # See all BERT models at https://huggingface.co/models?filter=bert 50 | } 51 | 52 | 53 | class BertConfig(PretrainedConfig): 54 | r""" 55 | This is the configuration class to store the configuration of a :class:`~transformers.BertModel`. 56 | It is used to instantiate an BERT model according to the specified arguments, defining the model 57 | architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of 58 | the BERT `bert-base-uncased `__ architecture. 59 | 60 | Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used 61 | to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig` 62 | for more information. 63 | 64 | 65 | Args: 66 | vocab_size (:obj:`int`, optional, defaults to 30522): 67 | Vocabulary size of the BERT model. Defines the different tokens that 68 | can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`. 69 | hidden_size (:obj:`int`, optional, defaults to 768): 70 | Dimensionality of the encoder layers and the pooler layer. 71 | num_hidden_layers (:obj:`int`, optional, defaults to 12): 72 | Number of hidden layers in the Transformer encoder. 73 | num_attention_heads (:obj:`int`, optional, defaults to 12): 74 | Number of attention heads for each attention layer in the Transformer encoder. 75 | intermediate_size (:obj:`int`, optional, defaults to 3072): 76 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 77 | hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): 78 | The non-linear activation function (function or string) in the encoder and pooler. 79 | If string, "gelu", "relu", "swish" and "gelu_new" are supported. 80 | hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): 81 | The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. 82 | attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): 83 | The dropout ratio for the attention probabilities. 84 | max_position_embeddings (:obj:`int`, optional, defaults to 512): 85 | The maximum sequence length that this model might ever be used with. 86 | Typically set this to something large just in case (e.g., 512 or 1024 or 2048). 87 | type_vocab_size (:obj:`int`, optional, defaults to 2): 88 | The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`. 89 | initializer_range (:obj:`float`, optional, defaults to 0.02): 90 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 91 | layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): 92 | The epsilon used by the layer normalization layers. 93 | gradient_checkpointing (:obj:`bool`, optional, defaults to False): 94 | If True, use gradient checkpointing to save memory at the expense of slower backward pass. 95 | 96 | Example:: 97 | 98 | >>> from transformers import BertModel, BertConfig 99 | 100 | >>> # Initializing a BERT bert-base-uncased style configuration 101 | >>> configuration = BertConfig() 102 | 103 | >>> # Initializing a model from the bert-base-uncased style configuration 104 | >>> model = BertModel(configuration) 105 | 106 | >>> # Accessing the model configuration 107 | >>> configuration = model.config 108 | """ 109 | model_type = "bert" 110 | 111 | def __init__( 112 | self, 113 | vocab_size=30522, 114 | hidden_size=768, 115 | num_hidden_layers=12, 116 | num_attention_heads=12, 117 | intermediate_size=3072, 118 | hidden_act="gelu", 119 | hidden_dropout_prob=0.1, 120 | attention_probs_dropout_prob=0.1, 121 | max_position_embeddings=512, 122 | type_vocab_size=2, 123 | initializer_range=0.02, 124 | layer_norm_eps=1e-12, 125 | pad_token_id=0, 126 | gradient_checkpointing=False, 127 | **kwargs 128 | ): 129 | super().__init__(pad_token_id=pad_token_id, **kwargs) 130 | 131 | self.vocab_size = vocab_size 132 | self.hidden_size = hidden_size 133 | self.num_hidden_layers = num_hidden_layers 134 | self.num_attention_heads = num_attention_heads 135 | self.hidden_act = hidden_act 136 | self.intermediate_size = intermediate_size 137 | self.hidden_dropout_prob = hidden_dropout_prob 138 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 139 | self.max_position_embeddings = max_position_embeddings 140 | self.type_vocab_size = type_vocab_size 141 | self.initializer_range = initializer_range 142 | self.layer_norm_eps = layer_norm_eps 143 | self.gradient_checkpointing = gradient_checkpointing 144 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from .bert.tokenization_bert import BertTokenizer 8 | from .refer.refer import REFER 9 | 10 | 11 | class ReferDataset(data.Dataset): 12 | def __init__( 13 | self, 14 | dataset: str, 15 | splitBy: str = 'unc', 16 | refer_data_root: str = 'dataset/refer', 17 | image_transforms: object = None, 18 | target_transforms: object = None, 19 | split: str = 'train', 20 | bert_tokenizer: str = 'bert-base-uncased', 21 | COCO_image_root: str = None, 22 | eval_mode: bool = False, 23 | return_attributes: bool = False 24 | ): 25 | """Wrapper for the REFER class in refer (https://github.com/lichengunc/refer) 26 | 27 | Args: 28 | dataset (str): dataset description (one of "refcoco", "refcoco+" or "refcocog") 29 | splitBy (str, optional): data split. Defaults to 'unc'. 30 | refer_data_root (str, optional): root of the REFER data. Defaults to 'dataset/refer'. 31 | image_transforms (object, optional): transformations to be applied to input images. Defaults to None. 32 | target_transforms (object, optional): transformations to be applied to target masks. Defaults to None. 33 | split (str, optional): dataset split. Defaults to 'train'. 34 | bert_tokenizer (str, optional): type of BERT tokenizer. Defaults to 'bert-base-uncased'. 35 | COCO_image_root (str, optional): image root. Defaults to None. 36 | eval_mode (bool, optional): if True, load all sentences, otherwise chose only one. Defaults to False. 37 | return_attributes (bool, optional): whether additional data attributes should be returned. Defaults to False. 38 | """ 39 | 40 | self.classes = [] 41 | self.image_transforms = image_transforms 42 | self.target_transform = target_transforms 43 | self.split = split 44 | self.refer = REFER(refer_data_root, dataset, splitBy, image_root=COCO_image_root) 45 | self.return_attributes = return_attributes 46 | 47 | self.max_tokens = 20 48 | 49 | ref_ids = self.refer.getRefIds(split=self.split) 50 | img_ids = self.refer.getImgIds(ref_ids) 51 | 52 | all_imgs = self.refer.Imgs 53 | self.imgs = list(all_imgs[i] for i in img_ids) 54 | self.ref_ids = ref_ids 55 | 56 | self.input_ids = [] 57 | self.attention_masks = [] 58 | self.tokenizer = BertTokenizer.from_pretrained(bert_tokenizer) 59 | 60 | self.eval_mode = eval_mode 61 | # if we are testing on a dataset, test all sentences of an object; 62 | # o/w, we are validating during training, randomly sample one sentence for efficiency 63 | for r in ref_ids: 64 | ref = self.refer.Refs[r] 65 | 66 | sentences_for_ref = [] 67 | attentions_for_ref = [] 68 | 69 | for i, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])): 70 | sentence_raw = el['sent'] 71 | attention_mask = [0] * self.max_tokens 72 | padded_input_ids = [0] * self.max_tokens 73 | 74 | input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) 75 | 76 | # truncation of tokens 77 | input_ids = input_ids[:self.max_tokens] 78 | 79 | padded_input_ids[:len(input_ids)] = input_ids 80 | attention_mask[:len(input_ids)] = [1]*len(input_ids) 81 | 82 | sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0)) 83 | attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0)) 84 | 85 | self.input_ids.append(sentences_for_ref) 86 | self.attention_masks.append(attentions_for_ref) 87 | 88 | def get_classes(self): 89 | return self.classes 90 | 91 | def __len__(self): 92 | return len(self.ref_ids) 93 | 94 | def __getitem__(self, index): 95 | this_ref_id = self.ref_ids[index] 96 | this_img_id = self.refer.getImgIds(this_ref_id) 97 | this_img = self.refer.Imgs[this_img_id[0]] 98 | 99 | img = Image.open(os.path.join(self.refer.IMAGE_DIR, this_img['file_name'])).convert("RGB") 100 | 101 | ref = self.refer.loadRefs(this_ref_id) 102 | 103 | ref_mask = np.array(self.refer.getMask(ref[0])['mask']) 104 | annot = np.zeros(ref_mask.shape) 105 | annot[ref_mask == 1] = 1 106 | 107 | annot = Image.fromarray(annot.astype(np.uint8), mode="P") 108 | target = annot 109 | 110 | if self.image_transforms is not None: 111 | # resize, from PIL to tensor, and mean and std normalization 112 | img, target = self.image_transforms(img, annot) 113 | 114 | if self.eval_mode: 115 | embedding = [] 116 | att = [] 117 | for s in range(len(self.input_ids[index])): 118 | e = self.input_ids[index][s] 119 | a = self.attention_masks[index][s] 120 | embedding.append(e.unsqueeze(-1)) 121 | att.append(a.unsqueeze(-1)) 122 | 123 | tensor_embeddings = torch.cat(embedding, dim=-1) 124 | attention_mask = torch.cat(att, dim=-1) 125 | else: 126 | choice_sent = np.random.choice(len(self.input_ids[index])) 127 | tensor_embeddings = self.input_ids[index][choice_sent] 128 | attention_mask = self.attention_masks[index][choice_sent] 129 | 130 | if self.return_attributes: 131 | attributes = { 132 | "file_name": this_img["file_name"], 133 | "sentence_ids": [sent["sent_id"] for sent in ref[0]["sentences"]], 134 | "sentences_raw": [sent["raw"] for sent in ref[0]["sentences"]], 135 | "sentences_sent": [sent["sent"] for sent in ref[0]["sentences"]], 136 | "ref_id": ref[0]["ref_id"], 137 | "ann_id": ref[0]["ann_id"], 138 | "ref": ref[0] 139 | } 140 | 141 | return img, target, tensor_embeddings, attention_mask, attributes 142 | else: 143 | return img, target, tensor_embeddings, attention_mask 144 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | # install pycocotools/mask locally 3 | # copy from https://github.com/pdollar/coco.git 4 | python setup.py build_ext --inplace 5 | rm -rf build 6 | 7 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/README.md: -------------------------------------------------------------------------------- 1 | ## Note 2 | This API is able to load all 4 referring expression datasets, i.e., RefClef, RefCOCO, RefCOCO+ and RefCOCOg. 3 | They are with different train/val/test split by UNC, Google and UC Berkeley respectively. We provide all kinds of splits here. 4 | 5 | 6 | 7 | 8 |
Mountain View
9 | 10 | ## Citation 11 | If you used the following three datasets RefClef, RefCOCO and RefCOCO+ that were collected by UNC, please consider cite our EMNLP2014 paper; if you want to compare with our recent results, please check our ECCV2016 paper. 12 | ```bash 13 | Kazemzadeh, Sahar, et al. "ReferItGame: Referring to Objects in Photographs of Natural Scenes." EMNLP 2014. 14 | Yu, Licheng, et al. "Modeling Context in Referring Expressions." ECCV 2016. 15 | ``` 16 | 17 | ## Setup 18 | Run "make" before using the code. 19 | It will generate ``_mask.c`` and ``_mask.so`` in ``external/`` folder. 20 | These mask-related codes are copied from mscoco [API](https://github.com/pdollar/coco). 21 | 22 | ## Download 23 | Download the cleaned data and extract them into "data" folder 24 | - 1) http://bvisionweb1.cs.unc.edu/licheng/referit/data/refclef.zip 25 | - 2) http://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco.zip 26 | - 3) http://bvisionweb1.cs.unc.edu/licheng/referit/data/refcoco+.zip 27 | - 4) http://bvisionweb1.cs.unc.edu/licheng/referit/data/refcocog.zip 28 | 29 | ## Prepare Images: 30 | Besides, add "mscoco" into the ``data/images`` folder, which can be from [mscoco](http://mscoco.org/dataset/#overview) 31 | COCO's images are used for RefCOCO, RefCOCO+ and refCOCOg. 32 | For RefCLEF, please add ``saiapr_tc-12`` into ``data/images`` folder. We extracted the related 19997 images to the cleaned RefCLEF dataset, which is a subset of the original [imageCLEF](http://imageclef.org/SIAPRdata). Download the [subset](http://bvisionweb1.cs.unc.edu/licheng/referit/data/images/saiapr_tc-12.zip) and unzip it to ``data/images/saiapr_tc-12``. 33 | 34 | ## How to use 35 | The "refer.py" is able to load all 4 datasets with different kinds of data split by UNC, Google, UMD and UC Berkeley. 36 | **Note for RefCOCOg, we suggest use UMD's split which has train/val/test splits and there is no overlap of images between different split.** 37 | ```bash 38 | # locate your own data_root, and choose the dataset_splitBy you want to use 39 | refer = REFER(data_root, dataset='refclef', splitBy='unc') 40 | refer = REFER(data_root, dataset='refclef', splitBy='berkeley') # 2 train and 1 test images missed 41 | refer = REFER(data_root, dataset='refcoco', splitBy='unc') 42 | refer = REFER(data_root, dataset='refcoco', splitBy='google') 43 | refer = REFER(data_root, dataset='refcoco+', splitBy='unc') 44 | refer = REFER(data_root, dataset='refcocog', splitBy='google') # test split not released yet 45 | refer = REFER(data_root, dataset='refcocog', splitBy='umd') # Recommended, including train/val/test 46 | ``` 47 | 48 | 49 | 56 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'licheng' 2 | 3 | 4 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert(gts.keys() == res.keys()) 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert(type(hypo) is list) 33 | assert(len(hypo) == 1) 34 | assert(type(ref) is list) 35 | assert(len(ref) >= 1) 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | #score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1) 41 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1) 42 | 43 | # return (bleu, bleu_info) 44 | return score, scores 45 | 46 | def method(self): 47 | return "Bleu" 48 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import sys, math, re 21 | from collections import defaultdict 22 | 23 | def precook(s, n=4, out=False): 24 | """Takes a string as input and returns an object that can be given to 25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 26 | can take string arguments as well.""" 27 | words = s.split() 28 | counts = defaultdict(int) 29 | for k in xrange(1,n+1): 30 | for i in xrange(len(words)-k+1): 31 | ngram = tuple(words[i:i+k]) 32 | counts[ngram] += 1 33 | return (len(words), counts) 34 | 35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 36 | '''Takes a list of reference sentences for a single segment 37 | and returns an object that encapsulates everything that BLEU 38 | needs to know about them.''' 39 | 40 | reflen = [] 41 | maxcounts = {} 42 | for ref in refs: 43 | rl, counts = precook(ref, n) 44 | reflen.append(rl) 45 | for (ngram,count) in counts.iteritems(): 46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 47 | 48 | # Calculate effective reference sentence length. 49 | if eff == "shortest": 50 | reflen = min(reflen) 51 | elif eff == "average": 52 | reflen = float(sum(reflen))/len(reflen) 53 | 54 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 55 | 56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 57 | 58 | return (reflen, maxcounts) 59 | 60 | def cook_test(test, (reflen, refmaxcounts), eff=None, n=4): 61 | '''Takes a test sentence and returns an object that 62 | encapsulates everything that BLEU needs to know about it.''' 63 | 64 | testlen, counts = precook(test, n, True) 65 | 66 | result = {} 67 | 68 | # Calculate effective reference sentence length. 69 | 70 | if eff == "closest": 71 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1] 72 | else: ## i.e., "average" or "shortest" or None 73 | result["reflen"] = reflen 74 | 75 | result["testlen"] = testlen 76 | 77 | result["guess"] = [max(0,testlen-k+1) for k in xrange(1,n+1)] 78 | 79 | result['correct'] = [0]*n 80 | for (ngram, count) in counts.iteritems(): 81 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count) 82 | 83 | return result 84 | 85 | class BleuScorer(object): 86 | """Bleu scorer. 87 | """ 88 | 89 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 90 | # special_reflen is used in oracle (proportional effective ref len for a node). 91 | 92 | def copy(self): 93 | ''' copy the refs.''' 94 | new = BleuScorer(n=self.n) 95 | new.ctest = copy.copy(self.ctest) 96 | new.crefs = copy.copy(self.crefs) 97 | new._score = None 98 | return new 99 | 100 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 101 | ''' singular instance ''' 102 | 103 | self.n = n 104 | self.crefs = [] 105 | self.ctest = [] 106 | self.cook_append(test, refs) 107 | self.special_reflen = special_reflen 108 | 109 | def cook_append(self, test, refs): 110 | '''called by constructor and __iadd__ to avoid creating new instances.''' 111 | 112 | if refs is not None: 113 | self.crefs.append(cook_refs(refs)) 114 | if test is not None: 115 | cooked_test = cook_test(test, self.crefs[-1]) 116 | self.ctest.append(cooked_test) ## N.B.: -1 117 | else: 118 | self.ctest.append(None) # lens of crefs and ctest have to match 119 | 120 | self._score = None ## need to recompute 121 | 122 | def ratio(self, option=None): 123 | self.compute_score(option=option) 124 | return self._ratio 125 | 126 | def score_ratio(self, option=None): 127 | '''return (bleu, len_ratio) pair''' 128 | return (self.fscore(option=option), self.ratio(option=option)) 129 | 130 | def score_ratio_str(self, option=None): 131 | return "%.4f (%.2f)" % self.score_ratio(option) 132 | 133 | def reflen(self, option=None): 134 | self.compute_score(option=option) 135 | return self._reflen 136 | 137 | def testlen(self, option=None): 138 | self.compute_score(option=option) 139 | return self._testlen 140 | 141 | def retest(self, new_test): 142 | if type(new_test) is str: 143 | new_test = [new_test] 144 | assert len(new_test) == len(self.crefs), new_test 145 | self.ctest = [] 146 | for t, rs in zip(new_test, self.crefs): 147 | self.ctest.append(cook_test(t, rs)) 148 | self._score = None 149 | 150 | return self 151 | 152 | def rescore(self, new_test): 153 | ''' replace test(s) with new test(s), and returns the new score.''' 154 | 155 | return self.retest(new_test).compute_score() 156 | 157 | def size(self): 158 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 159 | return len(self.crefs) 160 | 161 | def __iadd__(self, other): 162 | '''add an instance (e.g., from another sentence).''' 163 | 164 | if type(other) is tuple: 165 | ## avoid creating new BleuScorer instances 166 | self.cook_append(other[0], other[1]) 167 | else: 168 | assert self.compatible(other), "incompatible BLEUs." 169 | self.ctest.extend(other.ctest) 170 | self.crefs.extend(other.crefs) 171 | self._score = None ## need to recompute 172 | 173 | return self 174 | 175 | def compatible(self, other): 176 | return isinstance(other, BleuScorer) and self.n == other.n 177 | 178 | def single_reflen(self, option="average"): 179 | return self._single_reflen(self.crefs[0][0], option) 180 | 181 | def _single_reflen(self, reflens, option=None, testlen=None): 182 | 183 | if option == "shortest": 184 | reflen = min(reflens) 185 | elif option == "average": 186 | reflen = float(sum(reflens))/len(reflens) 187 | elif option == "closest": 188 | reflen = min((abs(l-testlen), l) for l in reflens)[1] 189 | else: 190 | assert False, "unsupported reflen option %s" % option 191 | 192 | return reflen 193 | 194 | def recompute_score(self, option=None, verbose=0): 195 | self._score = None 196 | return self.compute_score(option, verbose) 197 | 198 | def compute_score(self, option=None, verbose=0): 199 | n = self.n 200 | small = 1e-9 201 | tiny = 1e-15 ## so that if guess is 0 still return 0 202 | bleu_list = [[] for _ in range(n)] 203 | 204 | if self._score is not None: 205 | return self._score 206 | 207 | if option is None: 208 | option = "average" if len(self.crefs) == 1 else "closest" 209 | 210 | self._testlen = 0 211 | self._reflen = 0 212 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n} 213 | 214 | # for each sentence 215 | for comps in self.ctest: 216 | testlen = comps['testlen'] 217 | self._testlen += testlen 218 | 219 | if self.special_reflen is None: ## need computation 220 | reflen = self._single_reflen(comps['reflen'], option, testlen) 221 | else: 222 | reflen = self.special_reflen 223 | 224 | self._reflen += reflen 225 | 226 | for key in ['guess','correct']: 227 | for k in xrange(n): 228 | totalcomps[key][k] += comps[key][k] 229 | 230 | # append per image bleu score 231 | bleu = 1. 232 | for k in xrange(n): 233 | bleu *= (float(comps['correct'][k]) + tiny) \ 234 | /(float(comps['guess'][k]) + small) 235 | bleu_list[k].append(bleu ** (1./(k+1))) 236 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 237 | if ratio < 1: 238 | for k in xrange(n): 239 | bleu_list[k][-1] *= math.exp(1 - 1/ratio) 240 | 241 | if verbose > 1: 242 | print comps, reflen 243 | 244 | totalcomps['reflen'] = self._reflen 245 | totalcomps['testlen'] = self._testlen 246 | 247 | bleus = [] 248 | bleu = 1. 249 | for k in xrange(n): 250 | bleu *= float(totalcomps['correct'][k] + tiny) \ 251 | / (totalcomps['guess'][k] + small) 252 | bleus.append(bleu ** (1./(k+1))) 253 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 254 | if ratio < 1: 255 | for k in xrange(n): 256 | bleus[k] *= math.exp(1 - 1/ratio) 257 | 258 | if verbose > 0: 259 | print totalcomps 260 | print "ratio:", ratio 261 | 262 | self._score = bleus 263 | return self._score, bleu_list 264 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from cider_scorer import CiderScorer 11 | import pdb 12 | 13 | class Cider: 14 | """ 15 | Main Class to compute the CIDEr metric 16 | 17 | """ 18 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 19 | # set cider to sum over 1 to 4-grams 20 | self._n = n 21 | # set the standard deviation parameter for gaussian penalty 22 | self._sigma = sigma 23 | 24 | def compute_score(self, gts, res): 25 | """ 26 | Main function to compute CIDEr score 27 | :param hypo_for_image (dict) : dictionary with key and value 28 | ref_for_image (dict) : dictionary with key and value 29 | :return: cider (float) : computed CIDEr score for the corpus 30 | """ 31 | 32 | assert(gts.keys() == res.keys()) 33 | imgIds = gts.keys() 34 | 35 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 36 | 37 | for id in imgIds: 38 | hypo = res[id] 39 | ref = gts[id] 40 | 41 | # Sanity check. 42 | assert(type(hypo) is list) 43 | assert(len(hypo) == 1) 44 | assert(type(ref) is list) 45 | assert(len(ref) > 0) 46 | 47 | cider_scorer += (hypo[0], ref) 48 | 49 | (score, scores) = cider_scorer.compute_score() 50 | 51 | return score, scores 52 | 53 | def method(self): 54 | return "CIDEr" -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Tsung-Yi Lin 3 | # Ramakrishna Vedantam 4 | 5 | import copy 6 | from collections import defaultdict 7 | import numpy as np 8 | import pdb 9 | import math 10 | 11 | def precook(s, n=4, out=False): 12 | """ 13 | Takes a string as input and returns an object that can be given to 14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 15 | can take string arguments as well. 16 | :param s: string : sentence to be converted into ngrams 17 | :param n: int : number of ngrams for which representation is calculated 18 | :return: term frequency vector for occuring ngrams 19 | """ 20 | words = s.split() 21 | counts = defaultdict(int) 22 | for k in xrange(1,n+1): 23 | for i in xrange(len(words)-k+1): 24 | ngram = tuple(words[i:i+k]) 25 | counts[ngram] += 1 26 | return counts 27 | 28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 29 | '''Takes a list of reference sentences for a single segment 30 | and returns an object that encapsulates everything that BLEU 31 | needs to know about them. 32 | :param refs: list of string : reference sentences for some image 33 | :param n: int : number of ngrams for which (ngram) representation is calculated 34 | :return: result (list of dict) 35 | ''' 36 | return [precook(ref, n) for ref in refs] 37 | 38 | def cook_test(test, n=4): 39 | '''Takes a test sentence and returns an object that 40 | encapsulates everything that BLEU needs to know about it. 41 | :param test: list of string : hypothesis sentence for some image 42 | :param n: int : number of ngrams for which (ngram) representation is calculated 43 | :return: result (dict) 44 | ''' 45 | return precook(test, n, True) 46 | 47 | class CiderScorer(object): 48 | """CIDEr scorer. 49 | """ 50 | 51 | def copy(self): 52 | ''' copy the refs.''' 53 | new = CiderScorer(n=self.n) 54 | new.ctest = copy.copy(self.ctest) 55 | new.crefs = copy.copy(self.crefs) 56 | return new 57 | 58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 59 | ''' singular instance ''' 60 | self.n = n 61 | self.sigma = sigma 62 | self.crefs = [] 63 | self.ctest = [] 64 | self.document_frequency = defaultdict(float) 65 | self.cook_append(test, refs) 66 | self.ref_len = None 67 | 68 | def cook_append(self, test, refs): 69 | '''called by constructor and __iadd__ to avoid creating new instances.''' 70 | 71 | if refs is not None: 72 | self.crefs.append(cook_refs(refs)) 73 | if test is not None: 74 | self.ctest.append(cook_test(test)) ## N.B.: -1 75 | else: 76 | self.ctest.append(None) # lens of crefs and ctest have to match 77 | 78 | def size(self): 79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 80 | return len(self.crefs) 81 | 82 | def __iadd__(self, other): 83 | '''add an instance (e.g., from another sentence).''' 84 | 85 | if type(other) is tuple: 86 | ## avoid creating new CiderScorer instances 87 | self.cook_append(other[0], other[1]) 88 | else: 89 | self.ctest.extend(other.ctest) 90 | self.crefs.extend(other.crefs) 91 | 92 | return self 93 | def compute_doc_freq(self): 94 | ''' 95 | Compute term frequency for reference data. 96 | This will be used to compute idf (inverse document frequency later) 97 | The term frequency is stored in the object 98 | :return: None 99 | ''' 100 | for refs in self.crefs: 101 | # refs, k ref captions of one image 102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.iteritems()]): 103 | self.document_frequency[ngram] += 1 104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 105 | 106 | def compute_cider(self): 107 | def counts2vec(cnts): 108 | """ 109 | Function maps counts of ngram to vector of tfidf weights. 110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 111 | The n-th entry of array denotes length of n-grams. 112 | :param cnts: 113 | :return: vec (array of dict), norm (array of float), length (int) 114 | """ 115 | vec = [defaultdict(float) for _ in range(self.n)] 116 | length = 0 117 | norm = [0.0 for _ in range(self.n)] 118 | for (ngram,term_freq) in cnts.iteritems(): 119 | # give word count 1 if it doesn't appear in reference corpus 120 | df = np.log(max(1.0, self.document_frequency[ngram])) 121 | # ngram index 122 | n = len(ngram)-1 123 | # tf (term_freq) * idf (precomputed idf) for n-grams 124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 125 | # compute norm for the vector. the norm will be used for computing similarity 126 | norm[n] += pow(vec[n][ngram], 2) 127 | 128 | if n == 1: 129 | length += term_freq 130 | norm = [np.sqrt(n) for n in norm] 131 | return vec, norm, length 132 | 133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 134 | ''' 135 | Compute the cosine similarity of two vectors. 136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 137 | :param vec_ref: array of dictionary for vector corresponding to reference 138 | :param norm_hyp: array of float for vector corresponding to hypothesis 139 | :param norm_ref: array of float for vector corresponding to reference 140 | :param length_hyp: int containing length of hypothesis 141 | :param length_ref: int containing length of reference 142 | :return: array of score for each n-grams cosine similarity 143 | ''' 144 | delta = float(length_hyp - length_ref) 145 | # measure consine similarity 146 | val = np.array([0.0 for _ in range(self.n)]) 147 | for n in range(self.n): 148 | # ngram 149 | for (ngram,count) in vec_hyp[n].iteritems(): 150 | # vrama91 : added clipping 151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 152 | 153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 154 | val[n] /= (norm_hyp[n]*norm_ref[n]) 155 | 156 | assert(not math.isnan(val[n])) 157 | # vrama91: added a length based gaussian penalty 158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 159 | return val 160 | 161 | # compute log reference length 162 | self.ref_len = np.log(float(len(self.crefs))) 163 | 164 | scores = [] 165 | for test, refs in zip(self.ctest, self.crefs): 166 | # compute vector for test captions 167 | vec, norm, length = counts2vec(test) 168 | # compute vector for ref captions 169 | score = np.array([0.0 for _ in range(self.n)]) 170 | for ref in refs: 171 | vec_ref, norm_ref, length_ref = counts2vec(ref) 172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 173 | # change by vrama91 - mean of ngram scores, instead of sum 174 | score_avg = np.mean(score) 175 | # divide by number of references 176 | score_avg /= len(refs) 177 | # multiply score by 10 178 | score_avg *= 10.0 179 | # append score of an image to the score list 180 | scores.append(score_avg) 181 | return scores 182 | 183 | def compute_score(self, option=None, verbose=0): 184 | # compute idf 185 | self.compute_doc_freq() 186 | # assert to check document frequency 187 | assert(len(self.ctest) >= max(self.document_frequency.values())) 188 | # compute cider score 189 | score = self.compute_cider() 190 | # debug 191 | # print score 192 | return np.mean(np.array(score)), np.array(score) -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | 6 | import os 7 | import sys 8 | import subprocess 9 | import threading 10 | 11 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 12 | METEOR_JAR = 'meteor-1.5.jar' 13 | # print METEOR_JAR 14 | 15 | class Meteor: 16 | 17 | def __init__(self): 18 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \ 19 | '-', '-', '-stdio', '-l', 'en', '-norm'] 20 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \ 21 | cwd=os.path.dirname(os.path.abspath(__file__)), \ 22 | stdin=subprocess.PIPE, \ 23 | stdout=subprocess.PIPE, \ 24 | stderr=subprocess.PIPE) 25 | # Used to guarantee thread safety 26 | self.lock = threading.Lock() 27 | 28 | def compute_score(self, gts, res): 29 | assert(gts.keys() == res.keys()) 30 | imgIds = gts.keys() 31 | scores = [] 32 | 33 | eval_line = 'EVAL' 34 | self.lock.acquire() 35 | for i in imgIds: 36 | assert(len(res[i]) == 1) 37 | stat = self._stat(res[i][0], gts[i]) 38 | eval_line += ' ||| {}'.format(stat) 39 | 40 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 41 | for i in range(0,len(imgIds)): 42 | scores.append(float(self.meteor_p.stdout.readline().strip())) 43 | score = float(self.meteor_p.stdout.readline().strip()) 44 | self.lock.release() 45 | 46 | return score, scores 47 | 48 | def method(self): 49 | return "METEOR" 50 | 51 | def _stat(self, hypothesis_str, reference_list): 52 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 53 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 54 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 55 | self.meteor_p.stdin.write('{}\n'.format(score_line)) 56 | return self.meteor_p.stdout.readline().strip() 57 | 58 | def _score(self, hypothesis_str, reference_list): 59 | self.lock.acquire() 60 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 61 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ') 62 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 63 | self.meteor_p.stdin.write('{}\n'.format(score_line)) 64 | stats = self.meteor_p.stdout.readline().strip() 65 | eval_line = 'EVAL ||| {}'.format(stats) 66 | # EVAL ||| stats 67 | self.meteor_p.stdin.write('{}\n'.format(eval_line)) 68 | score = float(self.meteor_p.stdout.readline().strip()) 69 | self.lock.release() 70 | return score 71 | 72 | def __exit__(self): 73 | self.lock.acquire() 74 | self.meteor_p.stdin.close() 75 | self.meteor_p.wait() 76 | self.lock.release() 77 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/readme.txt: -------------------------------------------------------------------------------- 1 | This folder contains modified coco-caption evaluation, which is downloaded from https://github.com/tylin/coco-caption.git 2 | and refEvaluation which is to be called by the refer algorithm. 3 | 4 | More specifically, this folder contains: 5 | 1. bleu/ 6 | 2. cider/ 7 | 3. meteor/ 8 | 4. rouge/ 9 | 5. tokenizer/ 10 | 6. __init__.py 11 | 7. refEvaluation.py 12 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/refEvaluation.py: -------------------------------------------------------------------------------- 1 | from tokenizer.ptbtokenizer import PTBTokenizer 2 | from bleu.bleu import Bleu 3 | from meteor.meteor import Meteor 4 | from rouge.rouge import Rouge 5 | from cider.cider import Cider 6 | 7 | """ 8 | Input: refer and Res = [{ref_id, sent}] 9 | 10 | Things of interest 11 | evalRefs - list of ['ref_id', 'CIDEr', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'ROUGE_L', 'METEOR'] 12 | eval - dict of {metric: score} 13 | refToEval - dict of {ref_id: ['ref_id', 'CIDEr', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'ROUGE_L', 'METEOR']} 14 | """ 15 | 16 | class RefEvaluation: 17 | def __init__ (self, refer, Res): 18 | """ 19 | :param refer: refer class of current dataset 20 | :param Res: [{'ref_id', 'sent'}] 21 | """ 22 | self.evalRefs = [] 23 | self.eval = {} 24 | self.refToEval = {} 25 | self.refer = refer 26 | self.Res = Res 27 | 28 | def evaluate(self): 29 | 30 | evalRefIds = [ann['ref_id'] for ann in self.Res] 31 | 32 | refToGts = {} 33 | for ref_id in evalRefIds: 34 | ref = self.refer.Refs[ref_id] 35 | gt_sents = [sent['sent'].encode('ascii', 'ignore').decode('ascii') for sent in ref['sentences']] # up to 3 expressions 36 | refToGts[ref_id] = gt_sents 37 | refToRes = {ann['ref_id']: [ann['sent']] for ann in self.Res} 38 | 39 | print 'tokenization...' 40 | tokenizer = PTBTokenizer() 41 | self.refToRes = tokenizer.tokenize(refToRes) 42 | self.refToGts = tokenizer.tokenize(refToGts) 43 | 44 | # ================================================= 45 | # Set up scorers 46 | # ================================================= 47 | print 'setting up scorers...' 48 | scorers = [ 49 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 50 | (Meteor(),"METEOR"), 51 | (Rouge(), "ROUGE_L"), 52 | (Cider(), "CIDEr") 53 | ] 54 | 55 | # ================================================= 56 | # Compute scores 57 | # ================================================= 58 | for scorer, method in scorers: 59 | print 'computing %s score...'%(scorer.method()) 60 | score, scores = scorer.compute_score(self.refToGts, self.refToRes) 61 | if type(method) == list: 62 | for sc, scs, m in zip(score, scores, method): 63 | self.setEval(sc, m) 64 | self.setRefToEvalRefs(scs, self.refToGts.keys(), m) 65 | print "%s: %0.3f"%(m, sc) 66 | else: 67 | self.setEval(score, method) 68 | self.setRefToEvalRefs(scores, self.refToGts.keys(), method) 69 | print "%s: %0.3f"%(method, score) 70 | self.setEvalRefs() 71 | 72 | def setEval(self, score, method): 73 | self.eval[method] = score 74 | 75 | def setRefToEvalRefs(self, scores, refIds, method): 76 | for refId, score in zip(refIds, scores): 77 | if not refId in self.refToEval: 78 | self.refToEval[refId] = {} 79 | self.refToEval[refId]["ref_id"] = refId 80 | self.refToEval[refId][method] = score 81 | 82 | def setEvalRefs(self): 83 | self.evalRefs = [eval for refId, eval in self.refToEval.items()] 84 | 85 | 86 | if __name__ == '__main__': 87 | 88 | import os.path as osp 89 | import sys 90 | ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..')) 91 | sys.path.insert(0, osp.join(ROOT_DIR, 'lib', 'datasets')) 92 | from refer import REFER 93 | 94 | # load refer of dataset 95 | dataset = 'refcoco' 96 | refer = REFER(dataset, splitBy = 'google') 97 | 98 | # mimic some Res 99 | val_refIds = refer.getRefIds(split='test') 100 | ref_id = 49767 101 | print "GD: %s" % refer.Refs[ref_id]['sentences'] 102 | Res = [{'ref_id': ref_id, 'sent': 'left bottle'}] 103 | 104 | # evaluate some refer expressions 105 | refEval = RefEvaluation(refer, Res) 106 | refEval.evaluate() 107 | 108 | # print output evaluation scores 109 | for metric, score in refEval.eval.items(): 110 | print '%s: %.3f'%(metric, score) 111 | 112 | # demo how to use evalImgs to retrieve low score result 113 | # evals = [eva for eva in refEval.evalRefs if eva['CIDEr']<30] 114 | # print 'ground truth sents' 115 | # refId = evals[0]['ref_id'] 116 | # print 'refId: %s' % refId 117 | # print [sent['sent'] for sent in refer.Refs[refId]['sentences']] 118 | # 119 | # print 'generated sent (CIDEr score %0.1f)' % (evals[0]['CIDEr']) 120 | 121 | # print refEval.refToEval[8] 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'vrama91' 2 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | import pdb 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if(len(string)< len(sub)): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)] 26 | 27 | for j in range(1,len(sub)+1): 28 | for i in range(1,len(string)+1): 29 | if(string[i-1] == sub[j-1]): 30 | lengths[i][j] = lengths[i-1][j-1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | class Rouge(): 37 | ''' 38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 39 | 40 | ''' 41 | def __init__(self): 42 | # vrama91: updated the value below based on discussion with Hovey 43 | self.beta = 1.2 44 | 45 | def calc_score(self, candidate, refs): 46 | """ 47 | Compute ROUGE-L score given one candidate and references for an image 48 | :param candidate: str : candidate sentence to be evaluated 49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 50 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 51 | """ 52 | assert(len(candidate)==1) 53 | assert(len(refs)>0) 54 | prec = [] 55 | rec = [] 56 | 57 | # split into tokens 58 | token_c = candidate[0].split(" ") 59 | 60 | for reference in refs: 61 | # split into tokens 62 | token_r = reference.split(" ") 63 | # compute the longest common subsequence 64 | lcs = my_lcs(token_r, token_c) 65 | prec.append(lcs/float(len(token_c))) 66 | rec.append(lcs/float(len(token_r))) 67 | 68 | prec_max = max(prec) 69 | rec_max = max(rec) 70 | 71 | if(prec_max!=0 and rec_max !=0): 72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max) 73 | else: 74 | score = 0.0 75 | return score 76 | 77 | def compute_score(self, gts, res): 78 | """ 79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 80 | Invoked by evaluate_captions.py 81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 84 | """ 85 | assert(gts.keys() == res.keys()) 86 | imgIds = gts.keys() 87 | 88 | score = [] 89 | for id in imgIds: 90 | hypo = res[id] 91 | ref = gts[id] 92 | 93 | score.append(self.calc_score(hypo, ref)) 94 | 95 | # Sanity check. 96 | assert(type(hypo) is list) 97 | assert(len(hypo) == 1) 98 | assert(type(ref) is list) 99 | assert(len(ref) > 0) 100 | 101 | average_score = np.mean(np.array(score)) 102 | return average_score, np.array(score) 103 | 104 | def method(self): 105 | return "Rouge" 106 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'hfang' 2 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/tokenizer/ptbtokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : ptbtokenizer.py 4 | # 5 | # Description : Do the PTB Tokenization and remove punctuations. 6 | # 7 | # Creation Date : 29-12-2014 8 | # Last Modified : Thu Mar 19 09:53:35 2015 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | import os 12 | import sys 13 | import subprocess 14 | import tempfile 15 | import itertools 16 | 17 | # path to the stanford corenlp jar 18 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar' 19 | 20 | # punctuations to be removed from the sentences 21 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \ 22 | ".", "?", "!", ",", ":", "-", "--", "...", ";"] 23 | 24 | class PTBTokenizer: 25 | """Python wrapper of Stanford PTBTokenizer""" 26 | 27 | def tokenize(self, captions_for_image): 28 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \ 29 | 'edu.stanford.nlp.process.PTBTokenizer', \ 30 | '-preserveLines', '-lowerCase'] 31 | 32 | # ====================================================== 33 | # prepare data for PTB Tokenizer 34 | # ====================================================== 35 | final_tokenized_captions_for_image = {} 36 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))] 37 | sentences = '\n'.join([c.replace('\n', ' ') for k, v in captions_for_image.items() for c in v]) 38 | 39 | # ====================================================== 40 | # save sentences to temporary file 41 | # ====================================================== 42 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__)) 43 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname) 44 | tmp_file.write(sentences) 45 | tmp_file.close() 46 | 47 | # ====================================================== 48 | # tokenize sentence 49 | # ====================================================== 50 | cmd.append(os.path.basename(tmp_file.name)) 51 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \ 52 | stdout=subprocess.PIPE) 53 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0] 54 | lines = token_lines.split('\n') 55 | # remove temp file 56 | os.remove(tmp_file.name) 57 | 58 | # ====================================================== 59 | # create dictionary for tokenized captions 60 | # ====================================================== 61 | for k, line in zip(image_id, lines): 62 | if not k in final_tokenized_captions_for_image: 63 | final_tokenized_captions_for_image[k] = [] 64 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \ 65 | if w not in PUNCTUATIONS]) 66 | final_tokenized_captions_for_image[k].append(tokenized_caption) 67 | 68 | return final_tokenized_captions_for_image 69 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/evaluation/tokenizer/stanford-corenlp-3.4.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fgirbal/segment-select-correct/0d36714c1166a598051e467b0bddc5ec23a67d6a/ssc_ris/refer_dataset/refer/evaluation/tokenizer/stanford-corenlp-3.4.1.jar -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/external/README.md: -------------------------------------------------------------------------------- 1 | The codes inside this folder are copied from pycocotools: https://github.com/pdollar/coco -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/external/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/external/_mask.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c 2 | # distutils: sources = external/maskApi.c 3 | 4 | #************************************************************************** 5 | # Microsoft COCO Toolbox. version 2.0 6 | # Data, paper, and tutorials available at: http://mscoco.org/ 7 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 8 | # Licensed under the Simplified BSD License [see coco/license.txt] 9 | #************************************************************************** 10 | 11 | __author__ = 'tsungyi' 12 | 13 | # import both Python-level and C-level symbols of Numpy 14 | # the API uses Numpy to interface C and Python 15 | import numpy as np 16 | cimport numpy as np 17 | from libc.stdlib cimport malloc, free 18 | 19 | # intialized Numpy. must do. 20 | np.import_array() 21 | 22 | # import numpy C function 23 | # we use PyArray_ENABLEFLAGS to make Numpy ndarray responsible to memoery management 24 | cdef extern from "numpy/arrayobject.h": 25 | void PyArray_ENABLEFLAGS(np.ndarray arr, int flags) 26 | 27 | # Declare the prototype of the C functions in MaskApi.h 28 | cdef extern from "maskApi.h": 29 | ctypedef unsigned int uint 30 | ctypedef unsigned long siz 31 | ctypedef unsigned char byte 32 | ctypedef double* BB 33 | ctypedef struct RLE: 34 | siz h, 35 | siz w, 36 | siz m, 37 | uint* cnts, 38 | void rlesInit( RLE **R, siz n ) 39 | void rleEncode( RLE *R, const byte *M, siz h, siz w, siz n ) 40 | void rleDecode( const RLE *R, byte *mask, siz n ) 41 | void rleMerge( const RLE *R, RLE *M, siz n, bint intersect ) 42 | void rleArea( const RLE *R, siz n, uint *a ) 43 | void rleIou( RLE *dt, RLE *gt, siz m, siz n, byte *iscrowd, double *o ) 44 | void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o ) 45 | void rleToBbox( const RLE *R, BB bb, siz n ) 46 | void rleFrBbox( RLE *R, const BB bb, siz h, siz w, siz n ) 47 | void rleFrPoly( RLE *R, const double *xy, siz k, siz h, siz w ) 48 | char* rleToString( const RLE *R ) 49 | void rleFrString( RLE *R, char *s, siz h, siz w ) 50 | 51 | # python class to wrap RLE array in C 52 | # the class handles the memory allocation and deallocation 53 | cdef class RLEs: 54 | cdef RLE *_R 55 | cdef siz _n 56 | 57 | def __cinit__(self, siz n =0): 58 | rlesInit(&self._R, n) 59 | self._n = n 60 | 61 | # free the RLE array here 62 | def __dealloc__(self): 63 | if self._R is not NULL: 64 | for i in range(self._n): 65 | free(self._R[i].cnts) 66 | free(self._R) 67 | def __getattr__(self, key): 68 | if key == 'n': 69 | return self._n 70 | raise AttributeError(key) 71 | 72 | # python class to wrap Mask array in C 73 | # the class handles the memory allocation and deallocation 74 | cdef class Masks: 75 | cdef byte *_mask 76 | cdef siz _h 77 | cdef siz _w 78 | cdef siz _n 79 | 80 | def __cinit__(self, h, w, n): 81 | self._mask = malloc(h*w*n* sizeof(byte)) 82 | self._h = h 83 | self._w = w 84 | self._n = n 85 | # def __dealloc__(self): 86 | # the memory management of _mask has been passed to np.ndarray 87 | # it doesn't need to be freed here 88 | 89 | # called when passing into np.array() and return an np.ndarray in column-major order 90 | def __array__(self): 91 | cdef np.npy_intp shape[1] 92 | shape[0] = self._h*self._w*self._n 93 | # Create a 1D array, and reshape it to fortran/Matlab column-major array 94 | ndarray = np.PyArray_SimpleNewFromData(1, shape, np.NPY_UINT8, self._mask).reshape((self._h, self._w, self._n), order='F') 95 | # The _mask allocated by Masks is now handled by ndarray 96 | PyArray_ENABLEFLAGS(ndarray, np.NPY_OWNDATA) 97 | return ndarray 98 | 99 | # internal conversion from Python RLEs object to compressed RLE format 100 | def _toString(RLEs Rs): 101 | cdef siz n = Rs.n 102 | cdef bytes py_string 103 | cdef char* c_string 104 | objs = [] 105 | for i in range(n): 106 | c_string = rleToString( &Rs._R[i] ) 107 | py_string = c_string 108 | objs.append({ 109 | 'size': [Rs._R[i].h, Rs._R[i].w], 110 | 'counts': py_string 111 | }) 112 | free(c_string) 113 | return objs 114 | 115 | # internal conversion from compressed RLE format to Python RLEs object 116 | def _frString(rleObjs): 117 | cdef siz n = len(rleObjs) 118 | Rs = RLEs(n) 119 | cdef bytes py_string 120 | cdef char* c_string 121 | for i, obj in enumerate(rleObjs): 122 | py_string = str(obj['counts']) 123 | c_string = py_string 124 | rleFrString( &Rs._R[i], c_string, obj['size'][0], obj['size'][1] ) 125 | return Rs 126 | 127 | # encode mask to RLEs objects 128 | # list of RLE string can be generated by RLEs member function 129 | def encode(np.ndarray[np.uint8_t, ndim=3, mode='fortran'] mask): 130 | h, w, n = mask.shape[0], mask.shape[1], mask.shape[2] 131 | cdef RLEs Rs = RLEs(n) 132 | rleEncode(Rs._R,mask.data,h,w,n) 133 | objs = _toString(Rs) 134 | return objs 135 | 136 | # decode mask from compressed list of RLE string or RLEs object 137 | def decode(rleObjs): 138 | cdef RLEs Rs = _frString(rleObjs) 139 | h, w, n = Rs._R[0].h, Rs._R[0].w, Rs._n 140 | masks = Masks(h, w, n) 141 | rleDecode( Rs._R, masks._mask, n ); 142 | return np.array(masks) 143 | 144 | def merge(rleObjs, bint intersect=0): 145 | cdef RLEs Rs = _frString(rleObjs) 146 | cdef RLEs R = RLEs(1) 147 | rleMerge(Rs._R, R._R, Rs._n, intersect) 148 | obj = _toString(R)[0] 149 | return obj 150 | 151 | def area(rleObjs): 152 | cdef RLEs Rs = _frString(rleObjs) 153 | cdef uint* _a = malloc(Rs._n* sizeof(uint)) 154 | rleArea(Rs._R, Rs._n, _a) 155 | cdef np.npy_intp shape[1] 156 | shape[0] = Rs._n 157 | a = np.array((Rs._n, ), dtype=np.uint8) 158 | a = np.PyArray_SimpleNewFromData(1, shape, np.NPY_UINT32, _a) 159 | PyArray_ENABLEFLAGS(a, np.NPY_OWNDATA) 160 | return a 161 | 162 | # iou computation. support function overload (RLEs-RLEs and bbox-bbox). 163 | def iou( dt, gt, pyiscrowd ): 164 | def _preproc(objs): 165 | if len(objs) == 0: 166 | return objs 167 | if type(objs) == np.ndarray: 168 | if len(objs.shape) == 1: 169 | objs = objs.reshape((objs[0], 1)) 170 | # check if it's Nx4 bbox 171 | if not len(objs.shape) == 2 or not objs.shape[1] == 4: 172 | raise Exception('numpy ndarray input is only for *bounding boxes* and should have Nx4 dimension') 173 | objs = objs.astype(np.double) 174 | elif type(objs) == list: 175 | # check if list is in box format and convert it to np.ndarray 176 | isbox = np.all(np.array([(len(obj)==4) and ((type(obj)==list) or (type(obj)==np.ndarray)) for obj in objs])) 177 | isrle = np.all(np.array([type(obj) == dict for obj in objs])) 178 | if isbox: 179 | objs = np.array(objs, dtype=np.double) 180 | if len(objs.shape) == 1: 181 | objs = objs.reshape((1,objs.shape[0])) 182 | elif isrle: 183 | objs = _frString(objs) 184 | else: 185 | raise Exception('list input can be bounding box (Nx4) or RLEs ([RLE])') 186 | else: 187 | raise Exception('unrecognized type. The following type: RLEs (rle), np.ndarray (box), and list (box) are supported.') 188 | return objs 189 | def _rleIou(RLEs dt, RLEs gt, np.ndarray[np.uint8_t, ndim=1] iscrowd, siz m, siz n, np.ndarray[np.double_t, ndim=1] _iou): 190 | rleIou( dt._R, gt._R, m, n, iscrowd.data, _iou.data ) 191 | def _bbIou(np.ndarray[np.double_t, ndim=2] dt, np.ndarray[np.double_t, ndim=2] gt, np.ndarray[np.uint8_t, ndim=1] iscrowd, siz m, siz n, np.ndarray[np.double_t, ndim=1] _iou): 192 | bbIou( dt.data, gt.data, m, n, iscrowd.data, _iou.data ) 193 | def _len(obj): 194 | cdef siz N = 0 195 | if type(obj) == RLEs: 196 | N = obj.n 197 | elif len(obj)==0: 198 | pass 199 | elif type(obj) == np.ndarray: 200 | N = obj.shape[0] 201 | return N 202 | # convert iscrowd to numpy array 203 | cdef np.ndarray[np.uint8_t, ndim=1] iscrowd = np.array(pyiscrowd, dtype=np.uint8) 204 | # simple type checking 205 | cdef siz m, n 206 | dt = _preproc(dt) 207 | gt = _preproc(gt) 208 | m = _len(dt) 209 | n = _len(gt) 210 | if m == 0 or n == 0: 211 | return [] 212 | if not type(dt) == type(gt): 213 | raise Exception('The dt and gt should have the same data type, either RLEs, list or np.ndarray') 214 | 215 | # define local variables 216 | cdef double* _iou = 0 217 | cdef np.npy_intp shape[1] 218 | # check type and assign iou function 219 | if type(dt) == RLEs: 220 | _iouFun = _rleIou 221 | elif type(dt) == np.ndarray: 222 | _iouFun = _bbIou 223 | else: 224 | raise Exception('input data type not allowed.') 225 | _iou = malloc(m*n* sizeof(double)) 226 | iou = np.zeros((m*n, ), dtype=np.double) 227 | shape[0] = m*n 228 | iou = np.PyArray_SimpleNewFromData(1, shape, np.NPY_DOUBLE, _iou) 229 | PyArray_ENABLEFLAGS(iou, np.NPY_OWNDATA) 230 | _iouFun(dt, gt, iscrowd, m, n, iou) 231 | return iou.reshape((m,n), order='F') 232 | 233 | def toBbox( rleObjs ): 234 | cdef RLEs Rs = _frString(rleObjs) 235 | cdef siz n = Rs.n 236 | cdef BB _bb = malloc(4*n* sizeof(double)) 237 | rleToBbox( Rs._R, _bb, n ) 238 | cdef np.npy_intp shape[1] 239 | shape[0] = 4*n 240 | bb = np.array((1,4*n), dtype=np.double) 241 | bb = np.PyArray_SimpleNewFromData(1, shape, np.NPY_DOUBLE, _bb).reshape((n, 4)) 242 | PyArray_ENABLEFLAGS(bb, np.NPY_OWNDATA) 243 | return bb 244 | 245 | def frBbox(np.ndarray[np.double_t, ndim=2] bb, siz h, siz w ): 246 | cdef siz n = bb.shape[0] 247 | Rs = RLEs(n) 248 | rleFrBbox( Rs._R, bb.data, h, w, n ) 249 | objs = _toString(Rs) 250 | return objs 251 | 252 | def frPoly( poly, siz h, siz w ): 253 | cdef np.ndarray[np.double_t, ndim=1] np_poly 254 | n = len(poly) 255 | Rs = RLEs(n) 256 | for i, p in enumerate(poly): 257 | np_poly = np.array(p, dtype=np.double, order='F') 258 | rleFrPoly( &Rs._R[i], np_poly.data, len(np_poly)/2, h, w ) 259 | objs = _toString(Rs) 260 | return objs 261 | 262 | def frUncompressedRLE(ucRles, siz h, siz w): 263 | cdef np.ndarray[np.uint32_t, ndim=1] cnts 264 | cdef RLE R 265 | cdef uint *data 266 | n = len(ucRles) 267 | objs = [] 268 | for i in range(n): 269 | Rs = RLEs(1) 270 | cnts = np.array(ucRles[i]['counts'], dtype=np.uint32) 271 | # time for malloc can be saved here but it's fine 272 | data = malloc(len(cnts)* sizeof(uint)) 273 | for j in range(len(cnts)): 274 | data[j] = cnts[j] 275 | R = RLE(ucRles[i]['size'][0], ucRles[i]['size'][1], len(cnts), data) 276 | Rs._R[0] = R 277 | objs.append(_toString(Rs)[0]) 278 | return objs 279 | 280 | def frPyObjects(pyobj, siz h, w): 281 | if type(pyobj) == np.ndarray: 282 | objs = frBbox(pyobj, h, w ) 283 | elif type(pyobj) == list and len(pyobj[0]) == 4: 284 | objs = frBbox(pyobj, h, w ) 285 | elif type(pyobj) == list and len(pyobj[0]) > 4: 286 | objs = frPoly(pyobj, h, w ) 287 | elif type(pyobj) == list and type(pyobj[0]) == dict: 288 | objs = frUncompressedRLE(pyobj, h, w) 289 | else: 290 | raise Exception('input type is not supported.') 291 | return objs 292 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/external/mask.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tsungyi' 2 | 3 | import external._mask as _mask 4 | 5 | # Interface for manipulating masks stored in RLE format. 6 | # 7 | # RLE is a simple yet efficient format for storing binary masks. RLE 8 | # first divides a vector (or vectorized image) into a series of piecewise 9 | # constant regions and then for each piece simply stores the length of 10 | # that piece. For example, given M=[0 0 1 1 1 0 1] the RLE counts would 11 | # be [2 3 1 1], or for M=[1 1 1 1 1 1 0] the counts would be [0 6 1] 12 | # (note that the odd counts are always the numbers of zeros). Instead of 13 | # storing the counts directly, additional compression is achieved with a 14 | # variable bitrate representation based on a common scheme called LEB128. 15 | # 16 | # Compression is greatest given large piecewise constant regions. 17 | # Specifically, the size of the RLE is proportional to the number of 18 | # *boundaries* in M (or for an image the number of boundaries in the y 19 | # direction). Assuming fairly simple shapes, the RLE representation is 20 | # O(sqrt(n)) where n is number of pixels in the object. Hence space usage 21 | # is substantially lower, especially for large simple objects (large n). 22 | # 23 | # Many common operations on masks can be computed directly using the RLE 24 | # (without need for decoding). This includes computations such as area, 25 | # union, intersection, etc. All of these operations are linear in the 26 | # size of the RLE, in other words they are O(sqrt(n)) where n is the area 27 | # of the object. Computing these operations on the original mask is O(n). 28 | # Thus, using the RLE can result in substantial computational savings. 29 | # 30 | # The following API functions are defined: 31 | # encode - Encode binary masks using RLE. 32 | # decode - Decode binary masks encoded via RLE. 33 | # merge - Compute union or intersection of encoded masks. 34 | # iou - Compute intersection over union between masks. 35 | # area - Compute area of encoded masks. 36 | # toBbox - Get bounding boxes surrounding encoded masks. 37 | # frPyObjects - Convert polygon, bbox, and uncompressed RLE to encoded RLE mask. 38 | # 39 | # Usage: 40 | # Rs = encode( masks ) 41 | # masks = decode( Rs ) 42 | # R = merge( Rs, intersect=false ) 43 | # o = iou( dt, gt, iscrowd ) 44 | # a = area( Rs ) 45 | # bbs = toBbox( Rs ) 46 | # Rs = frPyObjects( [pyObjects], h, w ) 47 | # 48 | # In the API the following formats are used: 49 | # Rs - [dict] Run-length encoding of binary masks 50 | # R - dict Run-length encoding of binary mask 51 | # masks - [hxwxn] Binary mask(s) (must have type np.ndarray(dtype=uint8) in column-major order) 52 | # iscrowd - [nx1] list of np.ndarray. 1 indicates corresponding gt image has crowd region to ignore 53 | # bbs - [nx4] Bounding box(es) stored as [x y w h] 54 | # poly - Polygon stored as [[x1 y1 x2 y2...],[x1 y1 ...],...] (2D list) 55 | # dt,gt - May be either bounding boxes or encoded masks 56 | # Both poly and bbs are 0-indexed (bbox=[0 0 1 1] encloses first pixel). 57 | # 58 | # Finally, a note about the intersection over union (iou) computation. 59 | # The standard iou of a ground truth (gt) and detected (dt) object is 60 | # iou(gt,dt) = area(intersect(gt,dt)) / area(union(gt,dt)) 61 | # For "crowd" regions, we use a modified criteria. If a gt object is 62 | # marked as "iscrowd", we allow a dt to match any subregion of the gt. 63 | # Choosing gt' in the crowd gt that best matches the dt can be done using 64 | # gt'=intersect(dt,gt). Since by definition union(gt',dt)=dt, computing 65 | # iou(gt,dt,iscrowd) = iou(gt',dt) = area(intersect(gt,dt)) / area(dt) 66 | # For crowd gt regions we use this modified criteria above for the iou. 67 | # 68 | # To compile run "python setup.py build_ext --inplace" 69 | # Please do not contact us for help with compiling. 70 | # 71 | # Microsoft COCO Toolbox. version 2.0 72 | # Data, paper, and tutorials available at: http://mscoco.org/ 73 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 74 | # Licensed under the Simplified BSD License [see coco/license.txt] 75 | 76 | encode = _mask.encode 77 | decode = _mask.decode 78 | iou = _mask.iou 79 | merge = _mask.merge 80 | area = _mask.area 81 | toBbox = _mask.toBbox 82 | frPyObjects = _mask.frPyObjects -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/external/maskApi.c: -------------------------------------------------------------------------------- 1 | /************************************************************************** 2 | * Microsoft COCO Toolbox. version 2.0 3 | * Data, paper, and tutorials available at: http://mscoco.org/ 4 | * Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 5 | * Licensed under the Simplified BSD License [see coco/license.txt] 6 | **************************************************************************/ 7 | #include "maskApi.h" 8 | #include 9 | #include 10 | 11 | uint umin( uint a, uint b ) { return (ab) ? a : b; } 13 | 14 | void rleInit( RLE *R, siz h, siz w, siz m, uint *cnts ) { 15 | R->h=h; R->w=w; R->m=m; R->cnts=(m==0)?0:malloc(sizeof(uint)*m); 16 | siz j; if(cnts) for(j=0; jcnts[j]=cnts[j]; 17 | } 18 | 19 | void rleFree( RLE *R ) { 20 | free(R->cnts); R->cnts=0; 21 | } 22 | 23 | void rlesInit( RLE **R, siz n ) { 24 | siz i; *R = (RLE*) malloc(sizeof(RLE)*n); 25 | for(i=0; i0 ) { 61 | c=umin(ca,cb); cc+=c; ct=0; 62 | ca-=c; if(!ca && a0) { 83 | crowd=iscrowd!=NULL && iscrowd[g]; 84 | if(dt[d].h!=gt[g].h || dt[d].w!=gt[g].w) { o[g*m+d]=-1; continue; } 85 | siz ka, kb, a, b; uint c, ca, cb, ct, i, u; int va, vb; 86 | ca=dt[d].cnts[0]; ka=dt[d].m; va=vb=0; 87 | cb=gt[g].cnts[0]; kb=gt[g].m; a=b=1; i=u=0; ct=1; 88 | while( ct>0 ) { 89 | c=umin(ca,cb); if(va||vb) { u+=c; if(va&&vb) i+=c; } ct=0; 90 | ca-=c; if(!ca && athr) keep[j]=0; 105 | } 106 | } 107 | } 108 | 109 | void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o ) { 110 | double h, w, i, u, ga, da; siz g, d; int crowd; 111 | for( g=0; gthr) keep[j]=0; 129 | } 130 | } 131 | } 132 | 133 | void rleToBbox( const RLE *R, BB bb, siz n ) { 134 | siz i; for( i=0; id?1:c=dy && xs>xe) || (dxye); 173 | if(flip) { t=xs; xs=xe; xe=t; t=ys; ys=ye; ye=t; } 174 | s = dx>=dy ? (double)(ye-ys)/dx : (double)(xe-xs)/dy; 175 | if(dx>=dy) for( d=0; d<=dx; d++ ) { 176 | t=flip?dx-d:d; u[m]=t+xs; v[m]=(int)(ys+s*t+.5); m++; 177 | } else for( d=0; d<=dy; d++ ) { 178 | t=flip?dy-d:d; v[m]=t+ys; u[m]=(int)(xs+s*t+.5); m++; 179 | } 180 | } 181 | /* get points along y-boundary and downsample */ 182 | free(x); free(y); k=m; m=0; double xd, yd; 183 | x=malloc(sizeof(int)*k); y=malloc(sizeof(int)*k); 184 | for( j=1; jw-1 ) continue; 187 | yd=(double)(v[j]h) yd=h; yd=ceil(yd); 189 | x[m]=(int) xd; y[m]=(int) yd; m++; 190 | } 191 | /* compute rle encoding given y-boundary points */ 192 | k=m; a=malloc(sizeof(uint)*(k+1)); 193 | for( j=0; j0) b[m++]=a[j++]; else { 199 | j++; if(jm, p=0; long x; int more; 206 | char *s=malloc(sizeof(char)*m*6); 207 | for( i=0; icnts[i]; if(i>2) x-=(long) R->cnts[i-2]; more=1; 209 | while( more ) { 210 | char c=x & 0x1f; x >>= 5; more=(c & 0x10) ? x!=-1 : x!=0; 211 | if(more) c |= 0x20; c+=48; s[p++]=c; 212 | } 213 | } 214 | s[p]=0; return s; 215 | } 216 | 217 | void rleFrString( RLE *R, char *s, siz h, siz w ) { 218 | siz m=0, p=0, k; long x; int more; uint *cnts; 219 | while( s[m] ) m++; cnts=malloc(sizeof(uint)*m); m=0; 220 | while( s[p] ) { 221 | x=0; k=0; more=1; 222 | while( more ) { 223 | char c=s[p]-48; x |= (c & 0x1f) << 5*k; 224 | more = c & 0x20; p++; k++; 225 | if(!more && (c & 0x10)) x |= -1 << 5*k; 226 | } 227 | if(m>2) x+=(long) cnts[m-2]; cnts[m++]=(uint) x; 228 | } 229 | rleInit(R,h,w,m,cnts); free(cnts); 230 | } 231 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/external/maskApi.h: -------------------------------------------------------------------------------- 1 | /************************************************************************** 2 | * Microsoft COCO Toolbox. version 2.0 3 | * Data, paper, and tutorials available at: http://mscoco.org/ 4 | * Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 5 | * Licensed under the Simplified BSD License [see coco/license.txt] 6 | **************************************************************************/ 7 | #pragma once 8 | 9 | typedef unsigned int uint; 10 | typedef unsigned long siz; 11 | typedef unsigned char byte; 12 | typedef double* BB; 13 | typedef struct { siz h, w, m; uint *cnts; } RLE; 14 | 15 | /* Initialize/destroy RLE. */ 16 | void rleInit( RLE *R, siz h, siz w, siz m, uint *cnts ); 17 | void rleFree( RLE *R ); 18 | 19 | /* Initialize/destroy RLE array. */ 20 | void rlesInit( RLE **R, siz n ); 21 | void rlesFree( RLE **R, siz n ); 22 | 23 | /* Encode binary masks using RLE. */ 24 | void rleEncode( RLE *R, const byte *mask, siz h, siz w, siz n ); 25 | 26 | /* Decode binary masks encoded via RLE. */ 27 | void rleDecode( const RLE *R, byte *mask, siz n ); 28 | 29 | /* Compute union or intersection of encoded masks. */ 30 | void rleMerge( const RLE *R, RLE *M, siz n, int intersect ); 31 | 32 | /* Compute area of encoded masks. */ 33 | void rleArea( const RLE *R, siz n, uint *a ); 34 | 35 | /* Compute intersection over union between masks. */ 36 | void rleIou( RLE *dt, RLE *gt, siz m, siz n, byte *iscrowd, double *o ); 37 | 38 | /* Compute non-maximum suppression between bounding masks */ 39 | void rleNms( RLE *dt, siz n, uint *keep, double thr ); 40 | 41 | /* Compute intersection over union between bounding boxes. */ 42 | void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o ); 43 | 44 | /* Compute non-maximum suppression between bounding boxes */ 45 | void bbNms( BB dt, siz n, uint *keep, double thr ); 46 | 47 | /* Get bounding boxes surrounding encoded masks. */ 48 | void rleToBbox( const RLE *R, BB bb, siz n ); 49 | 50 | /* Convert bounding boxes to encoded masks. */ 51 | void rleFrBbox( RLE *R, const BB bb, siz h, siz w, siz n ); 52 | 53 | /* Convert polygon to encoded mask. */ 54 | void rleFrPoly( RLE *R, const double *xy, siz k, siz h, siz w ); 55 | 56 | /* Get compressed string representation of encoded mask. */ 57 | char* rleToString( const RLE *R ); 58 | 59 | /* Convert from compressed string representation of encoded mask. */ 60 | void rleFrString( RLE *R, char *s, siz h, siz w ); 61 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/refer/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from Cython.Build import cythonize 3 | from distutils.extension import Extension 4 | import numpy as np 5 | 6 | ext_modules = [ 7 | Extension( 8 | 'external._mask', 9 | sources=['external/maskApi.c', 'external/_mask.pyx'], 10 | include_dirs = [np.get_include(), 'external'], 11 | extra_compile_args=['-Wno-cpp', '-Wno-unused-function', '-std=c99'], 12 | ) 13 | ] 14 | 15 | setup( 16 | name='external', 17 | packages=['external'], 18 | package_dir = {'external': 'external'}, 19 | version='2.0', 20 | ext_modules=cythonize(ext_modules) 21 | ) 22 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/unsupervised_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from .bert.tokenization_bert import BertTokenizer 8 | from .refer.refer import REFER 9 | 10 | 11 | class UnsupervisedReferDataset(data.Dataset): 12 | def __init__( 13 | self, 14 | dataset: str, 15 | splitBy: str = 'unc', 16 | refer_data_root: str = 'dataset/refer', 17 | pseudo_masks_data_root='outputs', 18 | image_transforms: object = None, 19 | target_transforms: object = None, 20 | split: str = 'train', 21 | bert_tokenizer: str = 'bert-base-uncased', 22 | COCO_image_root: str = None, 23 | eval_mode: bool = False, 24 | groundtruth_masks=False, 25 | one_sentence=False 26 | ): 27 | """Wrapper for the REFER class in refer (https://github.com/lichengunc/refer) for the case of unsupervised data 28 | (potentially including multiple masks per image). 29 | 30 | Args: 31 | dataset (str): dataset description (one of "refcoco", "refcoco+" or "refcocog") 32 | splitBy (str, optional): data split. Defaults to 'unc'. 33 | refer_data_root (str, optional): root of the REFER data. Defaults to 'dataset/refer'. 34 | pseudo_masks_data_root (str, optional): file location where the unsupervised masks are located. Defaults to 'outputs'. 35 | image_transforms (object, optional): transformations to be applied to input images. Defaults to None. 36 | target_transforms (object, optional): transformations to be applied to target masks. Defaults to None. 37 | split (str, optional): dataset split. Defaults to 'train'. 38 | bert_tokenizer (str, optional): type of BERT tokenizer. Defaults to 'bert-base-uncased'. 39 | COCO_image_root (str, optional): image root. Defaults to None. 40 | eval_mode (bool, optional): if True, load all sentences, otherwise chose only one. Defaults to False. 41 | return_attributes (bool, optional): whether additional data attributes should be returned. Defaults to False. 42 | """ 43 | 44 | self.classes = [] 45 | self.image_transforms = image_transforms 46 | self.target_transform = target_transforms 47 | self.split = split 48 | self.refer = REFER(refer_data_root, dataset, splitBy, image_root=COCO_image_root) 49 | self.pseudo_masks_data_root = os.path.join(pseudo_masks_data_root, dataset) 50 | if dataset == "refcocog": 51 | self.pseudo_masks_data_root += f"_{splitBy}" 52 | 53 | self.groundtruth_masks = groundtruth_masks 54 | self.one_sentence = one_sentence 55 | 56 | if groundtruth_masks: 57 | print("WARNING: you are using UnsupervisedReferDataset with groundtruth masks, which should only be used for testing...") 58 | 59 | self.max_tokens = 20 60 | 61 | ref_ids = self.refer.getRefIds(split=self.split) 62 | img_ids = self.refer.getImgIds(ref_ids) 63 | 64 | all_imgs = self.refer.Imgs 65 | self.imgs = list(all_imgs[i] for i in img_ids) 66 | self.ref_ids = ref_ids 67 | 68 | self.input_ids = {} 69 | self.attention_masks = {} 70 | self.tokenizer = BertTokenizer.from_pretrained(bert_tokenizer) 71 | 72 | self.eval_mode = eval_mode 73 | # if we are testing on a dataset, test all sentences of an object; 74 | # o/w, we are validating during training, randomly sample one sentence for efficiency 75 | for r in ref_ids: 76 | ref = self.refer.Refs[r] 77 | 78 | sentences_for_ref = [] 79 | attentions_for_ref = [] 80 | 81 | for i, (el, sent_id) in enumerate(zip(ref['sentences'], ref['sent_ids'])): 82 | sentence_raw = el['sent'] 83 | attention_mask = [0] * self.max_tokens 84 | padded_input_ids = [0] * self.max_tokens 85 | 86 | input_ids = self.tokenizer.encode(text=sentence_raw, add_special_tokens=True) 87 | 88 | # truncation of tokens 89 | input_ids = input_ids[:self.max_tokens] 90 | 91 | padded_input_ids[:len(input_ids)] = input_ids 92 | attention_mask[:len(input_ids)] = [1]*len(input_ids) 93 | 94 | sentences_for_ref.append(torch.tensor(padded_input_ids).unsqueeze(0)) 95 | attentions_for_ref.append(torch.tensor(attention_mask).unsqueeze(0)) 96 | 97 | self.input_ids[ref['ref_id']] = torch.vstack(sentences_for_ref) 98 | self.attention_masks[ref['ref_id']] = torch.vstack(attentions_for_ref) 99 | 100 | def get_classes(self): 101 | return self.classes 102 | 103 | def __len__(self): 104 | return len(self.imgs) 105 | 106 | def __getitem__(self, index): 107 | this_img = self.imgs[index] 108 | this_img_id = this_img['id'] 109 | 110 | # get the ids of the references corresponding to this image 111 | possible_refs = self.refer.imgToRefs[this_img_id] 112 | refs = [ref for ref in possible_refs if ref['ref_id'] in self.ref_ids] 113 | PIL_img = Image.open(os.path.join(self.refer.IMAGE_DIR, this_img['file_name'])).convert("RGB") 114 | if self.image_transforms is None: 115 | img = PIL_img 116 | 117 | target_per_ref = [] 118 | tensor_embeddings_per_ref = [] 119 | attention_mask_per_ref = [] 120 | attributes_per_ref = [] 121 | for ref in refs: 122 | ref_id = ref['ref_id'] 123 | 124 | if self.groundtruth_masks: 125 | # groundtruth masks 126 | ref_mask = np.array(self.refer.getMask(ref)['mask']) 127 | annot = np.zeros(ref_mask.shape) 128 | annot[ref_mask == 1] = 1 129 | 130 | annot = Image.fromarray(annot.astype(np.uint8), mode="P") 131 | else: 132 | # pseudo-gt masks 133 | annot = Image.open(os.path.join(self.pseudo_masks_data_root, f"pseudo_gt_mask_{ref_id}.png")).convert("P") 134 | 135 | target = annot 136 | 137 | if self.image_transforms is not None: 138 | # resize, from PIL to tensor, and mean and std normalization 139 | img, target = self.image_transforms(PIL_img, annot) 140 | 141 | # embedding = [] 142 | # att = [] 143 | # for e, a in zip(self.input_ids[ref_id], self.attention_masks[ref_id]): 144 | # embedding.append(e.unsqueeze(-1)) 145 | # att.append(a.unsqueeze(-1)) 146 | 147 | if not self.one_sentence: 148 | tensor_embeddings = self.input_ids[ref_id] 149 | attention_mask = self.attention_masks[ref_id] 150 | else: 151 | choice_sent = np.random.choice(len(self.input_ids[ref_id])) 152 | tensor_embeddings = self.input_ids[ref_id][choice_sent:choice_sent+1] 153 | attention_mask = self.attention_masks[ref_id][choice_sent:choice_sent+1] 154 | 155 | target_per_ref.append(target) 156 | tensor_embeddings_per_ref.append(tensor_embeddings) 157 | attention_mask_per_ref.append(attention_mask) 158 | attributes_per_ref.append({ 159 | "sentence_ids": [sent["sent_id"] for sent in ref["sentences"]], 160 | "sentences_raw": [sent["raw"] for sent in ref["sentences"]], 161 | "sentences_sent": [sent["sent"] for sent in ref["sentences"]], 162 | "ref_id": ref["ref_id"], 163 | "ann_id": ref["ann_id"], 164 | "ref": ref 165 | }) 166 | 167 | return img, target_per_ref, tensor_embeddings_per_ref, attention_mask_per_ref, attributes_per_ref 168 | 169 | @staticmethod 170 | def collate_fn(data_items): 171 | return tuple(zip(*data_items)) 172 | -------------------------------------------------------------------------------- /ssc_ris/refer_dataset/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Tuple 3 | 4 | import matplotlib.pyplot as plt 5 | import matplotlib.cm as cm 6 | import numpy as np 7 | 8 | from .dataset import ReferDataset 9 | from .unsupervised_dataset import UnsupervisedReferDataset 10 | import ssc_ris.utils.transforms as T 11 | 12 | # --- Fetching functions --- 13 | 14 | def get_dataset( 15 | dataset: str, 16 | dataset_root: str, 17 | data_split: str, 18 | transforms: object = None, 19 | split_by: str = 'unc', 20 | return_attributes: bool = True, 21 | eval_model: bool = False 22 | ) -> Tuple[ReferDataset, int]: 23 | """Obtains a ReferDataset object 24 | 25 | Args: 26 | dataset (str): dataset descriptior 27 | dataset_root (str): root for the dataset 28 | data_split (str): dataset split 29 | transforms (object, optional): transforms to apply to image data. Defaults to None. 30 | split_by (str, optional): RIS dataset splits ('unc', 'umd', 'google'). Defaults to 'unc'. 31 | return_attributes (bool, optional): if True, extra dataset attributes are returned. Defaults to True. 32 | eval_model (bool, optional): if True, all referring sentences are returned. Defaults to False. 33 | 34 | Returns: 35 | Tuple[ReferDataset, int]: dataset object and number of classes 36 | """ 37 | ds = ReferDataset( 38 | dataset, 39 | refer_data_root=dataset_root, 40 | splitBy=split_by, 41 | split=data_split, 42 | image_transforms=transforms, 43 | target_transforms=None, 44 | return_attributes=return_attributes, 45 | eval_mode=eval_model 46 | ) 47 | num_classes = 2 48 | 49 | return ds, num_classes 50 | 51 | def get_unsupervised_dataset( 52 | dataset: str, 53 | dataset_root: str, 54 | pseudo_masks_root: str, 55 | data_split: str, 56 | transforms: object = None, 57 | split_by: str = 'unc', 58 | one_sentence: bool = True 59 | ) -> Tuple[UnsupervisedReferDataset, int]: 60 | ds = UnsupervisedReferDataset( 61 | dataset, 62 | refer_data_root=dataset_root, 63 | pseudo_masks_data_root=pseudo_masks_root, 64 | splitBy=split_by, 65 | split=data_split, 66 | image_transforms=transforms, 67 | target_transforms=None, 68 | one_sentence=one_sentence 69 | ) 70 | num_classes = 2 71 | 72 | return ds, num_classes 73 | 74 | 75 | def get_transform(img_size: int) -> object: 76 | """Get the dataset transform. 77 | 78 | Args: 79 | img_size (int): size of the image 80 | 81 | Returns: 82 | object: composed transform 83 | """ 84 | transforms = [ 85 | T.Resize(img_size, img_size), 86 | T.ToTensor(), 87 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 88 | ] 89 | 90 | return T.Compose(transforms) 91 | 92 | # --- Saving unsupervised masks functions --- 93 | 94 | def save_binary_object_masks(object_masks: List[np.ndarray], ref_ids: List[str], output_mask_folder: str) -> None: 95 | """Save the list of binary masks. 96 | 97 | Args: 98 | object_masks (List[np.ndarray]): list of masks to save 99 | ref_ids (List[str]): ids of the masks to save 100 | output_mask_folder (str): folder where to save the masks 101 | """ 102 | for mask, ref_id in zip(object_masks, ref_ids): 103 | save_binary_object_mask(mask, ref_id, output_mask_folder) 104 | 105 | def save_binary_object_mask(mask: np.ndarray, ref_id: str, output_mask_folder: str) -> None: 106 | """Save the list of binary masks. 107 | 108 | Args: 109 | object_masks (List[np.ndarray]): list of masks to save 110 | ref_ids (List[str]): ids of the masks to save 111 | output_mask_folder (str): folder where to save the masks 112 | """ 113 | img_name = os.path.join(output_mask_folder, f"pseudo_gt_mask_{ref_id}.png") 114 | plt.imsave(img_name, mask, cmap=cm.gray) 115 | -------------------------------------------------------------------------------- /ssc_ris/segment/__init__.py: -------------------------------------------------------------------------------- 1 | from .segment_fns import * -------------------------------------------------------------------------------- /ssc_ris/segment/_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from typing import List 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def save_all_masks( 9 | image_path: str, 10 | binary_masks: np.ndarray, 11 | sentences: List[str], 12 | output_name: str = "mask_example.png" 13 | ) -> None: 14 | """Save all masks in a Matplotlib viz for debugging purposes. 15 | 16 | Args: 17 | image_path (str): input image location on file 18 | binary_masks (np.ndarray): array with all the binary masks 19 | sentences (List[str]): list of sentences that generated the binary masks 20 | output_name (str, optional): name of the outputted figure. Defaults to "mask_example.png". 21 | """ 22 | # 1 subplot for the image itself, 1 for the overall segmentation mask, and 1 per binary segmentation masks 23 | fig, axs = plt.subplots(1, 1+len(binary_masks), figsize=(4*(1+len(binary_masks)), 4)) 24 | 25 | img = Image.open(image_path) 26 | for ax in axs: 27 | ax.imshow(img) 28 | ax.axis('off') 29 | 30 | for i, mask in enumerate(binary_masks): 31 | axs[1+i].imshow(mask, alpha=0.4) 32 | axs[1+i].set_title(sentences[i]) 33 | 34 | plt.tight_layout() 35 | plt.savefig(output_name) 36 | 37 | -------------------------------------------------------------------------------- /ssc_ris/select/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip_sim import * -------------------------------------------------------------------------------- /ssc_ris/select/clip_sim.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from typing import List, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | import clip 7 | 8 | from .visual_prompting_fns import vp_rectangle, vp_red_ellipse, vp_dense_mask, vp_reverse_blur 9 | 10 | 11 | def vp_and_max_sim_clip_choice( 12 | clip_model: torch.nn.Module, 13 | preprocess: torch.nn.Module, 14 | image: Image, 15 | masks: List[np.ndarray], 16 | sentences: List[str], 17 | method: str = "reverse_blur", 18 | thickness: int = 4, 19 | margin: int = 4, 20 | alpha: float = 0.25, 21 | color: Tuple[int, int, int] = (255, 0, 0), 22 | std_dev: int = 50 23 | ) -> np.ndarray: 24 | """Visually prompt and select the mask with the highest CLIP similarity. 25 | 26 | Args: 27 | clip_model (torch.nn.Module): CLIP model 28 | preprocess (torch.nn.Module): CLIP pre-processing module 29 | image (Image): image 30 | masks (List[np.ndarray]): list of candidate masks 31 | sentences (List[str]): sentences to encode and compute similarity to 32 | method (str, optional): visual prompting method. Defaults to "reverse_blur". 33 | thickness (int, optional): "rectangle"/"red_ellipse" prompting parameter. Defaults to 4. 34 | margin (int, optional): "rectangle"/"red_ellipse" prompting parameter. Defaults to 4. 35 | alpha (float, optional): "red_dense_mask" prompting parameter. Defaults to 0.25. 36 | color (Tuple[int, int, int], optional): "red_dense_mask" prompting parameter. Defaults to (255, 0, 0). 37 | std_dev (int, optional): "reverse_blur" prompting parameter. Defaults to 50. 38 | 39 | Returns: 40 | np.ndarray: the highest similarity mask 41 | """ 42 | if len(masks) == 1: 43 | return masks[0] 44 | 45 | device = next(clip_model.parameters()).device 46 | with torch.no_grad(): 47 | sentences_embedding = clip_model.encode_text(clip.tokenize(sentences).to(device)) 48 | 49 | sentences_embedding /= sentences_embedding.norm(dim=-1, keepdim=True) 50 | 51 | open_cv_image = np.array(image) 52 | open_cv_image = open_cv_image[:, :, ::-1].copy() 53 | 54 | best_similarity = -np.Inf 55 | best_mask = None 56 | for mask in masks: 57 | if mask.sum() < 500: 58 | continue 59 | 60 | if method == "rectangle": 61 | annotated_img = vp_rectangle( 62 | open_cv_image, 63 | mask, 64 | thickness=thickness, 65 | margin=margin 66 | ) 67 | elif method == "red_ellipse": 68 | annotated_img = vp_red_ellipse( 69 | open_cv_image, 70 | mask, 71 | thickness=thickness, 72 | margin=margin 73 | ) 74 | elif method == "red_dense_mask": 75 | annotated_img = vp_dense_mask( 76 | open_cv_image, 77 | mask, 78 | alpha=alpha, 79 | color=color 80 | ) 81 | elif method == "reverse_blur": 82 | annotated_img = vp_reverse_blur( 83 | open_cv_image, 84 | mask, 85 | std_dev=std_dev 86 | ) 87 | else: 88 | raise NotImplemented 89 | 90 | # get the CLIP image embedding 91 | image_input = preprocess(Image.fromarray(annotated_img)).unsqueeze(0).to(device) 92 | with torch.no_grad(): 93 | image_features = clip_model.encode_image(image_input) 94 | 95 | image_features /= image_features.norm(dim=-1, keepdim=True) 96 | similarity = (100.0 * image_features @ sentences_embedding.T).mean() 97 | 98 | if similarity >= best_similarity: 99 | best_mask = mask 100 | best_similarity = similarity 101 | 102 | return best_mask 103 | -------------------------------------------------------------------------------- /ssc_ris/select/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def separate_instance_masks(instance_merged_mask: np.ndarray) -> List[np.ndarray]: 8 | """Given an image with all instance masks, split them into separate binary masks. 9 | 10 | Args: 11 | instance_merged_mask (np.ndarray): input n-dim mask 12 | 13 | Returns: 14 | List[np.ndarray]: binary mask 15 | """ 16 | instance_vals = np.unique(instance_merged_mask) 17 | return [ 18 | (instance_merged_mask == val).astype(np.int32) 19 | for val in instance_vals if val != 0.0 20 | ] 21 | 22 | def separate_instance_masks_torch(instance_merged_mask: torch.Tensor) -> List[torch.Tensor]: 23 | """Given an image with all instance masks, split them into separate binary masks. 24 | 25 | Args: 26 | instance_merged_mask (torch.Tensor): input n-dim mask 27 | 28 | Returns: 29 | List[torch.Tensor]: binary mask 30 | """ 31 | instance_vals = torch.unique(instance_merged_mask) 32 | return [ 33 | (instance_merged_mask == val).to(torch.int32) 34 | for val in instance_vals if val != 0.0 35 | ] 36 | -------------------------------------------------------------------------------- /ssc_ris/select/visual_prompting_fns.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | def vp_rectangle(original_image: np.ndarray, mask: np.ndarray, thickness: int = 4, margin: int = 4) -> np.ndarray: 8 | """Visual prompting using a red rectangle 9 | 10 | Args: 11 | original_image (np.ndarray): image 12 | mask (np.ndarray): dense segmentation mask 13 | thickness (int, optional): thickness of the rectangle. Defaults to 4. 14 | margin (int, optional): margin of the rectangle w.r.t. the bounding box of the mask. Defaults to 4. 15 | 16 | Returns: 17 | np.ndarray: annotated image 18 | """ 19 | mask_indices = np.where(mask == 1) 20 | start_corner = (mask_indices[1].min()-margin, mask_indices[0].min()-margin) 21 | end_corner = (mask_indices[1].max()+margin, mask_indices[0].max()+margin) 22 | 23 | annotated_img = cv2.rectangle(original_image.copy(), start_corner, end_corner, (0, 0, 255), thickness) 24 | return annotated_img 25 | 26 | def vp_red_ellipse(original_image: np.ndarray, mask: np.ndarray, thickness: int = 4, margin: int = 4) -> np.ndarray: 27 | """Visual prompting using a red ellipse 28 | 29 | Args: 30 | original_image (np.ndarray): image 31 | mask (np.ndarray): dense segmentation mask 32 | thickness (int, optional): thickness of the ellipse. Defaults to 4. 33 | margin (int, optional): margin of the ellipse w.r.t. the bounding box of the mask. Defaults to 4. 34 | 35 | Returns: 36 | np.ndarray: annotated image 37 | """ 38 | mask_indices = np.where(mask == 1) 39 | start_corner = (mask_indices[1].min()-margin, mask_indices[0].min()-margin) 40 | end_corner = (mask_indices[1].max()+margin, mask_indices[0].max()+margin) 41 | 42 | annotated_img = cv2.ellipse( 43 | original_image.copy(), 44 | np.array(((start_corner[0] + end_corner[0]) / 2, (start_corner[1] + end_corner[1]) / 2), dtype=np.int64), 45 | np.array(((end_corner[0] - start_corner[0]) / 2 + margin, (end_corner[1] - start_corner[1]) / 2 + margin), dtype=np.int64), 46 | 0, 47 | 0, 48 | 360, 49 | (0, 0, 255), 50 | thickness 51 | ) 52 | return annotated_img 53 | 54 | def vp_dense_mask(original_image: np.ndarray, mask: np.ndarray, alpha: float = 0.25, color: Tuple[float, float, float] = (255, 0, 0)) -> np.ndarray: 55 | """Visual prompting using a dense mask 56 | 57 | Args: 58 | original_image (np.ndarray): image 59 | mask (np.ndarray): dense segmentation mask 60 | alpha (float, optional): transparency of the mask. Defaults to 0.25. 61 | color (Tuple[float, float, float], optional): color of the mask. Defaults to (255, 0, 0). 62 | 63 | Returns: 64 | np.ndarray: annotated image 65 | """ 66 | color = color[::-1] 67 | colored_mask = np.expand_dims(mask, 0).repeat(3, axis=0) 68 | colored_mask = np.moveaxis(colored_mask, 0, -1) 69 | 70 | masked = np.ma.MaskedArray(original_image, mask=colored_mask, fill_value=color) 71 | image_overlay = masked.filled() 72 | 73 | annotated_img = cv2.addWeighted(original_image, 1 - alpha, image_overlay, alpha, 0) 74 | 75 | return annotated_img 76 | 77 | def vp_reverse_blur(original_image: np.ndarray, mask: np.ndarray, std_dev: int = 100) -> np.ndarray: 78 | """Visual prompting using the reverse blur mechanism from the Fine-Tuning paper 79 | 80 | Args: 81 | original_image (np.ndarray): image 82 | mask (np.ndarray): dense segmentation mask 83 | std_dev (int, optional): standard deviation of the Gaussian kernel used for noise. Defaults to 100. 84 | 85 | Returns: 86 | np.ndarray: annotated image 87 | """ 88 | blur_background = cv2.GaussianBlur(original_image.copy(), [0, 0], sigmaX=std_dev, sigmaY=std_dev) 89 | 90 | masked_blur = cv2.bitwise_and(blur_background, blur_background, mask=(255*(1-mask)).astype(np.uint8)) 91 | masked_object = cv2.bitwise_and(original_image, original_image, mask=(255*mask).astype(np.uint8)) 92 | annotated_img = masked_blur + masked_object 93 | 94 | return annotated_img -------------------------------------------------------------------------------- /ssc_ris/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from ._utils import * -------------------------------------------------------------------------------- /ssc_ris/utils/_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Tuple 3 | 4 | import torch 5 | import numpy as np 6 | import wandb 7 | 8 | from ssc_ris.refer_dataset.bert.modeling_bert import BertModel 9 | 10 | 11 | def IoU(pred: torch.Tensor, gt: torch.Tensor) -> Tuple[torch.tensor, torch.tensor, torch.tensor]: 12 | """Compute IoU. 13 | 14 | Args: 15 | pred (torch.Tensor): a 2 class segmentation model output 16 | gt (torch.Tensor): the groundtruth w.r.t. which we'll compute the IoU 17 | 18 | Returns: 19 | Tuple[torch.tensor, torch.tensor, torch.tensor]: iou, intersection and union 20 | """ 21 | pred = pred.argmax(1) 22 | 23 | intersection = torch.sum(torch.mul(pred, gt)) 24 | union = torch.sum(torch.add(pred, gt)) - intersection 25 | 26 | if intersection == 0 or union == 0: 27 | iou = 0 28 | else: 29 | iou = float(intersection) / float(union) 30 | 31 | return iou, intersection, union 32 | 33 | 34 | def get_sentence_from_tokens(token_sequence: torch.Tensor, bert_model: BertModel) -> str: 35 | """Given a token sequence encoded using bert_model, decode it as a string. 36 | 37 | Args: 38 | token_sequence (torch.Tensor): input token sequence 39 | bert_model (BertModel): Bert model used to encode it; must include a tokenizer. 40 | 41 | Returns: 42 | str: decoded sequence, stripped of special tokens 43 | """ 44 | actual_sentence = bert_model.tokenizer.decode(token_sequence) 45 | actual_sentence = actual_sentence.replace('[CLS]', '').replace('[PAD]', '').replace('[SEP]', '')[1:] 46 | actual_sentence = re.sub(" +", " ", actual_sentence) 47 | return actual_sentence 48 | 49 | 50 | def wandb_mask( 51 | image: np.ndarray, 52 | sentence: torch.Tensor, 53 | pred_mask: np.ndarray, 54 | true_mask: np.ndarray, 55 | matched_mask: np.ndarray, 56 | bert_model: BertModel 57 | ) -> wandb.Image: 58 | """Obtain a wandb image which includes the predicted mask, true mask as well as mask options. 59 | 60 | Args: 61 | image (np.ndarray): original image 62 | sentence (torch.Tensor): token sequence describing the original sentence 63 | pred_mask (np.ndarray): predicted mask 64 | true_mask (np.ndarray): ground-truth/pseudo-ground-truth mask 65 | matched_mask (np.ndarray): matched mask 66 | bert_model (BertModel): Bert model used to encode sentence 67 | 68 | Returns: 69 | wandb.Image: output wandb image for logging 70 | """ 71 | actual_sentence = get_sentence_from_tokens(sentence, bert_model) 72 | 73 | if matched_mask is None: 74 | return wandb.Image( 75 | image, 76 | masks={ 77 | "prediction" : {"mask_data" : pred_mask, "class_labels" : {1: "object"}}, 78 | "pseudo GT" : {"mask_data" : true_mask + 1, "class_labels" : {obj_int + 1: f"instance {obj_int}" for obj_int in range(1, 10)}} 79 | }, 80 | caption=actual_sentence 81 | ) 82 | else: 83 | # masks must be int64, can't be int32 84 | matched_mask = matched_mask.astype(np.int64) 85 | 86 | return wandb.Image( 87 | image, 88 | masks={ 89 | "prediction" : {"mask_data" : pred_mask, "class_labels" : {1: "pred object"}}, 90 | "matched mask": {"mask_data": matched_mask + 1, "class_labels": {2: "ice_object"}}, 91 | "mask options": {"mask_data" : true_mask + 2, "class_labels" : {obj_int + 2: f"instance {obj_int}" for obj_int in range(1, 10)}} 92 | }, 93 | caption=actual_sentence 94 | ) -------------------------------------------------------------------------------- /ssc_ris/utils/lavt_lib/_utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import functional as F 3 | 4 | from ssc_ris.refer_dataset.bert.modeling_bert import BertModel 5 | 6 | 7 | class _LAVTSimpleDecode(nn.Module): 8 | def __init__(self, backbone, classifier): 9 | super(_LAVTSimpleDecode, self).__init__() 10 | self.backbone = backbone 11 | self.classifier = classifier 12 | 13 | def forward(self, x, l_feats, l_mask): 14 | input_shape = x.shape[-2:] 15 | features = self.backbone(x, l_feats, l_mask) 16 | x_c1, x_c2, x_c3, x_c4 = features 17 | x = self.classifier(x_c4, x_c3, x_c2, x_c1) 18 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) 19 | 20 | return x 21 | 22 | 23 | class LAVT(_LAVTSimpleDecode): 24 | pass 25 | 26 | 27 | ############################################### 28 | # LAVT One: put BERT inside the overall model # 29 | ############################################### 30 | class _LAVTOneSimpleDecode(nn.Module): 31 | def __init__(self, backbone, classifier, args): 32 | super(_LAVTOneSimpleDecode, self).__init__() 33 | self.backbone = backbone 34 | self.classifier = classifier 35 | self.text_encoder = BertModel.from_pretrained(args.ck_bert) 36 | self.text_encoder.pooler = None 37 | 38 | def forward(self, x, text, l_mask): 39 | input_shape = x.shape[-2:] 40 | ### language inference ### 41 | l_feats = self.text_encoder(text, attention_mask=l_mask)[0] # (6, 10, 768) 42 | l_feats = l_feats.permute(0, 2, 1) # (B, 768, N_l) to make Conv1d happy 43 | l_mask = l_mask.unsqueeze(dim=-1) # (batch, N_l, 1) 44 | ########################## 45 | features = self.backbone(x, l_feats, l_mask) 46 | x_c1, x_c2, x_c3, x_c4 = features 47 | x = self.classifier(x_c4, x_c3, x_c2, x_c1) 48 | x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=True) 49 | 50 | return x 51 | 52 | 53 | class LAVTOne(_LAVTOneSimpleDecode): 54 | pass 55 | -------------------------------------------------------------------------------- /ssc_ris/utils/lavt_lib/lavt_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from collections import defaultdict, deque 3 | import datetime 4 | import math 5 | import time 6 | import torch 7 | import torch.distributed as dist 8 | import torch.backends.cudnn as cudnn 9 | 10 | import errno 11 | import os 12 | 13 | import sys 14 | 15 | 16 | class SmoothedValue(object): 17 | """Track a series of values and provide access to smoothed values over a 18 | window or the global series average. 19 | """ 20 | 21 | def __init__(self, window_size=20, fmt=None): 22 | if fmt is None: 23 | fmt = "{median:.4f} ({global_avg:.4f})" 24 | self.deque = deque(maxlen=window_size) 25 | self.total = 0.0 26 | self.count = 0 27 | self.fmt = fmt 28 | 29 | def update(self, value, n=1): 30 | self.deque.append(value) 31 | self.count += n 32 | self.total += value * n 33 | 34 | def synchronize_between_processes(self): 35 | """ 36 | Warning: does not synchronize the deque! 37 | """ 38 | if not is_dist_avail_and_initialized(): 39 | return 40 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 41 | dist.barrier() 42 | dist.all_reduce(t) 43 | t = t.tolist() 44 | self.count = int(t[0]) 45 | self.total = t[1] 46 | 47 | @property 48 | def median(self): 49 | d = torch.tensor(list(self.deque)) 50 | return d.median().item() 51 | 52 | @property 53 | def avg(self): 54 | d = torch.tensor(list(self.deque), dtype=torch.float32) 55 | return d.mean().item() 56 | 57 | @property 58 | def global_avg(self): 59 | return self.total / self.count 60 | 61 | @property 62 | def max(self): 63 | return max(self.deque) 64 | 65 | @property 66 | def value(self): 67 | return self.deque[-1] 68 | 69 | def __str__(self): 70 | return self.fmt.format( 71 | median=self.median, 72 | avg=self.avg, 73 | global_avg=self.global_avg, 74 | max=self.max, 75 | value=self.value) 76 | 77 | 78 | class MetricLogger(object): 79 | def __init__(self, delimiter="\t"): 80 | self.meters = defaultdict(SmoothedValue) 81 | self.delimiter = delimiter 82 | 83 | def update(self, **kwargs): 84 | for k, v in kwargs.items(): 85 | if isinstance(v, torch.Tensor): 86 | v = v.item() 87 | try: 88 | assert isinstance(v, (float, int)) 89 | except: 90 | print("v") 91 | import pdb 92 | pdb.set_trace() 93 | 94 | self.meters[k].update(v) 95 | 96 | def __getattr__(self, attr): 97 | if attr in self.meters: 98 | return self.meters[attr] 99 | if attr in self.__dict__: 100 | return self.__dict__[attr] 101 | raise AttributeError("'{}' object has no attribute '{}'".format( 102 | type(self).__name__, attr)) 103 | 104 | def __str__(self): 105 | loss_str = [] 106 | for name, meter in self.meters.items(): 107 | loss_str.append( 108 | "{}: {}".format(name, str(meter)) 109 | ) 110 | return self.delimiter.join(loss_str) 111 | 112 | def synchronize_between_processes(self): 113 | for meter in self.meters.values(): 114 | meter.synchronize_between_processes() 115 | 116 | def add_meter(self, name, meter): 117 | self.meters[name] = meter 118 | 119 | def log_every(self, iterable, print_freq, header=None, skip_first=False): 120 | i = 0 121 | if not header: 122 | header = '' 123 | start_time = time.time() 124 | end = time.time() 125 | iter_time = SmoothedValue(fmt='{avg:.4f}') 126 | data_time = SmoothedValue(fmt='{avg:.4f}') 127 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 128 | log_msg = self.delimiter.join([ 129 | header, 130 | '[{0' + space_fmt + '}/{1}]', 131 | 'eta: {eta}', 132 | '{meters}', 133 | 'time: {time}', 134 | 'data: {data}', 135 | 'max mem: {memory:.0f}' 136 | ]) 137 | MB = 1024.0 * 1024.0 138 | for obj in iterable: 139 | is_first_should_skip = (i == 0 if skip_first else False) 140 | if is_first_should_skip: 141 | i += 1 142 | continue 143 | 144 | data_time.update(time.time() - end) 145 | yield obj 146 | iter_time.update(time.time() - end) 147 | 148 | if i % print_freq == 0: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | print(log_msg.format( 152 | i, len(iterable), eta=eta_string, 153 | meters=str(self), 154 | time=str(iter_time), data=str(data_time), 155 | memory=torch.cuda.max_memory_allocated() / MB)) 156 | sys.stdout.flush() 157 | 158 | i += 1 159 | end = time.time() 160 | total_time = time.time() - start_time 161 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 162 | print('{} Total time: {}'.format(header, total_time_str)) 163 | 164 | 165 | def mkdir(path): 166 | try: 167 | os.makedirs(path) 168 | except OSError as e: 169 | if e.errno != errno.EEXIST: 170 | raise 171 | 172 | 173 | def setup_for_distributed(is_master): 174 | """ 175 | This function disables printing when not in master process 176 | """ 177 | import builtins as __builtin__ 178 | builtin_print = __builtin__.print 179 | 180 | def print(*args, **kwargs): 181 | force = kwargs.pop('force', False) 182 | if is_master or force: 183 | builtin_print(*args, **kwargs) 184 | 185 | __builtin__.print = print 186 | 187 | 188 | def is_dist_avail_and_initialized(): 189 | if not dist.is_available(): 190 | return False 191 | if not dist.is_initialized(): 192 | return False 193 | return True 194 | 195 | 196 | def get_world_size(): 197 | if not is_dist_avail_and_initialized(): 198 | return 1 199 | return dist.get_world_size() 200 | 201 | 202 | def get_rank(): 203 | if not is_dist_avail_and_initialized(): 204 | return 0 205 | return dist.get_rank() 206 | 207 | 208 | def is_main_process(): 209 | return get_rank() == 0 210 | 211 | 212 | def save_on_master(*args, **kwargs): 213 | if is_main_process(): 214 | torch.save(*args, **kwargs) 215 | 216 | 217 | def init_distributed_mode(args): 218 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 219 | rank = int(os.environ["RANK"]) 220 | world_size = int(os.environ['WORLD_SIZE']) 221 | print(f"RANK and WORLD_SIZE in environment: {rank}/{world_size}") 222 | else: 223 | rank = -1 224 | world_size = -1 225 | 226 | torch.cuda.set_device(args.local_rank) 227 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 228 | torch.distributed.barrier() 229 | setup_for_distributed(is_main_process()) 230 | 231 | if args.output_dir: 232 | mkdir(args.output_dir) 233 | if args.model_id: 234 | mkdir(os.path.join('./models/', args.model_id)) 235 | -------------------------------------------------------------------------------- /ssc_ris/utils/lavt_lib/mask_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from collections import OrderedDict 5 | 6 | 7 | class SimpleDecoding(nn.Module): 8 | def __init__(self, c4_dims, factor=2): 9 | super(SimpleDecoding, self).__init__() 10 | 11 | hidden_size = c4_dims//factor 12 | c4_size = c4_dims 13 | c3_size = c4_dims//(factor**1) 14 | c2_size = c4_dims//(factor**2) 15 | c1_size = c4_dims//(factor**3) 16 | 17 | self.conv1_4 = nn.Conv2d(c4_size+c3_size, hidden_size, 3, padding=1, bias=False) 18 | self.bn1_4 = nn.BatchNorm2d(hidden_size) 19 | self.relu1_4 = nn.ReLU() 20 | self.conv2_4 = nn.Conv2d(hidden_size, hidden_size, 3, padding=1, bias=False) 21 | self.bn2_4 = nn.BatchNorm2d(hidden_size) 22 | self.relu2_4 = nn.ReLU() 23 | 24 | self.conv1_3 = nn.Conv2d(hidden_size + c2_size, hidden_size, 3, padding=1, bias=False) 25 | self.bn1_3 = nn.BatchNorm2d(hidden_size) 26 | self.relu1_3 = nn.ReLU() 27 | self.conv2_3 = nn.Conv2d(hidden_size, hidden_size, 3, padding=1, bias=False) 28 | self.bn2_3 = nn.BatchNorm2d(hidden_size) 29 | self.relu2_3 = nn.ReLU() 30 | 31 | self.conv1_2 = nn.Conv2d(hidden_size + c1_size, hidden_size, 3, padding=1, bias=False) 32 | self.bn1_2 = nn.BatchNorm2d(hidden_size) 33 | self.relu1_2 = nn.ReLU() 34 | self.conv2_2 = nn.Conv2d(hidden_size, hidden_size, 3, padding=1, bias=False) 35 | self.bn2_2 = nn.BatchNorm2d(hidden_size) 36 | self.relu2_2 = nn.ReLU() 37 | 38 | self.conv1_1 = nn.Conv2d(hidden_size, 2, 1) 39 | 40 | def forward(self, x_c4, x_c3, x_c2, x_c1): 41 | # fuse Y4 and Y3 42 | if x_c4.size(-2) < x_c3.size(-2) or x_c4.size(-1) < x_c3.size(-1): 43 | x_c4 = F.interpolate(input=x_c4, size=(x_c3.size(-2), x_c3.size(-1)), mode='bilinear', align_corners=True) 44 | x = torch.cat([x_c4, x_c3], dim=1) 45 | x = self.conv1_4(x) 46 | x = self.bn1_4(x) 47 | x = self.relu1_4(x) 48 | x = self.conv2_4(x) 49 | x = self.bn2_4(x) 50 | x = self.relu2_4(x) 51 | # fuse top-down features and Y2 features 52 | if x.size(-2) < x_c2.size(-2) or x.size(-1) < x_c2.size(-1): 53 | x = F.interpolate(input=x, size=(x_c2.size(-2), x_c2.size(-1)), mode='bilinear', align_corners=True) 54 | x = torch.cat([x, x_c2], dim=1) 55 | x = self.conv1_3(x) 56 | x = self.bn1_3(x) 57 | x = self.relu1_3(x) 58 | x = self.conv2_3(x) 59 | x = self.bn2_3(x) 60 | x = self.relu2_3(x) 61 | # fuse top-down features and Y1 features 62 | if x.size(-2) < x_c1.size(-2) or x.size(-1) < x_c1.size(-1): 63 | x = F.interpolate(input=x, size=(x_c1.size(-2), x_c1.size(-1)), mode='bilinear', align_corners=True) 64 | x = torch.cat([x, x_c1], dim=1) 65 | x = self.conv1_2(x) 66 | x = self.bn1_2(x) 67 | x = self.relu1_2(x) 68 | x = self.conv2_2(x) 69 | x = self.bn2_2(x) 70 | x = self.relu2_2(x) 71 | 72 | return self.conv1_1(x) 73 | -------------------------------------------------------------------------------- /ssc_ris/utils/lavt_lib/mmcv_custom/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .checkpoint import load_checkpoint 4 | 5 | __all__ = ['load_checkpoint'] 6 | -------------------------------------------------------------------------------- /ssc_ris/utils/lavt_lib/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .mask_predictor import SimpleDecoding 4 | from .backbone import MultiModalSwinTransformer 5 | from ._utils import LAVT, LAVTOne 6 | 7 | __all__ = ['lavt', 'lavt_one'] 8 | 9 | 10 | # LAVT 11 | def _segm_lavt(pretrained, args): 12 | # initialize the SwinTransformer backbone with the specified version 13 | if args.swin_type == 'tiny': 14 | embed_dim = 96 15 | depths = [2, 2, 6, 2] 16 | num_heads = [3, 6, 12, 24] 17 | elif args.swin_type == 'small': 18 | embed_dim = 96 19 | depths = [2, 2, 18, 2] 20 | num_heads = [3, 6, 12, 24] 21 | elif args.swin_type == 'base': 22 | embed_dim = 128 23 | depths = [2, 2, 18, 2] 24 | num_heads = [4, 8, 16, 32] 25 | elif args.swin_type == 'large': 26 | embed_dim = 192 27 | depths = [2, 2, 18, 2] 28 | num_heads = [6, 12, 24, 48] 29 | else: 30 | assert False 31 | # args.window12 added for test.py because state_dict is loaded after model initialization 32 | if 'window12' in pretrained or args.window12: 33 | print('Window size 12!') 34 | window_size = 12 35 | else: 36 | window_size = 7 37 | 38 | if args.mha: 39 | mha = args.mha.split('-') # if non-empty, then ['a', 'b', 'c', 'd'] 40 | mha = [int(a) for a in mha] 41 | else: 42 | mha = [1, 1, 1, 1] 43 | 44 | out_indices = (0, 1, 2, 3) 45 | backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads, 46 | window_size=window_size, 47 | ape=False, drop_path_rate=0.3, patch_norm=True, 48 | out_indices=out_indices, 49 | use_checkpoint=False, num_heads_fusion=mha, 50 | fusion_drop=args.fusion_drop 51 | ) 52 | if pretrained: 53 | print('Initializing Multi-modal Swin Transformer weights from ' + pretrained) 54 | backbone.init_weights(pretrained=pretrained) 55 | else: 56 | print('Randomly initialize Multi-modal Swin Transformer weights.') 57 | backbone.init_weights() 58 | 59 | model_map = [SimpleDecoding, LAVT] 60 | 61 | classifier = model_map[0](8*embed_dim) 62 | base_model = model_map[1] 63 | 64 | model = base_model(backbone, classifier) 65 | return model 66 | 67 | 68 | def _load_model_lavt(pretrained, args): 69 | model = _segm_lavt(pretrained, args) 70 | return model 71 | 72 | 73 | def lavt(pretrained='', args=None): 74 | return _load_model_lavt(pretrained, args) 75 | 76 | 77 | ############################################### 78 | # LAVT One: put BERT inside the overall model # 79 | ############################################### 80 | def _segm_lavt_one(pretrained, args): 81 | # initialize the SwinTransformer backbone with the specified version 82 | if args.swin_type == 'tiny': 83 | embed_dim = 96 84 | depths = [2, 2, 6, 2] 85 | num_heads = [3, 6, 12, 24] 86 | elif args.swin_type == 'small': 87 | embed_dim = 96 88 | depths = [2, 2, 18, 2] 89 | num_heads = [3, 6, 12, 24] 90 | elif args.swin_type == 'base': 91 | embed_dim = 128 92 | depths = [2, 2, 18, 2] 93 | num_heads = [4, 8, 16, 32] 94 | elif args.swin_type == 'large': 95 | embed_dim = 192 96 | depths = [2, 2, 18, 2] 97 | num_heads = [6, 12, 24, 48] 98 | else: 99 | assert False 100 | # args.window12 added for test.py because state_dict is loaded after model initialization 101 | if 'window12' in pretrained or args.window12: 102 | print('Window size 12!') 103 | window_size = 12 104 | else: 105 | window_size = 7 106 | 107 | if args.mha: 108 | mha = args.mha.split('-') # if non-empty, then ['a', 'b', 'c', 'd'] 109 | mha = [int(a) for a in mha] 110 | else: 111 | mha = [1, 1, 1, 1] 112 | 113 | out_indices = (0, 1, 2, 3) 114 | backbone = MultiModalSwinTransformer(embed_dim=embed_dim, depths=depths, num_heads=num_heads, 115 | window_size=window_size, 116 | ape=False, drop_path_rate=0.3, patch_norm=True, 117 | out_indices=out_indices, 118 | use_checkpoint=False, num_heads_fusion=mha, 119 | fusion_drop=args.fusion_drop 120 | ) 121 | if pretrained: 122 | print('Initializing Multi-modal Swin Transformer weights from ' + pretrained) 123 | backbone.init_weights(pretrained=pretrained) 124 | else: 125 | print('Randomly initialize Multi-modal Swin Transformer weights.') 126 | backbone.init_weights() 127 | 128 | model_map = [SimpleDecoding, LAVTOne] 129 | 130 | classifier = model_map[0](8*embed_dim) 131 | base_model = model_map[1] 132 | 133 | model = base_model(backbone, classifier, args) 134 | return model 135 | 136 | 137 | def _load_model_lavt_one(pretrained, args): 138 | model = _segm_lavt_one(pretrained, args) 139 | return model 140 | 141 | 142 | def lavt_one(pretrained='', args=None): 143 | return _load_model_lavt_one(pretrained, args) 144 | -------------------------------------------------------------------------------- /ssc_ris/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import random 4 | 5 | import torch 6 | from torchvision import transforms as T 7 | from torchvision.transforms import functional as F 8 | 9 | 10 | def pad_if_smaller(img, size, fill=0): 11 | min_size = min(img.size) 12 | if min_size < size: 13 | ow, oh = img.size 14 | padh = size - oh if oh < size else 0 15 | padw = size - ow if ow < size else 0 16 | img = F.pad(img, (0, 0, padw, padh), fill=fill) 17 | return img 18 | 19 | 20 | class Compose(object): 21 | def __init__(self, transforms): 22 | self.transforms = transforms 23 | 24 | def __call__(self, image, target): 25 | for t in self.transforms: 26 | image, target = t(image, target) 27 | return image, target 28 | 29 | 30 | class Resize(object): 31 | def __init__(self, h, w): 32 | self.h = h 33 | self.w = w 34 | 35 | def __call__(self, image, target): 36 | image = F.resize(image, (self.h, self.w)) 37 | # If size is a sequence like (h, w), the output size will be matched to this. 38 | # If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio 39 | target = F.resize(target, (self.h, self.w), interpolation=Image.NEAREST) 40 | return image, target 41 | 42 | 43 | class RandomResize(object): 44 | def __init__(self, min_size, max_size=None): 45 | self.min_size = min_size 46 | if max_size is None: 47 | max_size = min_size 48 | self.max_size = max_size 49 | 50 | def __call__(self, image, target): 51 | size = random.randint(self.min_size, self.max_size) # Return a random integer N such that a <= N <= b. Alias for randrange(a, b+1) 52 | image = F.resize(image, size) 53 | # If size is a sequence like (h, w), the output size will be matched to this. 54 | # If size is an int, the smaller edge of the image will be matched to this number maintaining the aspect ratio 55 | target = F.resize(target, size, interpolation=Image.NEAREST) 56 | return image, target 57 | 58 | 59 | class RandomHorizontalFlip(object): 60 | def __init__(self, flip_prob): 61 | self.flip_prob = flip_prob 62 | 63 | def __call__(self, image, target): 64 | if random.random() < self.flip_prob: 65 | image = F.hflip(image) 66 | target = F.hflip(target) 67 | return image, target 68 | 69 | 70 | class RandomCrop(object): 71 | def __init__(self, size): 72 | self.size = size 73 | 74 | def __call__(self, image, target): 75 | image = pad_if_smaller(image, self.size) 76 | target = pad_if_smaller(target, self.size, fill=255) 77 | crop_params = T.RandomCrop.get_params(image, (self.size, self.size)) 78 | image = F.crop(image, *crop_params) 79 | target = F.crop(target, *crop_params) 80 | return image, target 81 | 82 | 83 | class CenterCrop(object): 84 | def __init__(self, size): 85 | self.size = size 86 | 87 | def __call__(self, image, target): 88 | image = F.center_crop(image, self.size) 89 | target = F.center_crop(target, self.size) 90 | return image, target 91 | 92 | 93 | class ToTensor(object): 94 | def __call__(self, image, target): 95 | image = F.to_tensor(image) 96 | target = torch.as_tensor(np.asarray(target).copy(), dtype=torch.int64) 97 | return image, target 98 | 99 | 100 | class RandomAffine(object): 101 | def __init__(self, angle, translate, scale, shear, resample=0, fillcolor=None): 102 | self.angle = angle 103 | self.translate = translate 104 | self.scale = scale 105 | self.shear = shear 106 | self.resample = resample 107 | self.fillcolor = fillcolor 108 | 109 | def __call__(self, image, target): 110 | affine_params = T.RandomAffine.get_params(self.angle, self.translate, self.scale, self.shear, image.size) 111 | image = F.affine(image, *affine_params) 112 | target = F.affine(target, *affine_params) 113 | return image, target 114 | 115 | 116 | class Normalize(object): 117 | def __init__(self, mean, std): 118 | self.mean = mean 119 | self.std = std 120 | 121 | def __call__(self, image, target): 122 | image = F.normalize(image, mean=self.mean, std=self.std) 123 | return image, target 124 | 125 | --------------------------------------------------------------------------------