├── OD ├── CLIP │ ├── clip │ │ ├── __init__.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── simple_tokenizer.py │ │ ├── mask.py │ │ ├── clip.py │ │ └── model.py │ ├── run_install.sh │ ├── MANIFEST.in │ ├── requirements.txt │ ├── CLIP.png │ ├── dist │ │ └── clip-1.0-py3.8.egg │ ├── .gitignore │ ├── setup.py │ ├── data │ │ ├── rendered-sst2.md │ │ ├── yfcc100m.md │ │ └── country211.md │ ├── tests │ │ └── test_consistency.py │ ├── .github │ │ └── workflows │ │ │ └── test.yml │ ├── LICENSE │ ├── hubconf.py │ ├── model-card.md │ └── README.md ├── .DS_Store ├── modeling │ ├── .DS_Store │ ├── __init__.py │ ├── .ipynb_checkpoints │ │ ├── __init__-checkpoint.py │ │ ├── config-checkpoint.py │ │ ├── box_predictor-checkpoint.py │ │ ├── custom_pascal_evaluation-checkpoint.py │ │ ├── regularization-checkpoint.py │ │ ├── backbone-checkpoint.py │ │ ├── rpn-checkpoint.py │ │ ├── clip-checkpoint.py │ │ ├── roi_head-checkpoint.py │ │ └── meta_arch-checkpoint.py │ ├── config.py │ ├── box_predictor.py │ ├── custom_pascal_evaluation.py │ ├── regularization.py │ ├── backbone.py │ ├── rpn.py │ ├── clip.py │ ├── roi_head.py │ └── meta_arch.py ├── datasets │ └── diverseWeather │ │ └── .DS_Store ├── data │ └── datasets │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── builtin.cpython-36.pyc │ │ ├── builtin.cpython-37.pyc │ │ ├── builtin.cpython-38.pyc │ │ ├── diverse_weather.cpython-36.pyc │ │ ├── diverse_weather.cpython-37.pyc │ │ ├── diverse_weather.cpython-38.pyc │ │ ├── pascal_voc_adaptation.cpython-36.pyc │ │ ├── pascal_voc_adaptation.cpython-37.pyc │ │ ├── pascal_voc_adaptation.cpython-38.pyc │ │ ├── comic_water_adaptation.cpython-36.pyc │ │ ├── comic_water_adaptation.cpython-37.pyc │ │ └── comic_water_adaptation.cpython-38.pyc │ │ ├── __init__.py │ │ ├── builtin.py │ │ ├── pascal_voc_adaptation.py │ │ └── diverse_weather.py ├── prunedprompts.txt ├── configs │ ├── diverse_weather_dusk_rainy_test.yaml │ ├── diverse_weather_foggy_test.yaml │ ├── diverse_weather_night_rainy_test.yaml │ ├── diverse_weather_night_sunny_test.yaml │ └── diverse_weather.yaml └── train.py ├── Framework.png ├── Introduction.jpg ├── Introduction.png └── README.md /OD/CLIP/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /OD/CLIP/run_install.sh: -------------------------------------------------------------------------------- 1 | python setup.py develop -------------------------------------------------------------------------------- /OD/CLIP/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include clip/bpe_simple_vocab_16e6.txt.gz 2 | -------------------------------------------------------------------------------- /Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/Framework.png -------------------------------------------------------------------------------- /OD/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/.DS_Store -------------------------------------------------------------------------------- /OD/CLIP/requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | torch 5 | torchvision 6 | -------------------------------------------------------------------------------- /Introduction.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/Introduction.jpg -------------------------------------------------------------------------------- /Introduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/Introduction.png -------------------------------------------------------------------------------- /OD/CLIP/CLIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/CLIP/CLIP.png -------------------------------------------------------------------------------- /OD/modeling/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/modeling/.DS_Store -------------------------------------------------------------------------------- /OD/CLIP/dist/clip-1.0-py3.8.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/CLIP/dist/clip-1.0-py3.8.egg -------------------------------------------------------------------------------- /OD/datasets/diverseWeather/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/datasets/diverseWeather/.DS_Store -------------------------------------------------------------------------------- /OD/CLIP/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/CLIP/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/builtin.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/builtin.cpython-36.pyc -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/builtin.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/builtin.cpython-37.pyc -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/builtin.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/builtin.cpython-38.pyc -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/diverse_weather.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/diverse_weather.cpython-36.pyc -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/diverse_weather.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/diverse_weather.cpython-37.pyc -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/diverse_weather.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/diverse_weather.cpython-38.pyc -------------------------------------------------------------------------------- /OD/CLIP/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.egg-info 5 | .pytest_cache 6 | .ipynb_checkpoints 7 | 8 | thumbs.db 9 | .DS_Store 10 | .idea 11 | -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/pascal_voc_adaptation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/pascal_voc_adaptation.cpython-36.pyc -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/pascal_voc_adaptation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/pascal_voc_adaptation.cpython-37.pyc -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/pascal_voc_adaptation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/pascal_voc_adaptation.cpython-38.pyc -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/comic_water_adaptation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/comic_water_adaptation.cpython-36.pyc -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/comic_water_adaptation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/comic_water_adaptation.cpython-37.pyc -------------------------------------------------------------------------------- /OD/data/datasets/__pycache__/comic_water_adaptation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/comic_water_adaptation.cpython-38.pyc -------------------------------------------------------------------------------- /OD/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from . import builtin # ensure the builtin datasets are registered 4 | 5 | __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")] 6 | -------------------------------------------------------------------------------- /OD/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .rpn import SBRPN 2 | from .backbone import ClipRN101 3 | from .meta_arch import ClipRCNNWithClipBackbone 4 | from .roi_head import ClipRes5ROIHeads 5 | from .config import add_stn_config 6 | from .custom_pascal_evaluation import CustomPascalVOCDetectionEvaluator -------------------------------------------------------------------------------- /OD/modeling/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .rpn import SBRPN 2 | from .backbone import ClipRN101 3 | from .meta_arch import ClipRCNNWithClipBackbone 4 | from .roi_head import ClipRes5ROIHeads 5 | from .config import add_stn_config 6 | from .custom_pascal_evaluation import CustomPascalVOCDetectionEvaluator -------------------------------------------------------------------------------- /OD/prunedprompts.txt: -------------------------------------------------------------------------------- 1 | an image taken on a snow night 2 | an image taken on a fog night 3 | an image taken on a cloudy night 4 | an image taken on a rain night 5 | an image taken on a stormy night 6 | an image taken on a snow day 7 | an image taken on a fog day 8 | an image taken on a cloudy day 9 | an image taken on a rain day 10 | an image taken on a stormy day 11 | an image taken on a snow evening 12 | an image taken on a fog evening 13 | an image taken on a cloudy evening 14 | an image taken on a rain evening 15 | an image taken on a stormy evening 16 | -------------------------------------------------------------------------------- /OD/data/datasets/builtin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from .diverse_weather import register_dataset as register_diverse_weather 4 | from .pascal_voc_adaptation import register_all_pascal_voc as register_pascal_voc 5 | from .comic_water_adaptation import register_dataset as register_comic_water 6 | import os 7 | 8 | _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets")) 9 | DEFAULT_DATASETS_ROOT = "data/" 10 | 11 | 12 | register_diverse_weather(_root) 13 | register_pascal_voc(_root) 14 | register_comic_water(_root) 15 | -------------------------------------------------------------------------------- /OD/CLIP/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="clip", 8 | py_modules=["clip"], 9 | version="1.0", 10 | description="", 11 | author="OpenAI", 12 | packages=find_packages(exclude=["tests*"]), 13 | install_requires=[ 14 | str(r) 15 | for r in pkg_resources.parse_requirements( 16 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 17 | ) 18 | ], 19 | include_package_data=True, 20 | extras_require={'dev': ['pytest']}, 21 | ) 22 | -------------------------------------------------------------------------------- /OD/CLIP/data/rendered-sst2.md: -------------------------------------------------------------------------------- 1 | # The Rendered SST2 Dataset 2 | 3 | In the paper, we used an image classification dataset called Rendered SST2, to evaluate the model's capability on optical character recognition. To do so, we rendered the sentences in the [Standford Sentiment Treebank v2](https://nlp.stanford.edu/sentiment/treebank.html) dataset and used those as the input to the CLIP image encoder. 4 | 5 | The following command will download a 131MB archive countaining the images and extract into a subdirectory `rendered-sst2`: 6 | 7 | ```bash 8 | wget https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz 9 | tar zxvf rendered-sst2.tgz 10 | ``` 11 | 12 | -------------------------------------------------------------------------------- /OD/modeling/config.py: -------------------------------------------------------------------------------- 1 | from detectron2.config import CfgNode as CN 2 | 3 | def add_stn_config(cfg): 4 | cfg.OFFSET_DOMAIN = '' 5 | cfg.OFFSET_FROZENBN = False 6 | cfg.OFFSET_DOMAIN_TEXT = '' 7 | cfg.OFFSET_NAME = 0 8 | cfg.OFFSET_OPT_INTERVAL = [10] 9 | cfg.OFFSET_OPT_ITERS = 0 10 | cfg.AUG_PROB = 0.5 11 | cfg.DOMAIN_NAME = "" 12 | cfg.TEST.EVAL_SAVE_PERIOD = 5000 13 | cfg.INPUT.CLIP_WITH_IMG = False 14 | cfg.INPUT.CLIP_RANDOM_CROPS = False 15 | cfg.INPUT.IMAGE_JITTER = False 16 | cfg.INPUT.RANDOM_CROP_SIZE = 224 17 | cfg.MODEL.GLOBAL_GND = False 18 | cfg.BASE_YAML = "COCO-Detection/faster_rcnn_R_50_C4_3x.yaml" 19 | cfg.MODEL.RENAME = list() 20 | cfg.MODEL.CLIP_IMAGE_ENCODER_NAME = 'ViT-B/32' 21 | cfg.MODEL.BACKBONE.UNFREEZE = ['layer3','layer4','attnpool'] 22 | cfg.MODEL.USE_PROJ = True 23 | -------------------------------------------------------------------------------- /OD/modeling/.ipynb_checkpoints/config-checkpoint.py: -------------------------------------------------------------------------------- 1 | from detectron2.config import CfgNode as CN 2 | 3 | def add_stn_config(cfg): 4 | cfg.OFFSET_DOMAIN = '' 5 | cfg.OFFSET_FROZENBN = False 6 | cfg.OFFSET_DOMAIN_TEXT = '' 7 | cfg.OFFSET_NAME = 0 8 | cfg.OFFSET_OPT_INTERVAL = [10] 9 | cfg.OFFSET_OPT_ITERS = 0 10 | cfg.AUG_PROB = 0.5 11 | cfg.DOMAIN_NAME = "" 12 | cfg.TEST.EVAL_SAVE_PERIOD = 5000 13 | cfg.INPUT.CLIP_WITH_IMG = False 14 | cfg.INPUT.CLIP_RANDOM_CROPS = False 15 | cfg.INPUT.IMAGE_JITTER = False 16 | cfg.INPUT.RANDOM_CROP_SIZE = 224 17 | cfg.MODEL.GLOBAL_GND = False 18 | cfg.BASE_YAML = "COCO-Detection/faster_rcnn_R_50_C4_3x.yaml" 19 | cfg.MODEL.RENAME = list() 20 | cfg.MODEL.CLIP_IMAGE_ENCODER_NAME = 'ViT-B/32' 21 | cfg.MODEL.BACKBONE.UNFREEZE = ['layer3','layer4','attnpool'] 22 | cfg.MODEL.USE_PROJ = True 23 | -------------------------------------------------------------------------------- /OD/CLIP/tests/test_consistency.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from PIL import Image 5 | 6 | import clip 7 | 8 | 9 | @pytest.mark.parametrize('model_name', clip.available_models()) 10 | def test_consistency(model_name): 11 | device = "cpu" 12 | jit_model, transform = clip.load(model_name, device=device, jit=True) 13 | py_model, _ = clip.load(model_name, device=device, jit=False) 14 | 15 | image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device) 16 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) 17 | 18 | with torch.no_grad(): 19 | logits_per_image, _ = jit_model(image, text) 20 | jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy() 21 | 22 | logits_per_image, _ = py_model(image, text) 23 | py_probs = logits_per_image.softmax(dim=-1).cpu().numpy() 24 | 25 | assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1) 26 | -------------------------------------------------------------------------------- /OD/modeling/box_predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | import torch 3 | 4 | from detectron2.layers import cat, cross_entropy 5 | from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers 6 | from .clip import ClipPredictor 7 | 8 | class ClipFastRCNNOutputLayers(FastRCNNOutputLayers): 9 | 10 | def __init__(self,cfg, input_shape, clsnames) -> None: 11 | super().__init__(cfg, input_shape) 12 | self.cls_score = ClipPredictor(cfg.MODEL.CLIP_IMAGE_ENCODER_NAME, input_shape.channels, cfg.MODEL.DEVICE,clsnames) 13 | def forward(self,x,gfeat=None): 14 | 15 | if isinstance(x,list): 16 | scores = self.cls_score(x[0],gfeat) 17 | proposal_deltas = self.bbox_pred(x[1]) 18 | else: 19 | scores = self.cls_score(x,gfeat) 20 | proposal_deltas = self.bbox_pred(x) 21 | 22 | return scores, proposal_deltas 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /OD/modeling/.ipynb_checkpoints/box_predictor-checkpoint.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | import torch 3 | 4 | from detectron2.layers import cat, cross_entropy 5 | from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers 6 | from .clip import ClipPredictor 7 | 8 | class ClipFastRCNNOutputLayers(FastRCNNOutputLayers): 9 | 10 | def __init__(self,cfg, input_shape, clsnames) -> None: 11 | super().__init__(cfg, input_shape) 12 | self.cls_score = ClipPredictor(cfg.MODEL.CLIP_IMAGE_ENCODER_NAME, input_shape.channels, cfg.MODEL.DEVICE,clsnames) 13 | def forward(self,x,gfeat=None): 14 | 15 | if isinstance(x,list): 16 | scores = self.cls_score(x[0],gfeat) 17 | proposal_deltas = self.bbox_pred(x[1]) 18 | else: 19 | scores = self.cls_score(x,gfeat) 20 | proposal_deltas = self.bbox_pred(x) 21 | 22 | return scores, proposal_deltas 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /OD/CLIP/data/yfcc100m.md: -------------------------------------------------------------------------------- 1 | # The YFCC100M Subset 2 | 3 | In the paper, we performed a dataset ablation using a subset of the YFCC100M dataset and showed that the performance remained largely similar. 4 | 5 | The subset contains 14,829,396 images, about 15% of the full dataset, which have been filtered to only keep those with natural languag titles and/or descriptions in English. 6 | 7 | We provide the list of (line number, photo identifier, photo hash) of each image contained in this subset. These correspond to the first three columns in the dataset's metadata TSV file. 8 | 9 | ```bash 10 | wget https://openaipublic.azureedge.net/clip/data/yfcc100m_subset_data.tsv.bz2 11 | bunzip2 yfcc100m_subset_data.tsv.bz2 12 | ``` 13 | 14 | Use of the underlying media files is subject to the Creative Commons licenses chosen by their creators/uploaders. For more information about the YFCC100M dataset, visit [the official website](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/). -------------------------------------------------------------------------------- /OD/CLIP/data/country211.md: -------------------------------------------------------------------------------- 1 | # The Country211 Dataset 2 | 3 | In the paper, we used an image classification dataset called Country211, to evaluate the model's capability on geolocation. To do so, we filtered the YFCC100m dataset that have GPS coordinate corresponding to a [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes) and created a balanced dataset by sampling 150 train images, 50 validation images, and 100 test images images for each country. 4 | 5 | The following command will download an 11GB archive countaining the images and extract into a subdirectory `country211`: 6 | 7 | ```bash 8 | wget https://openaipublic.azureedge.net/clip/data/country211.tgz 9 | tar zxvf country211.tgz 10 | ``` 11 | 12 | These images are a subset of the YFCC100m dataset. Use of the underlying media files is subject to the Creative Commons licenses chosen by their creators/uploaders. For more information about the YFCC100M dataset, visit [the official website](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/). -------------------------------------------------------------------------------- /OD/configs/diverse_weather_dusk_rainy_test.yaml: -------------------------------------------------------------------------------- 1 | BASE_YAML: "COCO-Detection/faster_rcnn_R_101_C4_3x.yaml" 2 | DATASETS: 3 | TRAIN: ("daytime_clear_train",) 4 | TEST: ('dusk_rainy_train',) 5 | DATALOADER: 6 | NUM_WORKERS: 16 7 | INPUT: 8 | MIN_SIZE_TRAIN: (600,) 9 | MIN_SIZE_TEST: 600 10 | CLIP_RANDOM_CROPS: True 11 | RANDOM_CROP_SIZE: 400 12 | 13 | SOLVER: 14 | BASE_LR: 0.001 15 | MAX_ITER: 100000 16 | STEPS: [40000,] 17 | WARMUP_ITERS: 0 18 | IMS_PER_BATCH: 4 19 | CHECKPOINT_PERIOD: 1000000 20 | MODEL: 21 | BACKBONE: 22 | NAME: ClipRN101 23 | WEIGHTS: "" 24 | CLIP_IMAGE_ENCODER_NAME: 'RN101' 25 | META_ARCHITECTURE: 'ClipRCNNWithClipBackboneTrainable' 26 | 27 | PROPOSAL_GENERATOR: 28 | NAME: 'SBRPN' 29 | ROI_HEADS: 30 | NAME: 'ClipRes5ROIHeadsAttn' 31 | NUM_CLASSES: 7 32 | TEST: 33 | EVAL_SAVE_PERIOD: 5000 34 | OUTPUT_DIR: "all_outs/diverse_weather_dusk_rainy" 35 | VIS_PERIOD: 5000 36 | -------------------------------------------------------------------------------- /OD/configs/diverse_weather_foggy_test.yaml: -------------------------------------------------------------------------------- 1 | BASE_YAML: "COCO-Detection/faster_rcnn_R_101_C4_3x.yaml" 2 | DATASETS: 3 | TRAIN: ("daytime_clear_train",) 4 | TEST: ('daytime_foggy_train',) 5 | DATALOADER: 6 | NUM_WORKERS: 16 7 | INPUT: 8 | MIN_SIZE_TRAIN: (600,) 9 | MIN_SIZE_TEST: 600 10 | CLIP_RANDOM_CROPS: True 11 | RANDOM_CROP_SIZE: 400 12 | 13 | SOLVER: 14 | BASE_LR: 0.001 15 | MAX_ITER: 100000 16 | STEPS: [40000,] 17 | WARMUP_ITERS: 0 18 | IMS_PER_BATCH: 4 19 | CHECKPOINT_PERIOD: 1000000 20 | MODEL: 21 | BACKBONE: 22 | NAME: ClipRN101 23 | WEIGHTS: "" 24 | CLIP_IMAGE_ENCODER_NAME: 'RN101' 25 | META_ARCHITECTURE: 'ClipRCNNWithClipBackboneTrainable' 26 | 27 | PROPOSAL_GENERATOR: 28 | NAME: 'SBRPN' 29 | ROI_HEADS: 30 | NAME: 'ClipRes5ROIHeadsAttn' 31 | NUM_CLASSES: 7 32 | TEST: 33 | EVAL_SAVE_PERIOD: 5000 34 | OUTPUT_DIR: "all_outs/diverse_weather_foggy" 35 | VIS_PERIOD: 5000 36 | 37 | -------------------------------------------------------------------------------- /OD/configs/diverse_weather_night_rainy_test.yaml: -------------------------------------------------------------------------------- 1 | BASE_YAML: "COCO-Detection/faster_rcnn_R_101_C4_3x.yaml" 2 | DATASETS: 3 | TRAIN: ("daytime_clear_train",) 4 | TEST: ('night_rainy_train',) 5 | DATALOADER: 6 | NUM_WORKERS: 16 7 | INPUT: 8 | MIN_SIZE_TRAIN: (600,) 9 | MIN_SIZE_TEST: 600 10 | CLIP_RANDOM_CROPS: True 11 | RANDOM_CROP_SIZE: 400 12 | 13 | SOLVER: 14 | BASE_LR: 0.001 15 | MAX_ITER: 100000 16 | STEPS: [40000,] 17 | WARMUP_ITERS: 0 18 | IMS_PER_BATCH: 4 19 | CHECKPOINT_PERIOD: 1000000 20 | MODEL: 21 | BACKBONE: 22 | NAME: ClipRN101 23 | WEIGHTS: "" 24 | CLIP_IMAGE_ENCODER_NAME: 'RN101' 25 | META_ARCHITECTURE: 'ClipRCNNWithClipBackboneTrainable' 26 | 27 | PROPOSAL_GENERATOR: 28 | NAME: 'SBRPN' 29 | ROI_HEADS: 30 | NAME: 'ClipRes5ROIHeadsAttn' 31 | NUM_CLASSES: 7 32 | TEST: 33 | EVAL_SAVE_PERIOD: 5000 34 | OUTPUT_DIR: "all_outs/diverse_weather_night_rainy" 35 | VIS_PERIOD: 5000 36 | 37 | -------------------------------------------------------------------------------- /OD/configs/diverse_weather_night_sunny_test.yaml: -------------------------------------------------------------------------------- 1 | BASE_YAML: "COCO-Detection/faster_rcnn_R_101_C4_3x.yaml" 2 | DATASETS: 3 | TRAIN: ("daytime_clear_train",) 4 | TEST: ('night_sunny_train',) 5 | DATALOADER: 6 | NUM_WORKERS: 16 7 | INPUT: 8 | MIN_SIZE_TRAIN: (600,) 9 | MIN_SIZE_TEST: 600 10 | CLIP_RANDOM_CROPS: True 11 | RANDOM_CROP_SIZE: 400 12 | 13 | SOLVER: 14 | BASE_LR: 0.001 15 | MAX_ITER: 100000 16 | STEPS: [40000,] 17 | WARMUP_ITERS: 0 18 | IMS_PER_BATCH: 4 19 | CHECKPOINT_PERIOD: 1000000 20 | MODEL: 21 | BACKBONE: 22 | NAME: ClipRN101 23 | WEIGHTS: "" 24 | CLIP_IMAGE_ENCODER_NAME: 'RN101' 25 | META_ARCHITECTURE: 'ClipRCNNWithClipBackboneTrainable' 26 | 27 | PROPOSAL_GENERATOR: 28 | NAME: 'SBRPN' 29 | ROI_HEADS: 30 | NAME: 'ClipRes5ROIHeadsAttn' 31 | NUM_CLASSES: 7 32 | TEST: 33 | EVAL_SAVE_PERIOD: 5000 34 | OUTPUT_DIR: "all_outs/diverse_weather_night_sunny" 35 | VIS_PERIOD: 5000 36 | 37 | -------------------------------------------------------------------------------- /OD/CLIP/.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - main 9 | jobs: 10 | CLIP-test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: [3.8] 15 | pytorch-version: [1.7.1, 1.9.1, 1.10.1] 16 | include: 17 | - python-version: 3.8 18 | pytorch-version: 1.7.1 19 | torchvision-version: 0.8.2 20 | - python-version: 3.8 21 | pytorch-version: 1.9.1 22 | torchvision-version: 0.10.1 23 | - python-version: 3.8 24 | pytorch-version: 1.10.1 25 | torchvision-version: 0.11.2 26 | steps: 27 | - uses: conda-incubator/setup-miniconda@v2 28 | - run: conda install -n test python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} torchvision=${{ matrix.torchvision-version }} cpuonly -c pytorch 29 | - uses: actions/checkout@v2 30 | - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH 31 | - run: pip install pytest 32 | - run: pip install . 33 | - run: pytest 34 | -------------------------------------------------------------------------------- /OD/CLIP/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /OD/configs/diverse_weather.yaml: -------------------------------------------------------------------------------- 1 | BASE_YAML: "COCO-Detection/faster_rcnn_R_101_C4_3x.yaml" 2 | DATASETS: 3 | TRAIN: ("daytime_clear_train",) 4 | TEST: ('daytime_clear_test',) 5 | DATALOADER: 6 | NUM_WORKERS: 16 7 | ASPECT_RATIO_GROUPING: True 8 | INPUT: 9 | MIN_SIZE_TRAIN: (600,) 10 | MIN_SIZE_TEST: 600 11 | CLIP_RANDOM_CROPS: True 12 | RANDOM_CROP_SIZE: 400 13 | 14 | SOLVER: 15 | BASE_LR: 0.001 16 | MAX_ITER: 100000 17 | STEPS: [40000,] 18 | WARMUP_ITERS: 0 19 | IMS_PER_BATCH: 4 20 | CHECKPOINT_PERIOD: 1000000 21 | MODEL: 22 | BACKBONE: 23 | NAME: ClipRN101 24 | FREEZE_AT: 2 25 | UNFREEZE: 26 | - layer3 27 | - layer4 28 | - attnpool 29 | WEIGHTS: "" 30 | CLIP_IMAGE_ENCODER_NAME: 'RN101' 31 | META_ARCHITECTURE: 'ClipRCNNWithClipBackboneTrainable' 32 | 33 | PROPOSAL_GENERATOR: 34 | NAME: 'SBRPN' 35 | ROI_HEADS: 36 | NAME: 'ClipRes5ROIHeadsAttn' 37 | NUM_CLASSES: 7 38 | TEST: 39 | EVAL_SAVE_PERIOD: 5000 40 | 41 | OUTPUT_DIR: "all_outs/diverse_weather" 42 | VIS_PERIOD: 5000 43 | 44 | -------------------------------------------------------------------------------- /OD/CLIP/hubconf.py: -------------------------------------------------------------------------------- 1 | from clip.clip import tokenize as _tokenize, load as _load, available_models as _available_models 2 | import re 3 | import string 4 | 5 | dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"] 6 | 7 | # For compatibility (cannot include special characters in function name) 8 | model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()} 9 | 10 | def _create_hub_entrypoint(model): 11 | def entrypoint(**kwargs): 12 | return _load(model, **kwargs) 13 | 14 | entrypoint.__doc__ = f"""Loads the {model} CLIP model 15 | 16 | Parameters 17 | ---------- 18 | device : Union[str, torch.device] 19 | The device to put the loaded model 20 | 21 | jit : bool 22 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 23 | 24 | download_root: str 25 | path to download the model files; by default, it uses "~/.cache/clip" 26 | 27 | Returns 28 | ------- 29 | model : torch.nn.Module 30 | The {model} CLIP model 31 | 32 | preprocess : Callable[[PIL.Image], torch.Tensor] 33 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 34 | """ 35 | return entrypoint 36 | 37 | def tokenize(): 38 | return _tokenize 39 | 40 | _entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()} 41 | 42 | globals().update(_entrypoints) -------------------------------------------------------------------------------- /OD/modeling/custom_pascal_evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import tempfile 4 | from collections import OrderedDict, defaultdict 5 | 6 | from detectron2.utils import comm 7 | from detectron2.evaluation import PascalVOCDetectionEvaluator 8 | from detectron2.evaluation.pascal_voc_evaluation import voc_eval 9 | 10 | class CustomPascalVOCDetectionEvaluator(PascalVOCDetectionEvaluator): 11 | def __init__(self,dataset_name): 12 | super().__init__(dataset_name) 13 | 14 | def evaluate(self): 15 | """ 16 | Returns: 17 | dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75". 18 | """ 19 | all_predictions = comm.gather(self._predictions, dst=0) 20 | if not comm.is_main_process(): 21 | return 22 | predictions = defaultdict(list) 23 | for predictions_per_rank in all_predictions: 24 | for clsid, lines in predictions_per_rank.items(): 25 | predictions[clsid].extend(lines) 26 | del all_predictions 27 | 28 | self._logger.info( 29 | "Evaluating {} using {} metric. " 30 | "Note that results do not use the official Matlab API.".format( 31 | self._dataset_name, 2007 if self._is_2007 else 2012 32 | ) 33 | ) 34 | 35 | 36 | with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname: 37 | res_file_template = os.path.join(dirname, "{}.txt") 38 | 39 | aps = defaultdict(list) # iou -> ap per class 40 | for cls_id, cls_name in enumerate(self._class_names): 41 | lines = predictions.get(cls_id, [""]) 42 | 43 | with open(res_file_template.format(cls_name), "w") as f: 44 | f.write("\n".join(lines)) 45 | 46 | for thresh in range(50, 100, 5): 47 | rec, prec, ap = voc_eval( 48 | res_file_template, 49 | self._anno_file_template, 50 | self._image_set_path, 51 | cls_name, 52 | ovthresh=thresh / 100.0, 53 | use_07_metric=self._is_2007, 54 | ) 55 | aps[thresh].append(ap * 100) 56 | 57 | ret = OrderedDict() 58 | mAP = {iou: np.mean(x) for iou, x in aps.items()} 59 | 60 | clsaps = ','.join(['{:.2f}'.format(a) for a in aps[50]]) 61 | self._logger.info("classwise ap {}".format(clsaps)) 62 | 63 | ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]} 64 | return ret 65 | -------------------------------------------------------------------------------- /OD/modeling/.ipynb_checkpoints/custom_pascal_evaluation-checkpoint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import tempfile 4 | from collections import OrderedDict, defaultdict 5 | 6 | from detectron2.utils import comm 7 | from detectron2.evaluation import PascalVOCDetectionEvaluator 8 | from detectron2.evaluation.pascal_voc_evaluation import voc_eval 9 | 10 | class CustomPascalVOCDetectionEvaluator(PascalVOCDetectionEvaluator): 11 | def __init__(self,dataset_name): 12 | super().__init__(dataset_name) 13 | 14 | def evaluate(self): 15 | """ 16 | Returns: 17 | dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75". 18 | """ 19 | all_predictions = comm.gather(self._predictions, dst=0) 20 | if not comm.is_main_process(): 21 | return 22 | predictions = defaultdict(list) 23 | for predictions_per_rank in all_predictions: 24 | for clsid, lines in predictions_per_rank.items(): 25 | predictions[clsid].extend(lines) 26 | del all_predictions 27 | 28 | self._logger.info( 29 | "Evaluating {} using {} metric. " 30 | "Note that results do not use the official Matlab API.".format( 31 | self._dataset_name, 2007 if self._is_2007 else 2012 32 | ) 33 | ) 34 | 35 | 36 | with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname: 37 | res_file_template = os.path.join(dirname, "{}.txt") 38 | 39 | aps = defaultdict(list) # iou -> ap per class 40 | for cls_id, cls_name in enumerate(self._class_names): 41 | lines = predictions.get(cls_id, [""]) 42 | 43 | with open(res_file_template.format(cls_name), "w") as f: 44 | f.write("\n".join(lines)) 45 | 46 | for thresh in range(50, 100, 5): 47 | rec, prec, ap = voc_eval( 48 | res_file_template, 49 | self._anno_file_template, 50 | self._image_set_path, 51 | cls_name, 52 | ovthresh=thresh / 100.0, 53 | use_07_metric=self._is_2007, 54 | ) 55 | aps[thresh].append(ap * 100) 56 | 57 | ret = OrderedDict() 58 | mAP = {iou: np.mean(x) for iou, x in aps.items()} 59 | 60 | clsaps = ','.join(['{:.2f}'.format(a) for a in aps[50]]) 61 | self._logger.info("classwise ap {}".format(clsaps)) 62 | 63 | ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]} 64 | return ret 65 | -------------------------------------------------------------------------------- /OD/modeling/regularization.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class spar_loss(nn.Module): 8 | def __init__(self): 9 | super(spar_loss, self).__init__() 10 | 11 | def forward(self, flops_real, flops_mask, flops_ori, batch_size, den_target, lbda): 12 | # total sparsity 13 | flops_tensor, flops_conv1, flops_fc = flops_real[0], flops_real[1], flops_real[2] 14 | # block flops 15 | flops_conv = flops_tensor[0:batch_size,:].mean(0).sum() 16 | flops_mask = flops_mask.mean(0).sum() 17 | flops_ori = flops_ori.mean(0).sum() + flops_conv1.mean() + flops_fc.mean() 18 | flops_real = flops_conv + flops_mask + flops_conv1.mean() + flops_fc.mean() 19 | # loss 20 | rloss = lbda * (flops_real / flops_ori - den_target)**2 21 | return rloss 22 | 23 | 24 | class blance_loss(nn.Module): 25 | def __init__(self): 26 | super(blance_loss, self).__init__() 27 | 28 | def forward(self, mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, batch_size, 29 | den_target, gamma, p): 30 | norm_s = mask_norm_s 31 | norm_s_t = norm_s_t.mean(0) 32 | norm_c = mask_norm_c 33 | norm_c_t = norm_c_t.mean(0) 34 | den_s = norm_s[0:batch_size,:].mean(0) / norm_s_t 35 | den_c = norm_c[0:batch_size,:].mean(0) / norm_c_t 36 | den_tar = math.sqrt(den_target) 37 | bloss_s = get_bloss_basic(den_s, den_tar, batch_size, gamma, p) 38 | bloss_c = get_bloss_basic(den_c, den_tar, batch_size, gamma, p) 39 | bloss = bloss_s + bloss_c 40 | return bloss 41 | 42 | 43 | def get_bloss_basic(spar, spar_tar, batch_size, gamma, p): 44 | # bound 45 | bloss_l = (F.relu(p*spar_tar-spar)**2).mean() 46 | bloss_u = (F.relu(spar-1+p-p*spar_tar)**2).mean() 47 | bloss = gamma * (bloss_l + bloss_u) 48 | return bloss 49 | 50 | 51 | class Loss(nn.Module): 52 | def __init__(self): 53 | super(Loss, self).__init__() 54 | self.task_loss = nn.CrossEntropyLoss() 55 | self.spar_loss = spar_loss() 56 | self.balance_loss = blance_loss() 57 | 58 | def forward(self, output, targets, flops_real, flops_mask, flops_ori, batch_size, 59 | den_target, lbda, mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, 60 | gamma, p): 61 | closs = self.task_loss(output, targets) 62 | sloss = self.spar_loss(flops_real, flops_mask, flops_ori, batch_size, den_target, lbda) 63 | bloss = self.balance_loss(mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, batch_size, 64 | den_target, gamma, p) 65 | return closs, sloss, bloss 66 | -------------------------------------------------------------------------------- /OD/modeling/.ipynb_checkpoints/regularization-checkpoint.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class spar_loss(nn.Module): 8 | def __init__(self): 9 | super(spar_loss, self).__init__() 10 | 11 | def forward(self, flops_real, flops_mask, flops_ori, batch_size, den_target, lbda): 12 | # total sparsity 13 | flops_tensor, flops_conv1, flops_fc = flops_real[0], flops_real[1], flops_real[2] 14 | # block flops 15 | flops_conv = flops_tensor[0:batch_size,:].mean(0).sum() 16 | flops_mask = flops_mask.mean(0).sum() 17 | flops_ori = flops_ori.mean(0).sum() + flops_conv1.mean() + flops_fc.mean() 18 | flops_real = flops_conv + flops_mask + flops_conv1.mean() + flops_fc.mean() 19 | # loss 20 | rloss = lbda * (flops_real / flops_ori - den_target)**2 21 | return rloss 22 | 23 | 24 | class blance_loss(nn.Module): 25 | def __init__(self): 26 | super(blance_loss, self).__init__() 27 | 28 | def forward(self, mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, batch_size, 29 | den_target, gamma, p): 30 | norm_s = mask_norm_s 31 | norm_s_t = norm_s_t.mean(0) 32 | norm_c = mask_norm_c 33 | norm_c_t = norm_c_t.mean(0) 34 | den_s = norm_s[0:batch_size,:].mean(0) / norm_s_t 35 | den_c = norm_c[0:batch_size,:].mean(0) / norm_c_t 36 | den_tar = math.sqrt(den_target) 37 | bloss_s = get_bloss_basic(den_s, den_tar, batch_size, gamma, p) 38 | bloss_c = get_bloss_basic(den_c, den_tar, batch_size, gamma, p) 39 | bloss = bloss_s + bloss_c 40 | return bloss 41 | 42 | 43 | def get_bloss_basic(spar, spar_tar, batch_size, gamma, p): 44 | # bound 45 | bloss_l = (F.relu(p*spar_tar-spar)**2).mean() 46 | bloss_u = (F.relu(spar-1+p-p*spar_tar)**2).mean() 47 | bloss = gamma * (bloss_l + bloss_u) 48 | return bloss 49 | 50 | 51 | class Loss(nn.Module): 52 | def __init__(self): 53 | super(Loss, self).__init__() 54 | self.task_loss = nn.CrossEntropyLoss() 55 | self.spar_loss = spar_loss() 56 | self.balance_loss = blance_loss() 57 | 58 | def forward(self, output, targets, flops_real, flops_mask, flops_ori, batch_size, 59 | den_target, lbda, mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, 60 | gamma, p): 61 | closs = self.task_loss(output, targets) 62 | sloss = self.spar_loss(flops_real, flops_mask, flops_ori, batch_size, den_target, lbda) 63 | bloss = self.balance_loss(mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, batch_size, 64 | den_target, gamma, p) 65 | return closs, sloss, bloss 66 | -------------------------------------------------------------------------------- /OD/modeling/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torchvision.transforms as T 6 | 7 | from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec 8 | from ipdb import set_trace as stxx 9 | 10 | @BACKBONE_REGISTRY.register() 11 | class ClipRN101(Backbone): 12 | def __init__(self, cfg, clip_visual): 13 | super().__init__() 14 | self.enc = None 15 | self.unfreeze = cfg.MODEL.BACKBONE.UNFREEZE 16 | self.proj = nn.Linear(512,512) 17 | self.global_proj = nn.Linear(512,512) 18 | self.use_proj = cfg.MODEL.USE_PROJ 19 | 20 | 21 | def set_backbone_model(self,model): 22 | self.enc = model 23 | for name,val in self.enc.named_parameters(): 24 | head = name.split('.')[0] 25 | if head not in self.unfreeze: 26 | val.requires_grad = False 27 | else: 28 | val.requires_grad = True 29 | 30 | self.backbone_unchanged = nn.Sequential(*self.enc.layer3[:19]) 31 | 32 | def forward(self, image): 33 | 34 | x = image 35 | 36 | batch_num, _, _, _ = x.shape 37 | 38 | 39 | gate_activations = [] 40 | x = self.enc.relu1(self.enc.bn1(self.enc.conv1(x))) 41 | x = self.enc.relu2(self.enc.bn2(self.enc.conv2(x))) 42 | x = self.enc.relu3(self.enc.bn3(self.enc.conv3(x))) 43 | x = self.enc.avgpool(x) 44 | 45 | norm1 = torch.zeros(1, batch_num+1).to(x.device) 46 | norm2 = torch.zeros(1, batch_num+1).to(x.device) 47 | flops = torch.zeros(1, batch_num+2).to(x.device) 48 | 49 | x = self.enc.layer1(x) 50 | x = self.enc.layer2(x) 51 | x = self.enc.layer3(x) 52 | return {"res4": x} 53 | 54 | 55 | def forward_l12(self, image): 56 | x = image 57 | x = self.enc.relu1(self.enc.bn1(self.enc.conv1(x))) 58 | x = self.enc.relu2(self.enc.bn2(self.enc.conv2(x))) 59 | x = self.enc.relu3(self.enc.bn3(self.enc.conv3(x))) 60 | x = self.enc.avgpool(x) 61 | 62 | x = self.enc.layer1(x) 63 | x = self.enc.layer2(x) 64 | 65 | return x 66 | 67 | def forward_l3(self, x): 68 | x = self.enc.layer3(x) 69 | return {"res4": x} 70 | 71 | 72 | def output_shape(self): 73 | return {"res4": ShapeSpec(channels=1024, stride=16)} 74 | 75 | # def forward_res5(self,x): 76 | # def forward_res5(self, x, norm1, norm2, flops): 77 | # #detectron used last resnet layer for roi heads 78 | # x, norm1, norm2, flops = self.enc.layer4((x, norm1, norm2, flops)) 79 | # # x = self.enc.layer4(x) 80 | # return x 81 | 82 | 83 | 84 | def forward_res5(self, x, txt_emb, norm1, norm2, flops): 85 | #detectron used last resnet layer for roi heads 86 | x, txt_emb, norm1, norm2, flops = self.enc.layer4((x, txt_emb, norm1, norm2, flops)) 87 | return x 88 | 89 | 90 | 91 | def attention_global_pool(self,input): 92 | x = input 93 | x = self.enc.attnpool(x) 94 | return x 95 | 96 | -------------------------------------------------------------------------------- /OD/modeling/.ipynb_checkpoints/backbone-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torchvision.transforms as T 6 | 7 | from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec 8 | from ipdb import set_trace as stxx 9 | 10 | @BACKBONE_REGISTRY.register() 11 | class ClipRN101(Backbone): 12 | def __init__(self, cfg, clip_visual): 13 | super().__init__() 14 | self.enc = None 15 | self.unfreeze = cfg.MODEL.BACKBONE.UNFREEZE 16 | self.proj = nn.Linear(512,512) 17 | self.global_proj = nn.Linear(512,512) 18 | self.use_proj = cfg.MODEL.USE_PROJ 19 | 20 | 21 | def set_backbone_model(self,model): 22 | self.enc = model 23 | for name,val in self.enc.named_parameters(): 24 | head = name.split('.')[0] 25 | if head not in self.unfreeze: 26 | val.requires_grad = False 27 | else: 28 | val.requires_grad = True 29 | 30 | self.backbone_unchanged = nn.Sequential(*self.enc.layer3[:19]) 31 | 32 | def forward(self, image): 33 | 34 | x = image 35 | 36 | batch_num, _, _, _ = x.shape 37 | 38 | 39 | gate_activations = [] 40 | x = self.enc.relu1(self.enc.bn1(self.enc.conv1(x))) 41 | x = self.enc.relu2(self.enc.bn2(self.enc.conv2(x))) 42 | x = self.enc.relu3(self.enc.bn3(self.enc.conv3(x))) 43 | x = self.enc.avgpool(x) 44 | 45 | norm1 = torch.zeros(1, batch_num+1).to(x.device) 46 | norm2 = torch.zeros(1, batch_num+1).to(x.device) 47 | flops = torch.zeros(1, batch_num+2).to(x.device) 48 | 49 | x = self.enc.layer1(x) 50 | x = self.enc.layer2(x) 51 | x = self.enc.layer3(x) 52 | return {"res4": x} 53 | 54 | 55 | def forward_l12(self, image): 56 | x = image 57 | x = self.enc.relu1(self.enc.bn1(self.enc.conv1(x))) 58 | x = self.enc.relu2(self.enc.bn2(self.enc.conv2(x))) 59 | x = self.enc.relu3(self.enc.bn3(self.enc.conv3(x))) 60 | x = self.enc.avgpool(x) 61 | 62 | x = self.enc.layer1(x) 63 | x = self.enc.layer2(x) 64 | 65 | return x 66 | 67 | def forward_l3(self, x): 68 | x = self.enc.layer3(x) 69 | return {"res4": x} 70 | 71 | 72 | def output_shape(self): 73 | return {"res4": ShapeSpec(channels=1024, stride=16)} 74 | 75 | # def forward_res5(self,x): 76 | # def forward_res5(self, x, norm1, norm2, flops): 77 | # #detectron used last resnet layer for roi heads 78 | # x, norm1, norm2, flops = self.enc.layer4((x, norm1, norm2, flops)) 79 | # # x = self.enc.layer4(x) 80 | # return x 81 | 82 | 83 | 84 | def forward_res5(self, x, txt_emb, norm1, norm2, flops): 85 | #detectron used last resnet layer for roi heads 86 | x, txt_emb, norm1, norm2, flops = self.enc.layer4((x, txt_emb, norm1, norm2, flops)) 87 | return x 88 | 89 | 90 | 91 | def attention_global_pool(self,input): 92 | x = input 93 | x = self.enc.attnpool(x) 94 | return x 95 | 96 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prompt-Driven Dynamic Object-Centric Learning for Single Domain Generalization 2 | This repo is the implementation of Prompt-Driven Dynamic Object-Centric Learning for Single Domain Generalization (CVPR 2024). 3 | 4 | > **Abstract:** *Single-domain generalization aims to learn a model from single source domain data to achieve generalized performance on other unseen target domains. Existing works primarily focus on improving the generalization ability of static networks. However, static networks are unable to dynamically adapt to the diverse variations in different image scenes, leading to limited generalization capability. Different scenes exhibit varying levels of complexity, and the complexity of images further varies significantly in cross-domain scenarios. In this paper, we propose a dynamic object-centric perception network based on prompt learning, aiming to adapt to the variations in image complexity. Specifically, we propose an object-centric gating module based on prompt learning to focus attention on the object-centric features guided by the various scene prompts. Then, with the object-centric gating masks, the dynamic selective module dynamically selects highly correlated feature regions in both spatial and channel dimensions enabling the model to adaptively perceive object-centric relevant features, thereby enhancing the generalization capability. Extensive experiments were conducted on single-domain generalization tasks in image classification and object detection. The experimental results demonstrate that our approach outperforms state-of-the-art methods, which validates the effectiveness and generally of our proposed method.* 5 | 6 | ## 1. Illustration of dynamic object-centric learning via prompts for single domain generalization. 7 | 8 |

9 | 10 |

11 | 12 | Object-centric features capture the essential information related to individual objects. 13 | Incorporating the given scene prompts to dynamically optimize the extraction of object-centric features is beneficial for improving the generalization performance of models. 14 | 15 | ## 2. Method 16 | 17 |

18 | 19 |

20 | 21 | The proposed prompt-based dynamic object-centric learning framework. It mainly includes a prompt-based object-centric gating module and a dynamic selective module. First, the Slot Attention multimodal fusion module extracts object-centric features and leverages the various scene prompts to guide the object-centric gating mask learning for the input from different scenes. Next, the gating mask is used to dynamically select the relevant object-centric features to improve the 22 | generalization ability. 23 | 24 | ## 3. Usage 25 | ### 3.1 Prepare data 26 | #### Image Classificaton : PACS (Art paintings, Cartoons, Photos, and Sketches) 27 | #### Object Detection: Diverse-Weather Dataset(Daytime-Sunny, Night-Sunny, Dusk-Rainy, Night-Rainy, and Daytime-Foggy) 28 | 29 | 30 | 31 | ### 3.2 Dependencies 32 | 33 | Python: 3.8.10 34 | PyTorch: 1.9.1 35 | Pillow: 9.5.0 36 | Torchvision: 0.8.2 37 | CUDA: 11.8 38 | NumPy: 1.22.4 39 | PIL: 7.2.0 40 | clip: 1.0 41 | detectron2: 0.6 42 | 43 | ### 3.3 Train and Test 44 | 45 | - Train on source domain 46 | ``` 47 | CUDA_VISIBLE_DEVICES=0 python train.py --config-file configs/diverse_weather.yaml 48 | ``` 49 | 50 | - Test on target domain (Daytime-Foggy) 51 | 52 | ``` 53 | python train.py --config-file configs/diverse_weather_foggy_test.yaml --eval-only MODEL.WEIGHTS all_outs/diverse_weather/model_best.pth > diverse_weather_foggy_test.log 54 | ``` 55 | 56 | ## Acknowledgement 57 | Our code is based on the project [Detectron2](https://github.com/facebookresearch/detectron2). 58 | -------------------------------------------------------------------------------- /OD/data/datasets/pascal_voc_adaptation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import numpy as np 5 | import os 6 | import xml.etree.ElementTree as ET 7 | from typing import List, Tuple, Union 8 | 9 | from detectron2.data import DatasetCatalog, MetadataCatalog 10 | from detectron2.structures import BoxMode 11 | from detectron2.utils.file_io import PathManager 12 | 13 | __all__ = ["load_voc_instances", "register_all_pascal_voc"] 14 | 15 | 16 | # fmt: off 17 | CLASS_NAMES = ( 18 | "bicycle", "bird", "car", "cat", "dog", "person", 19 | ) 20 | # fmt: on 21 | 22 | 23 | def load_voc_instances(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 24 | """ 25 | Load Pascal VOC detection annotations to Detectron2 format. 26 | Args: 27 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 28 | split (str): one of "train", "test", "val", "trainval" 29 | class_names: list or tuple of class names 30 | """ 31 | with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f: 32 | fileids = np.loadtxt(f, dtype=np.str) 33 | 34 | # Needs to read many small annotation files. Makes sense at local 35 | annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/")) 36 | dicts = [] 37 | for fileid in fileids: 38 | anno_file = os.path.join(annotation_dirname, fileid + ".xml") 39 | jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg") 40 | 41 | with PathManager.open(anno_file) as f: 42 | tree = ET.parse(f) 43 | 44 | r = { 45 | "file_name": jpeg_file, 46 | "image_id": fileid, 47 | "height": int(tree.findall("./size/height")[0].text), 48 | "width": int(tree.findall("./size/width")[0].text), 49 | } 50 | instances = [] 51 | 52 | for obj in tree.findall("object"): 53 | cls = obj.find("name").text 54 | if cls not in CLASS_NAMES: 55 | continue 56 | # We include "difficult" samples in training. 57 | # Based on limited experiments, they don't hurt accuracy. 58 | # difficult = int(obj.find("difficult").text) 59 | # if difficult == 1: 60 | # continue 61 | bbox = obj.find("bndbox") 62 | bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]] 63 | # Original annotations are integers in the range [1, W or H] 64 | # Assuming they mean 1-based pixel indices (inclusive), 65 | # a box with annotation (xmin=1, xmax=W) covers the whole image. 66 | # In coordinate space this is represented by (xmin=0, xmax=W) 67 | bbox[0] -= 1.0 68 | bbox[1] -= 1.0 69 | instances.append( 70 | {"category_id": class_names.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS} 71 | ) 72 | r["annotations"] = instances 73 | dicts.append(r) 74 | return dicts 75 | 76 | 77 | def register_pascal_voc(name, dirname, split, year, class_names=CLASS_NAMES): 78 | DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names)) 79 | MetadataCatalog.get(name).set( 80 | thing_classes=list(class_names), dirname=dirname, year=year, split=split 81 | ) 82 | 83 | def register_all_pascal_voc(root): 84 | SPLITS = [ 85 | ("voc_adapt_2007_trainval", "VOC2007", "trainval"), 86 | ("voc_adapt_2007_train", "VOC2007", "train"), 87 | ("voc_adapt_2007_val", "VOC2007", "val"), 88 | ("voc_adapt_2007_test", "VOC2007", "test"), 89 | ("voc_adapt_2012_trainval", "VOC2012", "trainval"), 90 | ("voc_adapt_2012_train", "VOC2012", "train"), 91 | ("voc_adapt_2012_val", "VOC2012", "val"), 92 | ] 93 | for name, dirname, split in SPLITS: 94 | year = 2007 if "2007" in name else 2012 95 | register_pascal_voc(name, os.path.join(root, dirname), split, year) 96 | MetadataCatalog.get(name).evaluator_type = "pascal_voc" -------------------------------------------------------------------------------- /OD/modeling/rpn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from detectron2.modeling import PROPOSAL_GENERATOR_REGISTRY, RPN_HEAD_REGISTRY 6 | from detectron2.modeling.proposal_generator.rpn import RPN, StandardRPNHead 7 | from detectron2.structures import ImageList 8 | from typing import List 9 | import time 10 | 11 | 12 | @PROPOSAL_GENERATOR_REGISTRY.register() 13 | class SBRPN(RPN): 14 | 15 | def forward( 16 | self, 17 | images, 18 | features, 19 | gt_instances= None, 20 | ): 21 | """ 22 | Args: 23 | images (ImageList): input images of length `N` 24 | features (dict[str, Tensor]): input data as a mapping from feature 25 | map name to tensor. Axis 0 represents the number of images `N` in 26 | the input data; axes 1-3 are channels, height, and width, which may 27 | vary between feature maps (e.g., if a feature pyramid is used). 28 | gt_instances (list[Instances], optional): a length `N` list of `Instances`s. 29 | Each `Instances` stores ground-truth instances for the corresponding image. 30 | Returns: 31 | proposals: list[Instances]: contains fields "proposal_boxes", "objectness_logits" 32 | loss: dict[Tensor] or None 33 | """ 34 | features = [features[f] for f in self.in_features] 35 | anchors = self.anchor_generator(features) 36 | val = self.rpn_head(features) 37 | 38 | if len(val) == 2: 39 | pred_objectness_logits, pred_anchor_deltas = val 40 | rep_clip = None 41 | else: 42 | pred_objectness_logits, pred_anchor_deltas, rep_clip = val 43 | 44 | # Transpose the Hi*Wi*A dimension to the middle: 45 | pred_objectness_logits= [ 46 | # (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N, Hi*Wi*A) 47 | score.permute(0, 2, 3, 1).flatten(1) 48 | for score in pred_objectness_logits 49 | ] 50 | 51 | 52 | pred_anchor_deltas = [ 53 | # (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B) -> (N, Hi*Wi*A, B) 54 | x.view(x.shape[0], -1, self.anchor_generator.box_dim, x.shape[-2], x.shape[-1]) 55 | .permute(0, 3, 4, 1, 2) 56 | .flatten(1, -2) 57 | for x in pred_anchor_deltas 58 | ] 59 | 60 | if self.training: 61 | #assert gt_instances is not None, "RPN requires gt_instances in training!" 62 | if gt_instances is not None: 63 | gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances) 64 | losses = self.losses( 65 | anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes 66 | ) 67 | if rep_clip is not None : 68 | if not isinstance(rep_clip, dict): 69 | ll = torch.stack(gt_labels) 70 | ll = ll.reshape(ll.shape[0],-1,15) 71 | valid_mask = ll.ge(0).sum(dim=-1).gt(0) # remove ignored anchors 72 | ll=ll.eq(1).sum(dim=-1) # if an object is present at this location 73 | ll = ll.gt(0).float() 74 | 75 | clip_loss = torch.nn.functional.binary_cross_entropy_with_logits(rep_clip[valid_mask],ll[valid_mask],reduction='sum') 76 | losses.update({'loss_rpn_cls_clip':clip_loss/(self.batch_size_per_image*ll.shape[0])}) 77 | else: 78 | losses.update(rep_clip) 79 | else: 80 | losses = {} 81 | else: 82 | losses = {} 83 | proposals = self.predict_proposals( 84 | anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes 85 | ) 86 | out = [ 87 | # (N, Hi*Wi*A) -> (N, Hi, Wi, A) 88 | score.reshape(features[ind].shape[0],features[ind].shape[-2],features[ind].shape[-1],-1) 89 | for ind, score in enumerate(pred_objectness_logits) 90 | ] 91 | return out, proposals, losses 92 | 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /OD/modeling/.ipynb_checkpoints/rpn-checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from detectron2.modeling import PROPOSAL_GENERATOR_REGISTRY, RPN_HEAD_REGISTRY 6 | from detectron2.modeling.proposal_generator.rpn import RPN, StandardRPNHead 7 | from detectron2.structures import ImageList 8 | from typing import List 9 | import time 10 | 11 | 12 | @PROPOSAL_GENERATOR_REGISTRY.register() 13 | class SBRPN(RPN): 14 | 15 | def forward( 16 | self, 17 | images, 18 | features, 19 | gt_instances= None, 20 | ): 21 | """ 22 | Args: 23 | images (ImageList): input images of length `N` 24 | features (dict[str, Tensor]): input data as a mapping from feature 25 | map name to tensor. Axis 0 represents the number of images `N` in 26 | the input data; axes 1-3 are channels, height, and width, which may 27 | vary between feature maps (e.g., if a feature pyramid is used). 28 | gt_instances (list[Instances], optional): a length `N` list of `Instances`s. 29 | Each `Instances` stores ground-truth instances for the corresponding image. 30 | Returns: 31 | proposals: list[Instances]: contains fields "proposal_boxes", "objectness_logits" 32 | loss: dict[Tensor] or None 33 | """ 34 | features = [features[f] for f in self.in_features] 35 | anchors = self.anchor_generator(features) 36 | val = self.rpn_head(features) 37 | 38 | if len(val) == 2: 39 | pred_objectness_logits, pred_anchor_deltas = val 40 | rep_clip = None 41 | else: 42 | pred_objectness_logits, pred_anchor_deltas, rep_clip = val 43 | 44 | # Transpose the Hi*Wi*A dimension to the middle: 45 | pred_objectness_logits= [ 46 | # (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N, Hi*Wi*A) 47 | score.permute(0, 2, 3, 1).flatten(1) 48 | for score in pred_objectness_logits 49 | ] 50 | 51 | 52 | pred_anchor_deltas = [ 53 | # (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B) -> (N, Hi*Wi*A, B) 54 | x.view(x.shape[0], -1, self.anchor_generator.box_dim, x.shape[-2], x.shape[-1]) 55 | .permute(0, 3, 4, 1, 2) 56 | .flatten(1, -2) 57 | for x in pred_anchor_deltas 58 | ] 59 | 60 | if self.training: 61 | #assert gt_instances is not None, "RPN requires gt_instances in training!" 62 | if gt_instances is not None: 63 | gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances) 64 | losses = self.losses( 65 | anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes 66 | ) 67 | if rep_clip is not None : 68 | if not isinstance(rep_clip, dict): 69 | ll = torch.stack(gt_labels) 70 | ll = ll.reshape(ll.shape[0],-1,15) 71 | valid_mask = ll.ge(0).sum(dim=-1).gt(0) # remove ignored anchors 72 | ll=ll.eq(1).sum(dim=-1) # if an object is present at this location 73 | ll = ll.gt(0).float() 74 | 75 | clip_loss = torch.nn.functional.binary_cross_entropy_with_logits(rep_clip[valid_mask],ll[valid_mask],reduction='sum') 76 | losses.update({'loss_rpn_cls_clip':clip_loss/(self.batch_size_per_image*ll.shape[0])}) 77 | else: 78 | losses.update(rep_clip) 79 | else: 80 | losses = {} 81 | else: 82 | losses = {} 83 | proposals = self.predict_proposals( 84 | anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes 85 | ) 86 | out = [ 87 | # (N, Hi*Wi*A) -> (N, Hi, Wi, A) 88 | score.reshape(features[ind].shape[0],features[ind].shape[-2],features[ind].shape[-1],-1) 89 | for ind, score in enumerate(pred_objectness_logits) 90 | ] 91 | return out, proposals, losses 92 | 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /OD/modeling/clip.py: -------------------------------------------------------------------------------- 1 | import clip 2 | 3 | import torch 4 | import torch.nn as nn 5 | import time 6 | import numpy as np 7 | import copy 8 | 9 | class SlotAttention(nn.Module): 10 | def __init__(self, dim=768, iters=3, eps=1e-8, hidden_dim=512, drop_rate=0.4, feature_size=512): 11 | super().__init__() 12 | self.iters = iters 13 | self.eps = eps 14 | self.scale = dim ** -0.5 15 | self.feature_size = feature_size 16 | 17 | self.to_q = nn.Linear(dim, dim) 18 | slot_share_qk = False 19 | if slot_share_qk: 20 | self.to_k = self.to_q 21 | else: 22 | self.to_k = nn.Linear(dim, dim) 23 | 24 | self.to_v = nn.Linear(feature_size, feature_size) 25 | 26 | hidden_dim = max(dim, hidden_dim, feature_size) 27 | 28 | self.gru = nn.GRUCell(feature_size, feature_size) 29 | self.mlp = nn.Sequential( 30 | nn.Linear(feature_size, hidden_dim), 31 | nn.ReLU(inplace=True), 32 | nn.Linear(hidden_dim, feature_size) 33 | ) 34 | 35 | self.norm_slots = nn.LayerNorm(feature_size) 36 | self.norm_pre_ff = nn.LayerNorm(feature_size) 37 | self.norm_input = nn.LayerNorm(feature_size) 38 | 39 | self.slot_dropout = nn.Dropout(drop_rate) 40 | self.input_dropout = nn.Dropout(drop_rate) 41 | 42 | def forward(self,cand_feat, pano_feat): 43 | 44 | b, d, device = *pano_feat.shape, pano_feat.device 45 | # original cand_feat as the initial slot 46 | slots = cand_feat.clone() 47 | slots = self.slot_dropout(slots) 48 | pano_feat = self.norm_input(pano_feat.clone()) 49 | pano_feat = self.input_dropout(pano_feat) 50 | # (bs, num_ctx, hidden_size) 51 | k = self.to_k(slots) 52 | v = self.to_v(slots) 53 | attn_weights = [] 54 | for t in range(self.iters): 55 | slots_prev = slots 56 | slots = self.norm_slots(slots.clone()) 57 | # (bs, num_slots, hidden_size) 58 | q = self.to_q(pano_feat.clone()) 59 | # (bs, num_slots, num_ctx) 60 | dots = torch.einsum('id,jd->ijd', k, q) * self.scale 61 | 62 | attn = dots.softmax(dim=1) 63 | attn_weights.append(attn) # for visualization 64 | # (bs, num_slots, feature_size) 65 | updates = torch.einsum('id,ijd->id', v, attn) 66 | gru_updates = self.gru( 67 | updates.reshape(-1, self.feature_size), 68 | slots_prev.clone().reshape(-1, self.feature_size) 69 | ) 70 | gru_updates = gru_updates + self.mlp(self.norm_pre_ff(gru_updates)) 71 | slots = gru_updates.clone() 72 | return slots 73 | 74 | 75 | class ClipPredictor(nn.Module): 76 | def __init__(self, clip_enocder_name,inshape, device, clsnames): 77 | super().__init__() 78 | self.model, self.preprocess = clip.load(clip_enocder_name, device) 79 | self.model.float() 80 | #freeze everything 81 | for name, val in self.model.named_parameters(): 82 | val.requires_grad = False 83 | # this is only used for inference 84 | self.frozen_clip_model = copy.deepcopy(self.model) 85 | 86 | self.visual_enc = self.model.visual 87 | prompt = 'a photo of a {}' 88 | print(clsnames) 89 | with torch.no_grad(): 90 | text_inputs = torch.cat([clip.tokenize(prompt.format(cls)) for cls in clsnames]).to(device) 91 | self.text_features = self.model.encode_text(text_inputs).float() 92 | self.text_features /= self.text_features.norm(dim=-1, keepdim=True) 93 | 94 | 95 | self.projection = nn.Linear(inshape,512) 96 | self.projection_global = nn.Linear(inshape,512) 97 | 98 | self.slot_attention = SlotAttention( 99 | dim=512, 100 | iters=3, 101 | drop_rate=0, 102 | ) 103 | 104 | def forward(self, feat, gfeat=None): 105 | 106 | if feat.shape[-1] > 512: 107 | feat = self.projection(feat) 108 | feat = 0.5* feat + 0.5* self.slot_attention(feat,self.text_features.detach()) 109 | feat = feat/feat.norm(dim=-1,keepdim=True) 110 | if gfeat is not None: 111 | 112 | feat = feat-gfeat 113 | feat = feat/feat.norm(dim=-1,keepdim=True) 114 | scores = (100.0 * torch.matmul(feat,self.text_features.detach().T)) 115 | 116 | # print(scores.min(),scores.max()) 117 | # add for bkg class a score 0 118 | scores = torch.cat([scores,torch.zeros(scores.shape[0],1,device=scores.device)],1) 119 | return scores 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /OD/modeling/.ipynb_checkpoints/clip-checkpoint.py: -------------------------------------------------------------------------------- 1 | import clip 2 | 3 | import torch 4 | import torch.nn as nn 5 | import time 6 | import numpy as np 7 | import copy 8 | 9 | class SlotAttention(nn.Module): 10 | def __init__(self, dim=768, iters=3, eps=1e-8, hidden_dim=512, drop_rate=0.4, feature_size=512): 11 | super().__init__() 12 | self.iters = iters 13 | self.eps = eps 14 | self.scale = dim ** -0.5 15 | self.feature_size = feature_size 16 | 17 | self.to_q = nn.Linear(dim, dim) 18 | slot_share_qk = False 19 | if slot_share_qk: 20 | self.to_k = self.to_q 21 | else: 22 | self.to_k = nn.Linear(dim, dim) 23 | 24 | self.to_v = nn.Linear(feature_size, feature_size) 25 | 26 | hidden_dim = max(dim, hidden_dim, feature_size) 27 | 28 | self.gru = nn.GRUCell(feature_size, feature_size) 29 | self.mlp = nn.Sequential( 30 | nn.Linear(feature_size, hidden_dim), 31 | nn.ReLU(inplace=True), 32 | nn.Linear(hidden_dim, feature_size) 33 | ) 34 | 35 | self.norm_slots = nn.LayerNorm(feature_size) 36 | self.norm_pre_ff = nn.LayerNorm(feature_size) 37 | self.norm_input = nn.LayerNorm(feature_size) 38 | 39 | self.slot_dropout = nn.Dropout(drop_rate) 40 | self.input_dropout = nn.Dropout(drop_rate) 41 | 42 | def forward(self,cand_feat, pano_feat): 43 | 44 | b, d, device = *pano_feat.shape, pano_feat.device 45 | # original cand_feat as the initial slot 46 | slots = cand_feat.clone() 47 | slots = self.slot_dropout(slots) 48 | pano_feat = self.norm_input(pano_feat.clone()) 49 | pano_feat = self.input_dropout(pano_feat) 50 | # (bs, num_ctx, hidden_size) 51 | k = self.to_k(slots) 52 | v = self.to_v(slots) 53 | attn_weights = [] 54 | for t in range(self.iters): 55 | slots_prev = slots 56 | slots = self.norm_slots(slots.clone()) 57 | # (bs, num_slots, hidden_size) 58 | q = self.to_q(pano_feat.clone()) 59 | # (bs, num_slots, num_ctx) 60 | dots = torch.einsum('id,jd->ijd', k, q) * self.scale 61 | 62 | attn = dots.softmax(dim=1) 63 | attn_weights.append(attn) # for visualization 64 | # (bs, num_slots, feature_size) 65 | updates = torch.einsum('id,ijd->id', v, attn) 66 | gru_updates = self.gru( 67 | updates.reshape(-1, self.feature_size), 68 | slots_prev.clone().reshape(-1, self.feature_size) 69 | ) 70 | gru_updates = gru_updates + self.mlp(self.norm_pre_ff(gru_updates)) 71 | slots = gru_updates.clone() 72 | return slots 73 | 74 | 75 | class ClipPredictor(nn.Module): 76 | def __init__(self, clip_enocder_name,inshape, device, clsnames): 77 | super().__init__() 78 | self.model, self.preprocess = clip.load(clip_enocder_name, device) 79 | self.model.float() 80 | #freeze everything 81 | for name, val in self.model.named_parameters(): 82 | val.requires_grad = False 83 | # this is only used for inference 84 | self.frozen_clip_model = copy.deepcopy(self.model) 85 | 86 | self.visual_enc = self.model.visual 87 | prompt = 'a photo of a {}' 88 | print(clsnames) 89 | with torch.no_grad(): 90 | text_inputs = torch.cat([clip.tokenize(prompt.format(cls)) for cls in clsnames]).to(device) 91 | self.text_features = self.model.encode_text(text_inputs).float() 92 | self.text_features /= self.text_features.norm(dim=-1, keepdim=True) 93 | 94 | 95 | self.projection = nn.Linear(inshape,512) 96 | self.projection_global = nn.Linear(inshape,512) 97 | 98 | self.slot_attention = SlotAttention( 99 | dim=512, 100 | iters=3, 101 | drop_rate=0, 102 | ) 103 | 104 | def forward(self, feat, gfeat=None): 105 | 106 | if feat.shape[-1] > 512: 107 | feat = self.projection(feat) 108 | feat = 0.5* feat + 0.5* self.slot_attention(feat,self.text_features.detach()) 109 | feat = feat/feat.norm(dim=-1,keepdim=True) 110 | if gfeat is not None: 111 | 112 | feat = feat-gfeat 113 | feat = feat/feat.norm(dim=-1,keepdim=True) 114 | scores = (100.0 * torch.matmul(feat,self.text_features.detach().T)) 115 | 116 | # print(scores.min(),scores.max()) 117 | # add for bkg class a score 0 118 | scores = torch.cat([scores,torch.zeros(scores.shape[0],1,device=scores.device)],1) 119 | return scores 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /OD/data/datasets/diverse_weather.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | 4 | from tqdm import tqdm 5 | import pickle as pkl 6 | import xml.etree.ElementTree as ET 7 | 8 | import cv2 9 | import numpy as np 10 | from pymage_size import get_image_size 11 | 12 | from detectron2.data import DatasetCatalog, MetadataCatalog 13 | from detectron2.structures import BoxMode 14 | 15 | all_class_name = ['bus' ,'bike', 'car', 'motor', 'person', 'rider' ,'truck'] 16 | 17 | def get_annotation(root, image_id, ind): 18 | annotation_file = os.path.join(root,'VOC2007', "Annotations", "%s.xml" % image_id) 19 | 20 | et = ET.parse(annotation_file) 21 | 22 | 23 | objects = et.findall("object") 24 | 25 | record = {} 26 | record["file_name"] = os.path.join(root, 'VOC2007', "JPEGImages", "%s.jpg" % image_id) 27 | img_format = get_image_size(record["file_name"]) 28 | w, h = img_format.get_dimensions() 29 | 30 | record["image_id"] = image_id#ind for pascal evaluation actual image name is needed 31 | record["annotations"] = [] 32 | 33 | for obj in objects: 34 | class_name = obj.find('name').text.lower().strip() 35 | if class_name not in all_class_name: 36 | print(class_name) 37 | continue 38 | if obj.find('pose') is None: 39 | obj.append(ET.Element('pose')) 40 | obj.find('pose').text = '0' 41 | 42 | if obj.find('truncated') is None: 43 | obj.append(ET.Element('truncated')) 44 | obj.find('truncated').text = '0' 45 | 46 | if obj.find('difficult') is None: 47 | obj.append(ET.Element('difficult')) 48 | obj.find('difficult').text = '0' 49 | 50 | bbox = obj.find('bndbox') 51 | # VOC dataset format follows Matlab, in which indexes start from 0 52 | x1 = max(0,float(bbox.find('xmin').text) - 1) # fixing when -1 in anno 53 | y1 = max(0,float(bbox.find('ymin').text) - 1) # fixing when -1 in anno 54 | x2 = float(bbox.find('xmax').text) - 1 55 | y2 = float(bbox.find('ymax').text) - 1 56 | box = [x1, y1, x2, y2] 57 | 58 | #pascal voc evaluator requires int 59 | bbox.find('xmin').text = str(int(x1)) 60 | bbox.find('ymin').text = str(int(y1)) 61 | bbox.find('xmax').text = str(int(x2)) 62 | bbox.find('ymax').text = str(int(y2)) 63 | 64 | 65 | record_obj = { 66 | "bbox": box, 67 | "bbox_mode": BoxMode.XYXY_ABS, 68 | "category_id": all_class_name.index(class_name), 69 | } 70 | record["annotations"].append(record_obj) 71 | 72 | if len(record["annotations"]): 73 | #to convert float to int 74 | et.write(annotation_file) 75 | record["height"] = h 76 | record["width"] = w 77 | return record 78 | 79 | else: 80 | return None 81 | 82 | def files2dict(root,split): 83 | 84 | cache_dir = os.path.join(root, 'cache') 85 | 86 | pkl_filename = os.path.basename(root)+f'_{split}.pkl' 87 | pkl_path = os.path.join(cache_dir,pkl_filename) 88 | 89 | if os.path.exists(pkl_path): 90 | with open(pkl_path,'rb') as f: 91 | return pkl.load(f) 92 | else: 93 | try: 94 | os.makedirs(cache_dir) 95 | except OSError as e: 96 | if e.errno != errno.EEXIST: 97 | print(e) 98 | pass 99 | 100 | dataset_dicts = [] 101 | image_sets_file = os.path.join( root,'VOC2007', "ImageSets", "Main", "%s.txt" % split) 102 | 103 | with open(image_sets_file) as f: 104 | count = 0 105 | 106 | for line in tqdm(f): 107 | record = get_annotation(root,line.rstrip(),count) 108 | 109 | if record is not None: 110 | dataset_dicts.append(record) 111 | count +=1 112 | 113 | with open(pkl_path, 'wb') as f: 114 | pkl.dump(dataset_dicts,f) 115 | return dataset_dicts 116 | 117 | 118 | def register_dataset(datasets_root): 119 | datasets_root = os.path.join(datasets_root,'diverseWeather') 120 | dataset_list = ['daytime_clear', 121 | 'daytime_foggy', 122 | 'night_sunny', 123 | 'night_rainy', 124 | 'dusk_rainy', 125 | ] 126 | settype = ['train','test'] 127 | 128 | for name in dataset_list: 129 | for ind, d in enumerate(settype): 130 | 131 | DatasetCatalog.register(name+"_" + d, lambda datasets_root=datasets_root,name=name,d=d \ 132 | : files2dict(os.path.join(datasets_root,name), d)) 133 | MetadataCatalog.get(name+ "_" + d).set(thing_classes=all_class_name,evaluator_type='pascal_voc') 134 | MetadataCatalog.get(name+ "_" + d).set(dirname=datasets_root+f'/{name}/VOC2007') 135 | MetadataCatalog.get(name+ "_" + d).set(split=d) 136 | MetadataCatalog.get(name+ "_" + d).set(year=2007) -------------------------------------------------------------------------------- /OD/CLIP/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /OD/CLIP/clip/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import ipdb 7 | 8 | class SlotAttention(nn.Module): 9 | def __init__(self, dim=768, iters=3, eps=1e-8, hidden_dim=512, drop_rate=0.4, feature_size=512): 10 | super().__init__() 11 | self.iters = iters 12 | self.eps = eps 13 | self.scale = dim ** -0.5 14 | self.feature_size = feature_size 15 | 16 | self.to_q = nn.Linear(dim, dim) 17 | slot_share_qk = False 18 | if slot_share_qk: 19 | self.to_k = self.to_q 20 | else: 21 | self.to_k = nn.Linear(dim, dim) 22 | 23 | self.to_v = nn.Linear(feature_size, feature_size) 24 | 25 | hidden_dim = max(dim, hidden_dim, feature_size) 26 | 27 | self.gru = nn.GRUCell(feature_size, feature_size) 28 | self.mlp = nn.Sequential( 29 | nn.Linear(feature_size, hidden_dim), 30 | nn.ReLU(inplace=True), 31 | nn.Linear(hidden_dim, feature_size) 32 | ) 33 | 34 | self.norm_slots = nn.LayerNorm(feature_size) 35 | self.norm_pre_ff = nn.LayerNorm(feature_size) 36 | self.norm_input = nn.LayerNorm(feature_size) 37 | 38 | self.slot_dropout = nn.Dropout(drop_rate) 39 | self.input_dropout = nn.Dropout(drop_rate) 40 | 41 | def forward(self,cand_feat, pano_feat): 42 | b, d = pano_feat.shape 43 | # original cand_feat as the initial slot 44 | slots = cand_feat.clone() 45 | slots = self.slot_dropout(slots) 46 | 47 | pano_feat = self.norm_input(pano_feat.clone()) 48 | pano_feat = self.input_dropout(pano_feat) 49 | 50 | # (bs, num_ctx, hidden_size) 51 | k = self.to_k(pano_feat) 52 | v = self.to_v(pano_feat) 53 | attn_weights = [] 54 | 55 | for t in range(self.iters): 56 | slots_prev = slots 57 | 58 | slots = self.norm_slots(slots.clone()) 59 | 60 | # (bs, num_slots, hidden_size) 61 | q = self.to_q(slots.clone()) 62 | 63 | # (bs, num_slots, num_ctx) 64 | dots = torch.einsum('id,jd->ij', q, k) * self.scale 65 | 66 | attn = dots.softmax(dim=1) 67 | 68 | attn_weights.append(attn) # for visualization 69 | 70 | # (bs, num_slots, feature_size) 71 | updates = torch.einsum('jd,ij->id', v, attn) 72 | 73 | gru_updates = self.gru( 74 | updates.reshape(-1, self.feature_size), 75 | slots_prev.clone().reshape(-1, self.feature_size) 76 | ) 77 | gru_updates = gru_updates + self.mlp(self.norm_pre_ff(gru_updates)) 78 | 79 | slots = gru_updates.clone() 80 | 81 | return slots # , np.stack([a.cpu().detach().numpy() for a in attn_weights], 0) 82 | 83 | 84 | class GumbelSoftmax(nn.Module): 85 | ''' 86 | gumbel softmax gate. 87 | ''' 88 | def __init__(self, eps=1): 89 | super(GumbelSoftmax, self).__init__() 90 | self.eps = eps 91 | self.sigmoid = nn.Sigmoid() 92 | 93 | def gumbel_sample(self, template_tensor, eps=1e-8): 94 | uniform_samples_tensor = template_tensor.clone().uniform_() 95 | gumble_samples_tensor = torch.log(uniform_samples_tensor+eps)-torch.log( 96 | 1-uniform_samples_tensor+eps) 97 | return gumble_samples_tensor 98 | 99 | def gumbel_softmax(self, logits): 100 | """ Draw a sample from the Gumbel-Softmax distribution""" 101 | gsamples = self.gumbel_sample(logits.data) 102 | logits = logits + Variable(gsamples) 103 | soft_samples = self.sigmoid(logits / self.eps) 104 | return soft_samples, logits 105 | 106 | def forward(self, logits): 107 | if not self.training: 108 | out_hard = (logits>=0).float() 109 | return out_hard 110 | out_soft, prob_soft = self.gumbel_softmax(logits) 111 | out_hard = ((out_soft >= 0.5).float() - out_soft).detach() + out_soft 112 | return out_hard 113 | 114 | 115 | 116 | class Mask_s(nn.Module): 117 | ''' 118 | Attention Mask spatial. 119 | ''' 120 | def __init__(self, h, w, planes, block_w, block_h, eps=0.66667, 121 | bias=-1, **kwargs): 122 | super(Mask_s, self).__init__() 123 | # Parameter 124 | self.width, self.height, self.channel = w, h, planes 125 | self.mask_h, self.mask_w = int(np.ceil(h / block_h)), int(np.ceil(w / block_w)) 126 | self.eleNum_s = torch.Tensor([self.mask_h*self.mask_w]) 127 | # spatial attention 128 | self.atten_s = nn.Conv2d(planes, 1, kernel_size=3, stride=1, bias=bias>=0, padding=1) 129 | if bias>=0: 130 | nn.init.constant_(self.atten_s.bias, bias) 131 | # Gate 132 | self.gate_s = GumbelSoftmax(eps=eps) 133 | # Norm 134 | self.norm = lambda x: torch.norm(x, p=1, dim=(1,2,3)) 135 | 136 | def forward(self, x): 137 | 138 | batch, channel, height, width = x.size() # torch.Size([256, 64, 56, 56]) 139 | # Pooling 140 | input_ds = F.adaptive_avg_pool2d(input=x, output_size=(self.mask_h, self.mask_w)) # torch.Size([256, 64, 7, 7]) 141 | # spatial attention 142 | s_in = self.atten_s(input_ds) # [N, 1, h, w] 143 | 144 | # spatial gate 145 | mask_s = self.gate_s(s_in) # [N, 1, h, w] 146 | # norm 147 | norm = self.norm(mask_s) 148 | norm_t = self.eleNum_s.to(x.device) 149 | return mask_s, norm, norm_t 150 | 151 | def get_flops(self): 152 | flops = self.mask_h * self.mask_w * self.channel * 9 153 | return flops 154 | 155 | 156 | class Mask_c(nn.Module): 157 | ''' 158 | Attention Mask. 159 | ''' 160 | def __init__(self, inplanes, outplanes, fc_reduction=4, eps=0.66667, bias=-1, **kwargs): 161 | super(Mask_c, self).__init__() 162 | # Parameter 163 | self.bottleneck = 512 # inplanes // fc_reduction 164 | self.inplanes, self.outplanes = inplanes, outplanes 165 | self.eleNum_c = torch.Tensor([outplanes]) 166 | # channel attention 167 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 168 | self.atten_c_fc1 = nn.Conv2d(inplanes, 512, kernel_size=1) 169 | self.slot_attention = SlotAttention( 170 | dim=512, 171 | iters=3, 172 | drop_rate=0, 173 | ) 174 | 175 | self.atten_c_bn = nn.BatchNorm2d(self.bottleneck) 176 | self.atten_c_act = nn.ReLU(inplace=True) 177 | self.atten_c_conv = nn.Conv2d(self.bottleneck, outplanes, kernel_size=1, stride=1, bias=bias>=0) 178 | 179 | if bias>=0: 180 | nn.init.constant_(self.atten_c_conv.bias, bias) 181 | # Gate 182 | self.gate_c = GumbelSoftmax(eps=eps) 183 | # Norm 184 | self.norm = lambda x: torch.norm(x, p=1, dim=(1,2,3)) 185 | 186 | def forward(self, x, txt_emb): 187 | batch, channel, _, _ = x.size() 188 | context = self.avg_pool(x) # [N, C, 1, 1] 189 | context = self.atten_c_fc1(context) 190 | # transform 191 | c_in = context+self.slot_attention(context.squeeze(-1).squeeze(-1),txt_emb.detach()).unsqueeze(-1).unsqueeze(-1) 192 | c_in = self.atten_c_bn(c_in) 193 | c_in = self.atten_c_act(c_in) 194 | c_in = self.atten_c_conv(c_in) 195 | 196 | # channel gate 197 | mask_c = self.gate_c(c_in) # [N, C_out, 1, 1] 198 | # norm 199 | norm = self.norm(mask_c) 200 | norm_t = self.eleNum_c.to(x.device) 201 | return mask_c, norm, norm_t 202 | 203 | def get_flops(self): 204 | flops = self.inplanes * self.bottleneck + self.bottleneck * self.outplanes 205 | return flops 206 | -------------------------------------------------------------------------------- /OD/CLIP/model-card.md: -------------------------------------------------------------------------------- 1 | # Model Card: CLIP 2 | 3 | Inspired by [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf), we’re providing some accompanying information about the multimodal model. 4 | 5 | ## Model Details 6 | 7 | The CLIP model was developed by researchers at OpenAI to learn about what contributes to robustness in computer vision tasks. The model was also developed to test the ability of models to generalize to arbitrary image classification tasks in a zero-shot manner. It was not developed for general model deployment - to deploy models like CLIP, researchers will first need to carefully study their capabilities in relation to the specific context they’re being deployed within. 8 | 9 | ### Model Date 10 | 11 | January 2021 12 | 13 | ### Model Type 14 | 15 | The base model uses a ResNet50 with several modifications as an image encoder and uses a masked self-attention Transformer as a text encoder. These encoders are trained to maximize the similarity of (image, text) pairs via a contrastive loss. There is also a variant of the model where the ResNet image encoder is replaced with a Vision Transformer. 16 | 17 | ### Model Versions 18 | 19 | Initially, we’ve released one CLIP model based on the Vision Transformer architecture equivalent to ViT-B/32, along with the RN50 model, using the architecture equivalent to ResNet-50. 20 | 21 | As part of the staged release process, we have also released the RN101 model, as well as RN50x4, a RN50 scaled up 4x according to the [EfficientNet](https://arxiv.org/abs/1905.11946) scaling rule. In July 2021, we additionally released the RN50x16 and ViT-B/16 models, and in January 2022, the RN50x64 and ViT-L/14 models were released. Lastly, the ViT-L/14@336px model was released in April 2022. 22 | 23 | Please see the paper linked below for further details about their specification. 24 | 25 | ### Documents 26 | 27 | - [Blog Post](https://openai.com/blog/clip/) 28 | - [CLIP Paper](https://arxiv.org/abs/2103.00020) 29 | 30 | 31 | 32 | ## Model Use 33 | 34 | ### Intended Use 35 | 36 | The model is intended as a research output for research communities. We hope that this model will enable researchers to better understand and explore zero-shot, arbitrary image classification. We also hope it can be used for interdisciplinary studies of the potential impact of such models - the CLIP paper includes a discussion of potential downstream impacts to provide an example for this sort of analysis. 37 | 38 | #### Primary intended uses 39 | 40 | The primary intended users of these models are AI researchers. 41 | 42 | We primarily imagine the model will be used by researchers to better understand robustness, generalization, and other capabilities, biases, and constraints of computer vision models. 43 | 44 | ### Out-of-Scope Use Cases 45 | 46 | **Any** deployed use case of the model - whether commercial or not - is currently out of scope. Non-deployed use cases such as image search in a constrained environment, are also not recommended unless there is thorough in-domain testing of the model with a specific, fixed class taxonomy. This is because our safety assessment demonstrated a high need for task specific testing especially given the variability of CLIP’s performance with different class taxonomies. This makes untested and unconstrained deployment of the model in any use case currently potentially harmful. 47 | 48 | Certain use cases which would fall under the domain of surveillance and facial recognition are always out-of-scope regardless of performance of the model. This is because the use of artificial intelligence for tasks such as these can be premature currently given the lack of testing norms and checks to ensure its fair use. 49 | 50 | Since the model has not been purposefully trained in or evaluated on any languages other than English, its use should be limited to English language use cases. 51 | 52 | 53 | 54 | ## Data 55 | 56 | The model was trained on publicly available image-caption data. This was done through a combination of crawling a handful of websites and using commonly-used pre-existing image datasets such as [YFCC100M](http://projects.dfki.uni-kl.de/yfcc100m/). A large portion of the data comes from our crawling of the internet. This means that the data is more representative of people and societies most connected to the internet which tend to skew towards more developed nations, and younger, male users. 57 | 58 | ### Data Mission Statement 59 | 60 | Our goal with building this dataset was to test out robustness and generalizability in computer vision tasks. As a result, the focus was on gathering large quantities of data from different publicly-available internet data sources. The data was gathered in a mostly non-interventionist manner. However, we only crawled websites that had policies against excessively violent and adult images and allowed us to filter out such content. We do not intend for this dataset to be used as the basis for any commercial or deployed model and will not be releasing the dataset. 61 | 62 | 63 | 64 | ## Performance and Limitations 65 | 66 | ### Performance 67 | 68 | We have evaluated the performance of CLIP on a wide range of benchmarks across a variety of computer vision datasets such as OCR to texture recognition to fine-grained classification. The paper describes model performance on the following datasets: 69 | 70 | - Food101 71 | - CIFAR10 72 | - CIFAR100 73 | - Birdsnap 74 | - SUN397 75 | - Stanford Cars 76 | - FGVC Aircraft 77 | - VOC2007 78 | - DTD 79 | - Oxford-IIIT Pet dataset 80 | - Caltech101 81 | - Flowers102 82 | - MNIST 83 | - SVHN 84 | - IIIT5K 85 | - Hateful Memes 86 | - SST-2 87 | - UCF101 88 | - Kinetics700 89 | - Country211 90 | - CLEVR Counting 91 | - KITTI Distance 92 | - STL-10 93 | - RareAct 94 | - Flickr30 95 | - MSCOCO 96 | - ImageNet 97 | - ImageNet-A 98 | - ImageNet-R 99 | - ImageNet Sketch 100 | - ObjectNet (ImageNet Overlap) 101 | - Youtube-BB 102 | - ImageNet-Vid 103 | 104 | ## Limitations 105 | 106 | CLIP and our analysis of it have a number of limitations. CLIP currently struggles with respect to certain tasks such as fine grained classification and counting objects. CLIP also poses issues with regards to fairness and bias which we discuss in the paper and briefly in the next section. Additionally, our approach to testing CLIP also has an important limitation- in many cases we have used linear probes to evaluate the performance of CLIP and there is evidence suggesting that linear probes can underestimate model performance. 107 | 108 | ### Bias and Fairness 109 | 110 | We find that the performance of CLIP - and the specific biases it exhibits - can depend significantly on class design and the choices one makes for categories to include and exclude. We tested the risk of certain kinds of denigration with CLIP by classifying images of people from [Fairface](https://arxiv.org/abs/1908.04913) into crime-related and non-human animal categories. We found significant disparities with respect to race and gender. Additionally, we found that these disparities could shift based on how the classes were constructed. (Details captured in the Broader Impacts Section in the paper). 111 | 112 | We also tested the performance of CLIP on gender, race and age classification using the Fairface dataset (We default to using race categories as they are constructed in the Fairface dataset.) in order to assess quality of performance across different demographics. We found accuracy >96% across all races for gender classification with ‘Middle Eastern’ having the highest accuracy (98.4%) and ‘White’ having the lowest (96.5%). Additionally, CLIP averaged ~93% for racial classification and ~63% for age classification. Our use of evaluations to test for gender, race and age classification as well as denigration harms is simply to evaluate performance of the model across people and surface potential risks and not to demonstrate an endorsement/enthusiasm for such tasks. 113 | 114 | 115 | 116 | ## Feedback 117 | 118 | ### Where to send questions or comments about the model 119 | 120 | Please use [this Google Form](https://forms.gle/Uv7afRH5dvY34ZEs9) 121 | -------------------------------------------------------------------------------- /OD/CLIP/README.md: -------------------------------------------------------------------------------- 1 | # CLIP 2 | 3 | [[Blog]](https://openai.com/blog/clip/) [[Paper]](https://arxiv.org/abs/2103.00020) [[Model Card]](model-card.md) [[Colab]](https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Interacting_with_CLIP.ipynb) 4 | 5 | CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision. 6 | 7 | 8 | 9 | ## Approach 10 | 11 | ![CLIP](CLIP.png) 12 | 13 | 14 | 15 | ## Usage 16 | 17 | First, [install PyTorch 1.7.1](https://pytorch.org/get-started/locally/) (or later) and torchvision, as well as small additional dependencies, and then install this repo as a Python package. On a CUDA GPU machine, the following will do the trick: 18 | 19 | ```bash 20 | $ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0 21 | $ pip install ftfy regex tqdm 22 | $ pip install git+https://github.com/openai/CLIP.git 23 | ``` 24 | 25 | Replace `cudatoolkit=11.0` above with the appropriate CUDA version on your machine or `cpuonly` when installing on a machine without a GPU. 26 | 27 | ```python 28 | import torch 29 | import clip 30 | from PIL import Image 31 | 32 | device = "cuda" if torch.cuda.is_available() else "cpu" 33 | model, preprocess = clip.load("ViT-B/32", device=device) 34 | 35 | image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device) 36 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) 37 | 38 | with torch.no_grad(): 39 | image_features = model.encode_image(image) 40 | text_features = model.encode_text(text) 41 | 42 | logits_per_image, logits_per_text = model(image, text) 43 | probs = logits_per_image.softmax(dim=-1).cpu().numpy() 44 | 45 | print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]] 46 | ``` 47 | 48 | 49 | ## API 50 | 51 | The CLIP module `clip` provides the following methods: 52 | 53 | #### `clip.available_models()` 54 | 55 | Returns the names of the available CLIP models. 56 | 57 | #### `clip.load(name, device=..., jit=False)` 58 | 59 | Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The `name` argument can also be a path to a local checkpoint. 60 | 61 | The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When `jit` is `False`, a non-JIT version of the model will be loaded. 62 | 63 | #### `clip.tokenize(text: Union[str, List[str]], context_length=77)` 64 | 65 | Returns a LongTensor containing tokenized sequences of given text input(s). This can be used as the input to the model 66 | 67 | --- 68 | 69 | The model returned by `clip.load()` supports the following methods: 70 | 71 | #### `model.encode_image(image: Tensor)` 72 | 73 | Given a batch of images, returns the image features encoded by the vision portion of the CLIP model. 74 | 75 | #### `model.encode_text(text: Tensor)` 76 | 77 | Given a batch of text tokens, returns the text features encoded by the language portion of the CLIP model. 78 | 79 | #### `model(image: Tensor, text: Tensor)` 80 | 81 | Given a batch of images and a batch of text tokens, returns two Tensors, containing the logit scores corresponding to each image and text input. The values are cosine similarities between the corresponding image and text features, times 100. 82 | 83 | 84 | 85 | ## More Examples 86 | 87 | ### Zero-Shot Prediction 88 | 89 | The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset. 90 | 91 | ```python 92 | import os 93 | import clip 94 | import torch 95 | from torchvision.datasets import CIFAR100 96 | 97 | # Load the model 98 | device = "cuda" if torch.cuda.is_available() else "cpu" 99 | model, preprocess = clip.load('ViT-B/32', device) 100 | 101 | # Download the dataset 102 | cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False) 103 | 104 | # Prepare the inputs 105 | image, class_id = cifar100[3637] 106 | image_input = preprocess(image).unsqueeze(0).to(device) 107 | text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) 108 | 109 | # Calculate features 110 | with torch.no_grad(): 111 | image_features = model.encode_image(image_input) 112 | text_features = model.encode_text(text_inputs) 113 | 114 | # Pick the top 5 most similar labels for the image 115 | image_features /= image_features.norm(dim=-1, keepdim=True) 116 | text_features /= text_features.norm(dim=-1, keepdim=True) 117 | similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) 118 | values, indices = similarity[0].topk(5) 119 | 120 | # Print the result 121 | print("\nTop predictions:\n") 122 | for value, index in zip(values, indices): 123 | print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%") 124 | ``` 125 | 126 | The output will look like the following (the exact numbers may be slightly different depending on the compute device): 127 | 128 | ``` 129 | Top predictions: 130 | 131 | snake: 65.31% 132 | turtle: 12.29% 133 | sweet_pepper: 3.83% 134 | lizard: 1.88% 135 | crocodile: 1.75% 136 | ``` 137 | 138 | Note that this example uses the `encode_image()` and `encode_text()` methods that return the encoded features of given inputs. 139 | 140 | 141 | ### Linear-probe evaluation 142 | 143 | The example below uses [scikit-learn](https://scikit-learn.org/) to perform logistic regression on image features. 144 | 145 | ```python 146 | import os 147 | import clip 148 | import torch 149 | 150 | import numpy as np 151 | from sklearn.linear_model import LogisticRegression 152 | from torch.utils.data import DataLoader 153 | from torchvision.datasets import CIFAR100 154 | from tqdm import tqdm 155 | 156 | # Load the model 157 | device = "cuda" if torch.cuda.is_available() else "cpu" 158 | model, preprocess = clip.load('ViT-B/32', device) 159 | 160 | # Load the dataset 161 | root = os.path.expanduser("~/.cache") 162 | train = CIFAR100(root, download=True, train=True, transform=preprocess) 163 | test = CIFAR100(root, download=True, train=False, transform=preprocess) 164 | 165 | 166 | def get_features(dataset): 167 | all_features = [] 168 | all_labels = [] 169 | 170 | with torch.no_grad(): 171 | for images, labels in tqdm(DataLoader(dataset, batch_size=100)): 172 | features = model.encode_image(images.to(device)) 173 | 174 | all_features.append(features) 175 | all_labels.append(labels) 176 | 177 | return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy() 178 | 179 | # Calculate the image features 180 | train_features, train_labels = get_features(train) 181 | test_features, test_labels = get_features(test) 182 | 183 | # Perform logistic regression 184 | classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) 185 | classifier.fit(train_features, train_labels) 186 | 187 | # Evaluate using the logistic regression classifier 188 | predictions = classifier.predict(test_features) 189 | accuracy = np.mean((test_labels == predictions).astype(float)) * 100. 190 | print(f"Accuracy = {accuracy:.3f}") 191 | ``` 192 | 193 | Note that the `C` value should be determined via a hyperparameter sweep using a validation split. 194 | 195 | 196 | ## See Also 197 | 198 | * [OpenCLIP](https://github.com/mlfoundations/open_clip): includes larger and independently trained CLIP models up to ViT-G/14 199 | * [Hugging Face implementation of CLIP](https://huggingface.co/docs/transformers/model_doc/clip): for easier integration with the HF ecosystem 200 | -------------------------------------------------------------------------------- /OD/modeling/roi_head.py: -------------------------------------------------------------------------------- 1 | from pydoc import classname 2 | from typing import Dict, List, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | import torchvision.transforms as T 8 | 9 | 10 | from detectron2.layers import ShapeSpec 11 | from detectron2.data import MetadataCatalog 12 | 13 | from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads 14 | from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou 15 | from .box_predictor import ClipFastRCNNOutputLayers 16 | 17 | def select_foreground_proposals( 18 | proposals: List[Instances], bg_label: int 19 | ) -> Tuple[List[Instances], List[torch.Tensor]]: 20 | """ 21 | Given a list of N Instances (for N images), each containing a `gt_classes` field, 22 | return a list of Instances that contain only instances with `gt_classes != -1 && 23 | gt_classes != bg_label`. 24 | Args: 25 | proposals (list[Instances]): A list of N Instances, where N is the number of 26 | images in the batch. 27 | bg_label: label index of background class. 28 | Returns: 29 | list[Instances]: N Instances, each contains only the selected foreground instances. 30 | list[Tensor]: N boolean vector, correspond to the selection mask of 31 | each Instances object. True for selected instances. 32 | """ 33 | assert isinstance(proposals, (list, tuple)) 34 | assert isinstance(proposals[0], Instances) 35 | assert proposals[0].has("gt_classes") 36 | fg_proposals = [] 37 | fg_selection_masks = [] 38 | for proposals_per_image in proposals: 39 | gt_classes = proposals_per_image.gt_classes 40 | fg_selection_mask = (gt_classes != -1) & (gt_classes != bg_label) 41 | fg_idxs = fg_selection_mask.nonzero().squeeze(1) 42 | fg_proposals.append(proposals_per_image[fg_idxs]) 43 | fg_selection_masks.append(fg_selection_mask) 44 | return fg_proposals, fg_selection_masks 45 | 46 | 47 | @ROI_HEADS_REGISTRY.register() 48 | class ClipRes5ROIHeads(Res5ROIHeads): 49 | def __init__(self, cfg, input_shape) -> None: 50 | super().__init__(cfg, input_shape) 51 | clsnames = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).get("thing_classes").copy() 52 | 53 | # import pdb;pdb.set_trace() 54 | ##change the labels to represent the objects correctly 55 | for name in cfg.MODEL.RENAME: 56 | ind = clsnames.index(name[0]) 57 | clsnames[ind] = name[1] 58 | 59 | out_channels=cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * (2 ** 3) ### copied 60 | self.box_predictor = ClipFastRCNNOutputLayers(cfg, ShapeSpec(channels=out_channels, height=1, width=1), clsnames) 61 | self.clip_im_predictor = self.box_predictor.cls_score # should call it properly 62 | self.device = cfg.MODEL.DEVICE 63 | def forward( 64 | self, 65 | images: ImageList, 66 | features: Dict[str, torch.Tensor], 67 | proposals: List[Instances], 68 | targets: Optional[List[Instances]] = None, 69 | crops: Optional[List[Tuple]] = None, 70 | ): 71 | """ 72 | See :meth:`ROIHeads.forward`. 73 | """ 74 | del images 75 | 76 | if self.training: 77 | assert targets 78 | proposals = self.label_and_sample_proposals(proposals, targets) 79 | # import pdb;pdb.set_trace() 80 | loss_crop_im = None 81 | if crops is not None: 82 | crop_im = list()#[x[0] for x in crops] #bxcropx3x224x224 83 | crop_boxes = list()#[x[1].to(self.device) for x in crops] #bxcropsx4 84 | keep = torch.ones(len(crops)).bool() 85 | 86 | for ind,x in enumerate(crops): 87 | if len(x) == 0: 88 | keep[ind] = False 89 | continue 90 | crop_im.append(x[0]) 91 | crop_boxes.append(x[1].to(self.device)) 92 | 93 | c = self._shared_roi_transform( 94 | [features[f][keep] for f in self.in_features], crop_boxes) #(b*crops)x2048x7x7 95 | loss_crop_im, _ = self.clip_im_predictor.forward_crops(crop_im,crops_features.mean(dim=[2, 3])) 96 | 97 | del targets 98 | 99 | proposal_boxes = [x.proposal_boxes for x in proposals] 100 | box_features = self._shared_roi_transform( 101 | [features[f] for f in self.in_features], proposal_boxes 102 | ) 103 | predictions = self.box_predictor(box_features.mean(dim=[2, 3])) 104 | # import pdb;pdb.set_trace() 105 | if self.training: 106 | del features 107 | losses = self.box_predictor.losses(predictions, proposals) 108 | if self.mask_on: 109 | proposals, fg_selection_masks = select_foreground_proposals( 110 | proposals, self.num_classes 111 | ) 112 | # Since the ROI feature transform is shared between boxes and masks, 113 | # we don't need to recompute features. The mask loss is only defined 114 | # on foreground proposals, so we need to select out the foreground 115 | # features. 116 | mask_features = box_features[torch.cat(fg_selection_masks, dim=0)] 117 | del box_features 118 | losses.update(self.mask_head(mask_features, proposals)) 119 | 120 | if loss_crop_im is not None: 121 | losses.update(loss_crop_im) 122 | return [], losses 123 | else: 124 | pred_instances, _ = self.box_predictor.inference(predictions, proposals) 125 | pred_instances = self.forward_with_given_boxes(features, pred_instances) 126 | return pred_instances, {} 127 | 128 | 129 | @ROI_HEADS_REGISTRY.register() 130 | class ClipRes5ROIHeadsAttn(ClipRes5ROIHeads): 131 | def __init__(self, cfg, input_shape) -> None: 132 | super().__init__(cfg, input_shape) 133 | 134 | def _shared_roi_transform(self, features, txt_emb, boxes): 135 | x = self.pooler(features, boxes) 136 | batch_num, _, _, _ = x.shape 137 | norm1 = torch.zeros(1, batch_num+1).to(x.device) 138 | norm2 = torch.zeros(1, batch_num+1).to(x.device) 139 | flops = torch.zeros(1, batch_num+2).to(x.device) 140 | 141 | return self.fwdres5(x,txt_emb,norm1,norm2,flops) 142 | 143 | 144 | 145 | def forward( 146 | self, 147 | images: ImageList, 148 | txt_emb: torch.Tensor, 149 | features: Dict[str, torch.Tensor], 150 | proposals: List[Instances], 151 | targets: Optional[List[Instances]] = None, 152 | crops: Optional[List[Tuple]] = None, 153 | backbone = None 154 | ): 155 | """ 156 | See :meth:`ROIHeads.forward`. 157 | """ 158 | del images 159 | 160 | self.fwdres5 = backbone.forward_res5 161 | 162 | if self.training: 163 | assert targets 164 | proposals = self.label_and_sample_proposals(proposals, targets) 165 | # import pdb;pdb.set_trace() 166 | loss_crop_im = None 167 | if crops is not None: 168 | crop_im = list()#[x[0] for x in crops] #bxcropx3x224x224 169 | crop_boxes = list()#[x[1].to(self.device) for x in crops] #bxcropsx4 170 | keep = torch.ones(len(crops)).bool() 171 | 172 | for ind,x in enumerate(crops): 173 | if len(x) == 0: 174 | keep[ind] = False 175 | continue 176 | crop_im.append(x[0]) 177 | crop_boxes.append(x[1].to(self.device)) 178 | 179 | crops_features = self._shared_roi_transform( 180 | [features[f][keep] for f in self.in_features], txt_emb, crop_boxes) #(b*crops)x2048x7x7 181 | crops_features = backbone.attention_global_pool(crops_features) 182 | loss_crop_im, _ = self.clip_im_predictor.forward_crops(crop_im,crops_features) 183 | 184 | del targets 185 | 186 | proposal_boxes = [x.proposal_boxes for x in proposals] 187 | box_features = self._shared_roi_transform( 188 | [features[f] for f in self.in_features], txt_emb, proposal_boxes 189 | ) 190 | 191 | attn_feat = backbone.attention_global_pool(box_features) 192 | predictions = self.box_predictor([attn_feat,box_features.mean(dim=(2,3))]) 193 | 194 | if self.training: 195 | del features 196 | 197 | losses = self.box_predictor.losses(predictions, proposals) 198 | 199 | if self.mask_on: 200 | proposals, fg_selection_masks = select_foreground_proposals( 201 | proposals, self.num_classes 202 | ) 203 | # Since the ROI feature transform is shared between boxes and masks, 204 | # we don't need to recompute features. The mask loss is only defined 205 | # on foreground proposals, so we need to select out the foreground 206 | # features. 207 | mask_features = box_features[torch.cat(fg_selection_masks, dim=0)] 208 | del box_features 209 | losses.update(self.mask_head(mask_features, proposals)) 210 | 211 | if loss_crop_im is not None: 212 | losses.update(loss_crop_im) 213 | return [], losses 214 | else: 215 | pred_instances, _ = self.box_predictor.inference(predictions, proposals) 216 | pred_instances = self.forward_with_given_boxes(features, pred_instances) 217 | return pred_instances, {} 218 | 219 | 220 | -------------------------------------------------------------------------------- /OD/modeling/.ipynb_checkpoints/roi_head-checkpoint.py: -------------------------------------------------------------------------------- 1 | from pydoc import classname 2 | from typing import Dict, List, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | import torchvision.transforms as T 8 | 9 | 10 | from detectron2.layers import ShapeSpec 11 | from detectron2.data import MetadataCatalog 12 | 13 | from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads 14 | from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou 15 | from .box_predictor import ClipFastRCNNOutputLayers 16 | 17 | def select_foreground_proposals( 18 | proposals: List[Instances], bg_label: int 19 | ) -> Tuple[List[Instances], List[torch.Tensor]]: 20 | """ 21 | Given a list of N Instances (for N images), each containing a `gt_classes` field, 22 | return a list of Instances that contain only instances with `gt_classes != -1 && 23 | gt_classes != bg_label`. 24 | Args: 25 | proposals (list[Instances]): A list of N Instances, where N is the number of 26 | images in the batch. 27 | bg_label: label index of background class. 28 | Returns: 29 | list[Instances]: N Instances, each contains only the selected foreground instances. 30 | list[Tensor]: N boolean vector, correspond to the selection mask of 31 | each Instances object. True for selected instances. 32 | """ 33 | assert isinstance(proposals, (list, tuple)) 34 | assert isinstance(proposals[0], Instances) 35 | assert proposals[0].has("gt_classes") 36 | fg_proposals = [] 37 | fg_selection_masks = [] 38 | for proposals_per_image in proposals: 39 | gt_classes = proposals_per_image.gt_classes 40 | fg_selection_mask = (gt_classes != -1) & (gt_classes != bg_label) 41 | fg_idxs = fg_selection_mask.nonzero().squeeze(1) 42 | fg_proposals.append(proposals_per_image[fg_idxs]) 43 | fg_selection_masks.append(fg_selection_mask) 44 | return fg_proposals, fg_selection_masks 45 | 46 | 47 | @ROI_HEADS_REGISTRY.register() 48 | class ClipRes5ROIHeads(Res5ROIHeads): 49 | def __init__(self, cfg, input_shape) -> None: 50 | super().__init__(cfg, input_shape) 51 | clsnames = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).get("thing_classes").copy() 52 | 53 | # import pdb;pdb.set_trace() 54 | ##change the labels to represent the objects correctly 55 | for name in cfg.MODEL.RENAME: 56 | ind = clsnames.index(name[0]) 57 | clsnames[ind] = name[1] 58 | 59 | out_channels=cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * (2 ** 3) ### copied 60 | self.box_predictor = ClipFastRCNNOutputLayers(cfg, ShapeSpec(channels=out_channels, height=1, width=1), clsnames) 61 | self.clip_im_predictor = self.box_predictor.cls_score # should call it properly 62 | self.device = cfg.MODEL.DEVICE 63 | def forward( 64 | self, 65 | images: ImageList, 66 | features: Dict[str, torch.Tensor], 67 | proposals: List[Instances], 68 | targets: Optional[List[Instances]] = None, 69 | crops: Optional[List[Tuple]] = None, 70 | ): 71 | """ 72 | See :meth:`ROIHeads.forward`. 73 | """ 74 | del images 75 | 76 | if self.training: 77 | assert targets 78 | proposals = self.label_and_sample_proposals(proposals, targets) 79 | # import pdb;pdb.set_trace() 80 | loss_crop_im = None 81 | if crops is not None: 82 | crop_im = list()#[x[0] for x in crops] #bxcropx3x224x224 83 | crop_boxes = list()#[x[1].to(self.device) for x in crops] #bxcropsx4 84 | keep = torch.ones(len(crops)).bool() 85 | 86 | for ind,x in enumerate(crops): 87 | if len(x) == 0: 88 | keep[ind] = False 89 | continue 90 | crop_im.append(x[0]) 91 | crop_boxes.append(x[1].to(self.device)) 92 | 93 | c = self._shared_roi_transform( 94 | [features[f][keep] for f in self.in_features], crop_boxes) #(b*crops)x2048x7x7 95 | loss_crop_im, _ = self.clip_im_predictor.forward_crops(crop_im,crops_features.mean(dim=[2, 3])) 96 | 97 | del targets 98 | 99 | proposal_boxes = [x.proposal_boxes for x in proposals] 100 | box_features = self._shared_roi_transform( 101 | [features[f] for f in self.in_features], proposal_boxes 102 | ) 103 | predictions = self.box_predictor(box_features.mean(dim=[2, 3])) 104 | # import pdb;pdb.set_trace() 105 | if self.training: 106 | del features 107 | losses = self.box_predictor.losses(predictions, proposals) 108 | if self.mask_on: 109 | proposals, fg_selection_masks = select_foreground_proposals( 110 | proposals, self.num_classes 111 | ) 112 | # Since the ROI feature transform is shared between boxes and masks, 113 | # we don't need to recompute features. The mask loss is only defined 114 | # on foreground proposals, so we need to select out the foreground 115 | # features. 116 | mask_features = box_features[torch.cat(fg_selection_masks, dim=0)] 117 | del box_features 118 | losses.update(self.mask_head(mask_features, proposals)) 119 | 120 | if loss_crop_im is not None: 121 | losses.update(loss_crop_im) 122 | return [], losses 123 | else: 124 | pred_instances, _ = self.box_predictor.inference(predictions, proposals) 125 | pred_instances = self.forward_with_given_boxes(features, pred_instances) 126 | return pred_instances, {} 127 | 128 | 129 | @ROI_HEADS_REGISTRY.register() 130 | class ClipRes5ROIHeadsAttn(ClipRes5ROIHeads): 131 | def __init__(self, cfg, input_shape) -> None: 132 | super().__init__(cfg, input_shape) 133 | 134 | def _shared_roi_transform(self, features, txt_emb, boxes): 135 | x = self.pooler(features, boxes) 136 | batch_num, _, _, _ = x.shape 137 | norm1 = torch.zeros(1, batch_num+1).to(x.device) 138 | norm2 = torch.zeros(1, batch_num+1).to(x.device) 139 | flops = torch.zeros(1, batch_num+2).to(x.device) 140 | 141 | return self.fwdres5(x,txt_emb,norm1,norm2,flops) 142 | 143 | 144 | 145 | def forward( 146 | self, 147 | images: ImageList, 148 | txt_emb: torch.Tensor, 149 | features: Dict[str, torch.Tensor], 150 | proposals: List[Instances], 151 | targets: Optional[List[Instances]] = None, 152 | crops: Optional[List[Tuple]] = None, 153 | backbone = None 154 | ): 155 | """ 156 | See :meth:`ROIHeads.forward`. 157 | """ 158 | del images 159 | 160 | self.fwdres5 = backbone.forward_res5 161 | 162 | if self.training: 163 | assert targets 164 | proposals = self.label_and_sample_proposals(proposals, targets) 165 | # import pdb;pdb.set_trace() 166 | loss_crop_im = None 167 | if crops is not None: 168 | crop_im = list()#[x[0] for x in crops] #bxcropx3x224x224 169 | crop_boxes = list()#[x[1].to(self.device) for x in crops] #bxcropsx4 170 | keep = torch.ones(len(crops)).bool() 171 | 172 | for ind,x in enumerate(crops): 173 | if len(x) == 0: 174 | keep[ind] = False 175 | continue 176 | crop_im.append(x[0]) 177 | crop_boxes.append(x[1].to(self.device)) 178 | 179 | crops_features = self._shared_roi_transform( 180 | [features[f][keep] for f in self.in_features], txt_emb, crop_boxes) #(b*crops)x2048x7x7 181 | crops_features = backbone.attention_global_pool(crops_features) 182 | loss_crop_im, _ = self.clip_im_predictor.forward_crops(crop_im,crops_features) 183 | 184 | del targets 185 | 186 | proposal_boxes = [x.proposal_boxes for x in proposals] 187 | box_features = self._shared_roi_transform( 188 | [features[f] for f in self.in_features], txt_emb, proposal_boxes 189 | ) 190 | 191 | attn_feat = backbone.attention_global_pool(box_features) 192 | predictions = self.box_predictor([attn_feat,box_features.mean(dim=(2,3))]) 193 | 194 | if self.training: 195 | del features 196 | 197 | losses = self.box_predictor.losses(predictions, proposals) 198 | 199 | if self.mask_on: 200 | proposals, fg_selection_masks = select_foreground_proposals( 201 | proposals, self.num_classes 202 | ) 203 | # Since the ROI feature transform is shared between boxes and masks, 204 | # we don't need to recompute features. The mask loss is only defined 205 | # on foreground proposals, so we need to select out the foreground 206 | # features. 207 | mask_features = box_features[torch.cat(fg_selection_masks, dim=0)] 208 | del box_features 209 | losses.update(self.mask_head(mask_features, proposals)) 210 | 211 | if loss_crop_im is not None: 212 | losses.update(loss_crop_im) 213 | return [], losses 214 | else: 215 | pred_instances, _ = self.box_predictor.inference(predictions, proposals) 216 | pred_instances = self.forward_with_given_boxes(features, pred_instances) 217 | return pred_instances, {} 218 | 219 | 220 | -------------------------------------------------------------------------------- /OD/CLIP/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 95 | """Load a CLIP model 96 | 97 | Parameters 98 | ---------- 99 | name : str 100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 101 | 102 | device : Union[str, torch.device] 103 | The device to put the loaded model 104 | 105 | jit : bool 106 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 107 | 108 | download_root: str 109 | path to download the model files; by default, it uses "~/.cache/clip" 110 | 111 | Returns 112 | ------- 113 | model : torch.nn.Module 114 | The CLIP model 115 | 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | with open(model_path, 'rb') as opened_file: 127 | try: 128 | # loading JIT archive 129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 130 | state_dict = None 131 | except RuntimeError: 132 | # loading saved state dict 133 | if jit: 134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 135 | jit = False 136 | state_dict = torch.load(opened_file, map_location="cpu") 137 | 138 | if not jit: 139 | model = build_model(state_dict or model.state_dict()).to(device) 140 | if str(device) == "cpu": 141 | model.float() 142 | return model, _transform(model.visual.input_resolution) 143 | 144 | # patch the device names 145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 147 | 148 | def _node_get(node: torch._C.Node, key: str): 149 | """Gets attributes of a node which is polymorphic over return type. 150 | 151 | From https://github.com/pytorch/pytorch/pull/82628 152 | """ 153 | sel = node.kindOf(key) 154 | return getattr(node, sel)(key) 155 | 156 | def patch_device(module): 157 | try: 158 | graphs = [module.graph] if hasattr(module, "graph") else [] 159 | except RuntimeError: 160 | graphs = [] 161 | 162 | if hasattr(module, "forward1"): 163 | graphs.append(module.forward1.graph) 164 | 165 | for graph in graphs: 166 | for node in graph.findAllNodes("prim::Constant"): 167 | if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): 168 | node.copyAttributes(device_node) 169 | 170 | model.apply(patch_device) 171 | patch_device(model.encode_image) 172 | patch_device(model.encode_text) 173 | 174 | # patch dtype to float32 on CPU 175 | if str(device) == "cpu": 176 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 177 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 178 | float_node = float_input.node() 179 | 180 | def patch_float(module): 181 | try: 182 | graphs = [module.graph] if hasattr(module, "graph") else [] 183 | except RuntimeError: 184 | graphs = [] 185 | 186 | if hasattr(module, "forward1"): 187 | graphs.append(module.forward1.graph) 188 | 189 | for graph in graphs: 190 | for node in graph.findAllNodes("aten::to"): 191 | inputs = list(node.inputs()) 192 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 193 | if _node_get(inputs[i].node(), "value") == 5: 194 | inputs[i].node().copyAttributes(float_node) 195 | 196 | model.apply(patch_float) 197 | patch_float(model.encode_image) 198 | patch_float(model.encode_text) 199 | 200 | model.float() 201 | 202 | return model, _transform(model.input_resolution.item()) 203 | 204 | 205 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 206 | """ 207 | Returns the tokenized representation of given input string(s) 208 | 209 | Parameters 210 | ---------- 211 | texts : Union[str, List[str]] 212 | An input string or a list of input strings to tokenize 213 | 214 | context_length : int 215 | The context length to use; all CLIP models use 77 as the context length 216 | 217 | truncate: bool 218 | Whether to truncate the text in case its encoding is longer than the context length 219 | 220 | Returns 221 | ------- 222 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 223 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 224 | """ 225 | if isinstance(texts, str): 226 | texts = [texts] 227 | 228 | sot_token = _tokenizer.encoder["<|startoftext|>"] 229 | eot_token = _tokenizer.encoder["<|endoftext|>"] 230 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 231 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 232 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 233 | else: 234 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 235 | 236 | for i, tokens in enumerate(all_tokens): 237 | if len(tokens) > context_length: 238 | if truncate: 239 | tokens = tokens[:context_length] 240 | tokens[-1] = eot_token 241 | else: 242 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 243 | result[i, :len(tokens)] = torch.tensor(tokens) 244 | 245 | return result 246 | -------------------------------------------------------------------------------- /OD/modeling/meta_arch.py: -------------------------------------------------------------------------------- 1 | from ast import mod 2 | import math 3 | import numpy as np 4 | import cv2 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | import torchvision.transforms as T 11 | 12 | from typing import Dict,List,Optional 13 | 14 | from detectron2.modeling import META_ARCH_REGISTRY, GeneralizedRCNN 15 | from detectron2.structures import ImageList, Instances, pairwise_iou 16 | from detectron2.utils.events import get_event_storage 17 | from detectron2.layers import batched_nms 18 | from detectron2.data.detection_utils import convert_image_to_rgb 19 | from detectron2.utils.visualizer import Visualizer 20 | 21 | 22 | def show_cam_on_image(img, mask): 23 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 24 | heatmap = np.float32(heatmap) / 255 25 | cam = heatmap + np.float32(img) 26 | cam = cam / np.max(cam) 27 | return cam 28 | 29 | 30 | def generate_visualization(feat, input_img_tensor, scale_factor=60, size=(224, 224)): 31 | """ 32 | 33 | :param feat: 3D ,4D 34 | :param input_img_tensor: 35 | :param size: 36 | :return: 37 | """ 38 | input_img_tensor = input_img_tensor[:, :size[0], :size[1]] 39 | feat = torch.nn.functional.interpolate(feat, scale_factor=scale_factor, mode='bilinear') 40 | feat = feat.reshape(size).cuda().data.cpu().numpy() 41 | feat = (feat - feat.min()) / ( 42 | feat.max() - feat.min()) 43 | image_feat = input_img_tensor.permute(1, 2, 0).data.cpu().numpy() 44 | image_feat = (image_feat - image_feat.min()) / ( 45 | image_feat.max() - image_feat.min()) 46 | vis = show_cam_on_image(image_feat, feat) 47 | vis = np.uint8(255 * vis) 48 | return vis 49 | 50 | 51 | @META_ARCH_REGISTRY.register() 52 | class ClipRCNNWithClipBackbone(GeneralizedRCNN): 53 | 54 | def __init__(self,cfg) -> None: 55 | super().__init__(cfg) 56 | self.cfg = cfg 57 | self.colors = self.generate_colors(7) 58 | self.backbone.set_backbone_model(self.roi_heads.box_predictor.cls_score.visual_enc) 59 | 60 | # txt 61 | domain_text = {'day': 'an image taken during the day'} 62 | with open('prunedprompts.txt','r') as f: 63 | for ind,l in enumerate(f): 64 | domain_text.update({str(ind):l.strip()}) 65 | self.offsets = nn.Parameter(torch.zeros(len(domain_text)-1,1024,14,14)) #skip day 66 | 67 | import clip 68 | self.domain_tk = dict([(k,clip.tokenize(t)) for k,t in domain_text.items()]) 69 | self.apply_aug = cfg.AUG_PROB 70 | 71 | day_text_embed_list = [] 72 | for i,val in enumerate(self.domain_tk.items()): 73 | name , dtk = val 74 | if name == 'day': 75 | continue 76 | with torch.no_grad(): 77 | 78 | day_text_embed = self.roi_heads.box_predictor.cls_score.model.encode_text(self.domain_tk['day'].cuda()) #day 79 | day_text_embed = day_text_embed/day_text_embed.norm(dim=-1,keepdim=True) 80 | day_text_embed_list.append(day_text_embed) 81 | self.day_text_embeds = torch.cat(day_text_embed_list,0).cuda() 82 | 83 | def preprocess_image(self, batched_inputs: List[Dict[str, torch.Tensor]]): 84 | """ 85 | Normalize, pad and batch the input images. 86 | """ 87 | clip_images = [x["image"].to(self.pixel_mean.device) for x in batched_inputs] 88 | mean=[0.48145466, 0.4578275, 0.40821073] 89 | std=[0.26862954, 0.26130258, 0.27577711] 90 | 91 | 92 | clip_images = [ T.functional.normalize(ci.flip(0)/255, mean,std) for ci in clip_images] 93 | clip_images = ImageList.from_tensors( 94 | [i for i in clip_images]) 95 | return clip_images 96 | 97 | 98 | def forward(self, batched_inputs): 99 | 100 | if not self.training: 101 | return self.inference(batched_inputs) 102 | 103 | images = self.preprocess_image(batched_inputs) 104 | b = images.tensor.shape[0]#batchsize 105 | 106 | if "instances" in batched_inputs[0]: 107 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 108 | 109 | features = self.backbone(images.tensor) 110 | 111 | if self.proposal_generator is not None: 112 | if self.training: 113 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 114 | else: 115 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 116 | else: 117 | assert "proposals" in batched_inputs[0] 118 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 119 | proposal_losses = {} 120 | 121 | try: 122 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, None, self.backbone) 123 | except Exception as e: 124 | print(e) 125 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, None) 126 | 127 | if self.vis_period > 0: 128 | storage = get_event_storage() 129 | if storage.iter % self.vis_period == 0: 130 | self.visualize_training(batched_inputs, proposals) 131 | with torch.no_grad(): 132 | ogimage = batched_inputs[0]['image'] 133 | ogimage = convert_image_to_rgb(ogimage.permute(1, 2, 0), self.input_format) 134 | o_pred = Visualizer(ogimage, None).overlay_instances().get_image() 135 | 136 | vis_img = o_pred.transpose(2, 0, 1) 137 | storage.put_image('og-tfimage', vis_img) 138 | 139 | losses = {} 140 | losses.update(detector_losses) 141 | losses.update(proposal_losses) 142 | return losses 143 | 144 | def generate_colors(self,N): 145 | import colorsys 146 | ''' 147 | Generate random colors. 148 | To get visually distinct colors, generate them in HSV space then 149 | convert to RGB. 150 | ''' 151 | brightness = 0.7 152 | hsv = [(i / N, 1, brightness) for i in range(N)] 153 | colors = list(map(lambda c: tuple(round(i * 255) for i in colorsys.hsv_to_rgb(*c)), hsv)) 154 | perm = np.arange(7) 155 | colors = [colors[idx] for idx in perm] 156 | return colors 157 | 158 | 159 | def inference( 160 | self, 161 | batched_inputs: List[Dict[str, torch.Tensor]], 162 | detected_instances: Optional[List[Instances]] = None, 163 | do_postprocess: bool = True, 164 | ): 165 | """ 166 | Run inference on the given inputs. 167 | Args: 168 | batched_inputs (list[dict]): same as in :meth:`forward` 169 | detected_instances (None or list[Instances]): if not None, it 170 | contains an `Instances` object per image. The `Instances` 171 | object contains "pred_boxes" and "pred_classes" which are 172 | known boxes in the image. 173 | The inference will then skip the detection of bounding boxes, 174 | and only predict other per-ROI outputs. 175 | do_postprocess (bool): whether to apply post-processing on the outputs. 176 | Returns: 177 | When do_postprocess=True, same as in :meth:`forward`. 178 | Otherwise, a list[Instances] containing raw network outputs. 179 | """ 180 | assert not self.training 181 | 182 | images = self.preprocess_image(batched_inputs) 183 | features = self.backbone(images.tensor) 184 | 185 | 186 | ###############save feat vis################### 187 | feat_vis = False 188 | if feat_vis: 189 | scale_size = 16 190 | out_dir_explain = os.path.join('./output', 'featmap_vis') 191 | explain_rpn_feat_i, _ = torch.max(features['res4'][0], 0) 192 | explain_rpn_feat_i = explain_rpn_feat_i.unsqueeze(0).unsqueeze(0) 193 | 194 | size = (explain_rpn_feat_i.shape[-2] * scale_size , explain_rpn_feat_i.shape[-1] *scale_size) 195 | visual = generate_visualization(explain_rpn_feat_i, images.tensor[0], scale_factor=scale_size, size=size) 196 | 197 | name_sp_list = batched_inputs[0]['file_name'].split('/')[-1].rsplit('.', 1) 198 | save_file_name = name_sp_list[0] + '.' + name_sp_list[1] 199 | explain_out_file = os.path.join(out_dir_explain, save_file_name) 200 | cv2.imwrite(explain_out_file, visual) 201 | 202 | ###############save feat vis ################### 203 | 204 | if detected_instances is None: 205 | if self.proposal_generator is not None: 206 | logits,proposals, _ = self.proposal_generator(images, features, None) 207 | else: 208 | assert "proposals" in batched_inputs[0] 209 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 210 | 211 | 212 | try: 213 | results, _ = self.roi_heads(images,self.day_text_embeds, features, proposals, None, None, self.backbone) 214 | except: 215 | results, _ = self.roi_heads(images,self.day_text_embeds, features, proposals, None, None) 216 | else: 217 | detected_instances = [x.to(self.device) for x in detected_instances] 218 | results = self.roi_heads.forward_with_given_boxes(features, detected_instances) 219 | 220 | 221 | if do_postprocess: 222 | assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess." 223 | 224 | allresults = GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes) 225 | 226 | 227 | return allresults 228 | else: 229 | return results 230 | 231 | 232 | @META_ARCH_REGISTRY.register() 233 | class ClipRCNNWithClipBackboneTrainable(ClipRCNNWithClipBackbone): 234 | def __init__(self,cfg) -> None: 235 | super().__init__(cfg) 236 | 237 | def forward(self, batched_inputs): 238 | 239 | if not self.training: 240 | return self.inference(batched_inputs) 241 | 242 | images = self.preprocess_image(batched_inputs) 243 | b = images.tensor.shape[0]#batchsize 244 | 245 | if "instances" in batched_inputs[0]: 246 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 247 | 248 | features = self.backbone(images.tensor) 249 | 250 | if self.proposal_generator is not None: 251 | if self.training: 252 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 253 | else: 254 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 255 | else: 256 | assert "proposals" in batched_inputs[0] 257 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 258 | proposal_losses = {} 259 | 260 | _, detector_losses = self.roi_heads(images, self.day_text_embeds, features, proposals, gt_instances, None, self.backbone) 261 | 262 | losses = {} 263 | losses.update(detector_losses) 264 | losses.update(proposal_losses) 265 | return losses 266 | 267 | 268 | -------------------------------------------------------------------------------- /OD/modeling/.ipynb_checkpoints/meta_arch-checkpoint.py: -------------------------------------------------------------------------------- 1 | from ast import mod 2 | import math 3 | import numpy as np 4 | import cv2 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | import torchvision.transforms as T 11 | 12 | from typing import Dict,List,Optional 13 | 14 | from detectron2.modeling import META_ARCH_REGISTRY, GeneralizedRCNN 15 | from detectron2.structures import ImageList, Instances, pairwise_iou 16 | from detectron2.utils.events import get_event_storage 17 | from detectron2.layers import batched_nms 18 | from detectron2.data.detection_utils import convert_image_to_rgb 19 | from detectron2.utils.visualizer import Visualizer 20 | # from .regularization import * 21 | 22 | def show_cam_on_image(img, mask): 23 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 24 | heatmap = np.float32(heatmap) / 255 25 | cam = heatmap + np.float32(img) 26 | cam = cam / np.max(cam) 27 | return cam 28 | 29 | 30 | def generate_visualization(feat, input_img_tensor, scale_factor=60, size=(224, 224)): 31 | """ 32 | 33 | :param feat: 3D ,4D 34 | :param input_img_tensor: 35 | :param size: 36 | :return: 37 | """ 38 | input_img_tensor = input_img_tensor[:, :size[0], :size[1]] 39 | feat = torch.nn.functional.interpolate(feat, scale_factor=scale_factor, mode='bilinear') 40 | feat = feat.reshape(size).cuda().data.cpu().numpy() 41 | feat = (feat - feat.min()) / ( 42 | feat.max() - feat.min()) 43 | image_feat = input_img_tensor.permute(1, 2, 0).data.cpu().numpy() 44 | image_feat = (image_feat - image_feat.min()) / ( 45 | image_feat.max() - image_feat.min()) 46 | vis = show_cam_on_image(image_feat, feat) 47 | vis = np.uint8(255 * vis) 48 | return vis 49 | 50 | 51 | @META_ARCH_REGISTRY.register() 52 | class ClipRCNNWithClipBackbone(GeneralizedRCNN): 53 | 54 | def __init__(self,cfg) -> None: 55 | super().__init__(cfg) 56 | self.cfg = cfg 57 | self.colors = self.generate_colors(7) 58 | self.backbone.set_backbone_model(self.roi_heads.box_predictor.cls_score.visual_enc) 59 | 60 | # txt 61 | domain_text = {'day': 'an image taken during the day'} 62 | with open('prunedprompts2.txt','r') as f: 63 | for ind,l in enumerate(f): 64 | domain_text.update({str(ind):l.strip()}) 65 | self.offsets = nn.Parameter(torch.zeros(len(domain_text)-1,1024,14,14)) #skip day 66 | 67 | import clip 68 | self.domain_tk = dict([(k,clip.tokenize(t)) for k,t in domain_text.items()]) 69 | self.apply_aug = cfg.AUG_PROB 70 | 71 | day_text_embed_list = [] 72 | for i,val in enumerate(self.domain_tk.items()): 73 | name , dtk = val 74 | if name == 'day': 75 | continue 76 | with torch.no_grad(): 77 | 78 | day_text_embed = self.roi_heads.box_predictor.cls_score.model.encode_text(self.domain_tk['day'].cuda()) #day 79 | day_text_embed = day_text_embed/day_text_embed.norm(dim=-1,keepdim=True) 80 | day_text_embed_list.append(day_text_embed) 81 | self.day_text_embeds = torch.cat(day_text_embed_list,0).cuda() 82 | 83 | def preprocess_image(self, batched_inputs: List[Dict[str, torch.Tensor]]): 84 | """ 85 | Normalize, pad and batch the input images. 86 | """ 87 | clip_images = [x["image"].to(self.pixel_mean.device) for x in batched_inputs] 88 | mean=[0.48145466, 0.4578275, 0.40821073] 89 | std=[0.26862954, 0.26130258, 0.27577711] 90 | 91 | 92 | clip_images = [ T.functional.normalize(ci.flip(0)/255, mean,std) for ci in clip_images] 93 | clip_images = ImageList.from_tensors( 94 | [i for i in clip_images]) 95 | return clip_images 96 | 97 | 98 | def forward(self, batched_inputs): 99 | 100 | if not self.training: 101 | return self.inference(batched_inputs) 102 | 103 | images = self.preprocess_image(batched_inputs) 104 | b = images.tensor.shape[0]#batchsize 105 | 106 | if "instances" in batched_inputs[0]: 107 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 108 | 109 | features = self.backbone(images.tensor) 110 | 111 | if self.proposal_generator is not None: 112 | if self.training: 113 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 114 | else: 115 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 116 | else: 117 | assert "proposals" in batched_inputs[0] 118 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 119 | proposal_losses = {} 120 | 121 | try: 122 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, None, self.backbone) 123 | except Exception as e: 124 | print(e) 125 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, None) 126 | 127 | if self.vis_period > 0: 128 | storage = get_event_storage() 129 | if storage.iter % self.vis_period == 0: 130 | self.visualize_training(batched_inputs, proposals) 131 | with torch.no_grad(): 132 | ogimage = batched_inputs[0]['image'] 133 | ogimage = convert_image_to_rgb(ogimage.permute(1, 2, 0), self.input_format) 134 | o_pred = Visualizer(ogimage, None).overlay_instances().get_image() 135 | 136 | vis_img = o_pred.transpose(2, 0, 1) 137 | storage.put_image('og-tfimage', vis_img) 138 | 139 | losses = {} 140 | losses.update(detector_losses) 141 | losses.update(proposal_losses) 142 | return losses 143 | 144 | def generate_colors(self,N): 145 | import colorsys 146 | ''' 147 | Generate random colors. 148 | To get visually distinct colors, generate them in HSV space then 149 | convert to RGB. 150 | ''' 151 | brightness = 0.7 152 | hsv = [(i / N, 1, brightness) for i in range(N)] 153 | colors = list(map(lambda c: tuple(round(i * 255) for i in colorsys.hsv_to_rgb(*c)), hsv)) 154 | perm = np.arange(7) 155 | colors = [colors[idx] for idx in perm] 156 | return colors 157 | 158 | 159 | def inference( 160 | self, 161 | batched_inputs: List[Dict[str, torch.Tensor]], 162 | detected_instances: Optional[List[Instances]] = None, 163 | do_postprocess: bool = True, 164 | ): 165 | """ 166 | Run inference on the given inputs. 167 | Args: 168 | batched_inputs (list[dict]): same as in :meth:`forward` 169 | detected_instances (None or list[Instances]): if not None, it 170 | contains an `Instances` object per image. The `Instances` 171 | object contains "pred_boxes" and "pred_classes" which are 172 | known boxes in the image. 173 | The inference will then skip the detection of bounding boxes, 174 | and only predict other per-ROI outputs. 175 | do_postprocess (bool): whether to apply post-processing on the outputs. 176 | Returns: 177 | When do_postprocess=True, same as in :meth:`forward`. 178 | Otherwise, a list[Instances] containing raw network outputs. 179 | """ 180 | assert not self.training 181 | 182 | images = self.preprocess_image(batched_inputs) 183 | features = self.backbone(images.tensor) 184 | 185 | 186 | ###############save feat vis################### 187 | feat_vis = False 188 | if feat_vis: 189 | scale_size = 16 190 | out_dir_explain = os.path.join('./output', 'featmap_vis') 191 | explain_rpn_feat_i, _ = torch.max(features['res4'][0], 0) 192 | explain_rpn_feat_i = explain_rpn_feat_i.unsqueeze(0).unsqueeze(0) 193 | 194 | size = (explain_rpn_feat_i.shape[-2] * scale_size , explain_rpn_feat_i.shape[-1] *scale_size) 195 | visual = generate_visualization(explain_rpn_feat_i, images.tensor[0], scale_factor=scale_size, size=size) 196 | 197 | name_sp_list = batched_inputs[0]['file_name'].split('/')[-1].rsplit('.', 1) 198 | save_file_name = name_sp_list[0] + '.' + name_sp_list[1] 199 | explain_out_file = os.path.join(out_dir_explain, save_file_name) 200 | cv2.imwrite(explain_out_file, visual) 201 | 202 | ###############save feat vis ################### 203 | 204 | if detected_instances is None: 205 | if self.proposal_generator is not None: 206 | logits,proposals, _ = self.proposal_generator(images, features, None) 207 | else: 208 | assert "proposals" in batched_inputs[0] 209 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 210 | 211 | 212 | try: 213 | results, _ = self.roi_heads(images,self.day_text_embeds, features, proposals, None, None, self.backbone) 214 | except: 215 | results, _ = self.roi_heads(images,self.day_text_embeds, features, proposals, None, None) 216 | else: 217 | detected_instances = [x.to(self.device) for x in detected_instances] 218 | results = self.roi_heads.forward_with_given_boxes(features, detected_instances) 219 | 220 | 221 | if do_postprocess: 222 | assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess." 223 | 224 | allresults = GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes) 225 | 226 | 227 | return allresults 228 | else: 229 | return results 230 | 231 | 232 | @META_ARCH_REGISTRY.register() 233 | class ClipRCNNWithClipBackboneTrainable(ClipRCNNWithClipBackbone): 234 | def __init__(self,cfg) -> None: 235 | super().__init__(cfg) 236 | 237 | def forward(self, batched_inputs): 238 | 239 | if not self.training: 240 | return self.inference(batched_inputs) 241 | 242 | images = self.preprocess_image(batched_inputs) 243 | b = images.tensor.shape[0]#batchsize 244 | 245 | if "instances" in batched_inputs[0]: 246 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 247 | 248 | features = self.backbone(images.tensor) 249 | 250 | if self.proposal_generator is not None: 251 | if self.training: 252 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 253 | else: 254 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 255 | else: 256 | assert "proposals" in batched_inputs[0] 257 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 258 | proposal_losses = {} 259 | 260 | _, detector_losses = self.roi_heads(images, self.day_text_embeds, features, proposals, gt_instances, None, self.backbone) 261 | 262 | losses = {} 263 | losses.update(detector_losses) 264 | losses.update(proposal_losses) 265 | return losses 266 | 267 | 268 | -------------------------------------------------------------------------------- /OD/train.py: -------------------------------------------------------------------------------- 1 | from cgi import parse_multipart 2 | import os 3 | import logging 4 | import time 5 | from collections import OrderedDict, Counter 6 | import copy 7 | 8 | import numpy as np 9 | 10 | import torch 11 | from torch import autograd 12 | import torch.utils.data as torchdata 13 | 14 | from detectron2 import model_zoo 15 | from detectron2.config import get_cfg 16 | from detectron2.layers.batch_norm import FrozenBatchNorm2d 17 | from detectron2.engine import DefaultPredictor, DefaultTrainer, default_setup 18 | from detectron2.engine import default_argument_parser, hooks, HookBase 19 | from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping, build_lr_scheduler 20 | from detectron2.checkpoint import DetectionCheckpointer 21 | from detectron2.data import build_detection_train_loader, build_detection_test_loader, get_detection_dataset_dicts 22 | from detectron2.data.common import DatasetFromList, MapDataset 23 | from detectron2.data.samplers import InferenceSampler 24 | from detectron2.utils.events import get_event_storage 25 | 26 | from detectron2.utils import comm 27 | from detectron2.evaluation import COCOEvaluator, verify_results, inference_on_dataset, print_csv_format 28 | 29 | from detectron2.solver import LRMultiplier 30 | from detectron2.modeling import build_model 31 | from detectron2.structures import ImageList, Instances, pairwise_iou, Boxes 32 | 33 | from fvcore.common.param_scheduler import ParamScheduler 34 | from fvcore.common.checkpoint import Checkpointer 35 | 36 | from data.datasets import builtin 37 | 38 | from detectron2.evaluation import PascalVOCDetectionEvaluator, COCOEvaluator, inference_on_dataset 39 | 40 | from detectron2.data import build_detection_train_loader, MetadataCatalog 41 | import torch.utils.data as data 42 | from detectron2.data.dataset_mapper import DatasetMapper 43 | import detectron2.data.detection_utils as utils 44 | import detectron2.data.transforms as detT 45 | 46 | import torchvision.transforms as T 47 | import torchvision.transforms.functional as tF 48 | 49 | from modeling import add_stn_config 50 | from modeling import CustomPascalVOCDetectionEvaluator 51 | import ipdb 52 | 53 | logger = logging.getLogger("detectron2") 54 | 55 | def setup(args): 56 | cfg = get_cfg() 57 | add_stn_config(cfg) 58 | #hack to add base yaml 59 | cfg.merge_from_file(args.config_file) 60 | cfg.merge_from_file(model_zoo.get_config_file(cfg.BASE_YAML)) 61 | cfg.merge_from_file(args.config_file) 62 | cfg.merge_from_list(args.opts) 63 | default_setup(cfg, args) 64 | return cfg 65 | 66 | class CustomDatasetMapper(DatasetMapper): 67 | def __init__(self,cfg,is_train) -> None: 68 | super().__init__(cfg,is_train) 69 | self.with_crops = cfg.INPUT.CLIP_WITH_IMG 70 | self.with_random_clip_crops = cfg.INPUT.CLIP_RANDOM_CROPS 71 | self.with_jitter = cfg.INPUT.IMAGE_JITTER 72 | self.cropfn = T.RandomCrop#T.RandomCrop([224,224]) 73 | self.aug = T.ColorJitter(brightness=.5, hue=.3) 74 | self.crop_size = cfg.INPUT.RANDOM_CROP_SIZE 75 | 76 | def __call__(self,dataset_dict): 77 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 78 | # USER: Write your own image loading if it's not from a file 79 | image = utils.read_image(dataset_dict["file_name"], format=self.image_format) 80 | utils.check_image_size(dataset_dict, image) 81 | 82 | # USER: Remove if you don't do semantic/panoptic segmentation. 83 | if "sem_seg_file_name" in dataset_dict: 84 | sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) 85 | else: 86 | sem_seg_gt = None 87 | 88 | aug_input = detT.AugInput(image, sem_seg=sem_seg_gt) 89 | transforms = self.augmentations(aug_input) 90 | image, sem_seg_gt = aug_input.image, aug_input.sem_seg 91 | 92 | image_shape = image.shape[:2] # h, w 93 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 94 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 95 | # Therefore it's important to use torch.Tensor. 96 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 97 | if sem_seg_gt is not None: 98 | dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) 99 | 100 | # USER: Remove if you don't use pre-computed proposals. 101 | # Most users would not need this feature. 102 | if self.proposal_topk is not None: 103 | utils.transform_proposals( 104 | dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk 105 | ) 106 | 107 | if not self.is_train: 108 | # USER: Modify this if you want to keep them for some reason. 109 | dataset_dict.pop("annotations", None) 110 | dataset_dict.pop("sem_seg_file_name", None) 111 | return dataset_dict 112 | 113 | if "annotations" in dataset_dict: 114 | self._transform_annotations(dataset_dict, transforms, image_shape) 115 | 116 | if self.with_jitter: 117 | dataset_dict["jitter_image"] = self.aug(dataset_dict["image"]) 118 | 119 | if self.with_crops: 120 | bbox = dataset_dict['instances'].gt_boxes.tensor 121 | csx = (bbox[:,0] + bbox[:,2])*0.5 122 | csy = (bbox[:,1] + bbox[:,3])*0.5 123 | maxwh = torch.maximum(bbox[:,2]-bbox[:,0],bbox[:,3]-bbox[:,1]) 124 | crops = list() 125 | gt_boxes = list() 126 | mean=[0.48145466, 0.4578275, 0.40821073] 127 | std=[0.26862954, 0.26130258, 0.27577711] 128 | for cx,cy,maxdim,label,box in zip(csx,csy,maxwh,dataset_dict['instances'].gt_classes, bbox): 129 | 130 | if int(maxdim) < 10: 131 | continue 132 | x0 = torch.maximum(cx-maxdim*0.5,torch.tensor(0)) 133 | y0 = torch.maximum(cy-maxdim*0.5,torch.tensor(0)) 134 | try: 135 | imcrop = T.functional.resized_crop(dataset_dict['image'],top=int(y0),left=int(x0),height=int(maxdim),width=int(maxdim),size=224) 136 | imcrop = imcrop.flip(0)/255 # bgr --> rgb for clip 137 | imcrop = T.functional.normalize(imcrop,mean,std) 138 | # print(x0,y0,x0+maxdim,y0+maxdim,dataset_dict['image'].shape) 139 | # print(imcrop.min(),imcrop.max() ) 140 | gt_boxes.append(box.reshape(1,-1)) 141 | except Exception as e: 142 | print(e) 143 | print('crops:',x0,y0,maxdim) 144 | exit() 145 | # crops.append((imcrop,label)) 146 | crops.append(imcrop.unsqueeze(0)) 147 | 148 | if len(crops) == 0: 149 | dataset_dict['crops'] = [] 150 | else: 151 | dataset_dict['crops'] = [torch.cat(crops,0),Boxes(torch.cat(gt_boxes,0))] 152 | 153 | if self.with_random_clip_crops: 154 | crops = [] 155 | rbboxs = [] 156 | 157 | for i in range(15): 158 | p = self.cropfn.get_params(dataset_dict['image'],[self.crop_size,self.crop_size]) 159 | c = tF.crop(dataset_dict['image'],*p) 160 | if self.crop_size != 224: 161 | c = tF.resize(img=c,size=224) 162 | crops.append(c) 163 | rbboxs.append(p) 164 | 165 | crops = torch.stack(crops) 166 | dataset_dict['randomcrops'] = crops 167 | 168 | #apply same crop bbox to the jittered image 169 | if self.with_jitter: 170 | jitter_crops = [] 171 | for p in rbboxs: 172 | jc = tF.crop(dataset_dict['jitter_image'],*p) 173 | if self.crop_size != 224: 174 | jc = tF.resize(img=jc,size=224) 175 | jitter_crops.append(jc) 176 | 177 | jcrops = torch.stack(jitter_crops) 178 | dataset_dict['jitter_randomcrops'] = jcrops 179 | return dataset_dict 180 | 181 | class CombineLoaders(data.IterableDataset): 182 | def __init__(self,loaders): 183 | self.loaders = loaders 184 | 185 | def __iter__(self,): 186 | dd = iter(self.loaders[1]) 187 | for d1 in self.loaders[0]: 188 | try: 189 | d2 = next(dd) 190 | except: 191 | dd=iter(self.loaders[1]) 192 | d2 = next(dd) 193 | 194 | list_out_dict=[] 195 | for v1,v2 in zip(d1,d2): 196 | out_dict = {} 197 | for k in v1.keys(): 198 | out_dict[k] = (v1[k],v2[k]) 199 | list_out_dict.append(out_dict) 200 | 201 | yield list_out_dict 202 | 203 | 204 | class Trainer(DefaultTrainer): 205 | 206 | def __init__(self,cfg) -> None: 207 | super().__init__(cfg) 208 | self.teach_model = None 209 | self.off_opt_interval = np.arange(0,cfg.SOLVER.MAX_ITER,cfg.OFFSET_OPT_INTERVAL[0]).tolist() 210 | self.off_opt_iters = cfg.OFFSET_OPT_ITERS 211 | 212 | @classmethod 213 | def build_model(cls, cfg): 214 | """ 215 | Returns: 216 | torch.nn.Module: 217 | It now calls :func:`detectron2.modeling.build_model`. 218 | Overwrite it if you'd like a different model. 219 | """ 220 | model = build_model(cfg) 221 | 222 | 223 | logger = logging.getLogger(__name__) 224 | logger.info("Model:\n{}".format(model)) 225 | 226 | return model 227 | 228 | @classmethod 229 | def build_train_loader(cls,cfg): 230 | original = cfg.DATASETS.TRAIN 231 | print(original) 232 | cfg.DATASETS.TRAIN=(original[0],) 233 | data_loader1 = build_detection_train_loader(cfg, mapper=CustomDatasetMapper(cfg, True)) 234 | return data_loader1 235 | 236 | @classmethod 237 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 238 | if output_folder is None: 239 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 240 | if MetadataCatalog.get(dataset_name).evaluator_type == 'pascal_voc': 241 | return CustomPascalVOCDetectionEvaluator(dataset_name) 242 | else: 243 | return COCOEvaluator(dataset_name, output_dir=output_folder) 244 | 245 | @classmethod 246 | def build_optimizer(cls,cfg,model): 247 | 248 | trainable = {'others':[],'offset':[]} 249 | 250 | for name,val in model.named_parameters(): 251 | head = name.split('.')[0] 252 | #previously was setting all params to be true 253 | if val.requires_grad == True: 254 | print(name) 255 | if 'offset' in name: 256 | trainable['offset'].append(val) 257 | else: 258 | trainable['others'].append(val) 259 | 260 | optimizer1 = torch.optim.SGD( 261 | trainable['others'], 262 | cfg.SOLVER.BASE_LR, 263 | momentum=cfg.SOLVER.MOMENTUM, 264 | nesterov=cfg.SOLVER.NESTEROV, 265 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 266 | ) 267 | 268 | optimizer2 = torch.optim.Adam( 269 | trainable['offset'], 270 | 0.01, 271 | ) 272 | return (maybe_add_gradient_clipping(cfg, optimizer1),maybe_add_gradient_clipping(cfg, optimizer2)) 273 | 274 | 275 | def run_step(self): 276 | """ 277 | Implement the standard training logic described above. 278 | """ 279 | 280 | 281 | assert self.model.training, "[SimpleTrainer] model was changed to eval mode!" 282 | start = time.perf_counter() 283 | """ 284 | If you want to do something with the data, you can wrap the dataloader. 285 | """ 286 | data = next(self._trainer._data_loader_iter) 287 | data_time = time.perf_counter() - start 288 | 289 | """ 290 | If you want to do something with the losses, you can wrap the model. 291 | """ 292 | data_s = data 293 | 294 | opt_phase = False 295 | loss_dict_s = self.model(data_s) 296 | loss_dict = {} 297 | 298 | loss = 0 299 | for k,v in loss_dict_s.items(): 300 | loss += v 301 | 302 | 303 | """ 304 | If you need to accumulate gradients or do something similar, you can 305 | wrap the optimizer with your custom `zero_grad()` method. 306 | """ 307 | self.optimizer[0].zero_grad() 308 | self.optimizer[1].zero_grad() 309 | 310 | loss.backward() 311 | 312 | if not opt_phase: 313 | self.optimizer[0].step() 314 | else: 315 | self.optimizer[1].step() 316 | 317 | self.optimizer[0].zero_grad() 318 | self.optimizer[1].zero_grad() 319 | 320 | for k,v in loss_dict_s.items(): 321 | loss_dict.update({k:v}) 322 | 323 | # print(loss_di ct) 324 | self._trainer._write_metrics(loss_dict, data_time) 325 | """ 326 | If you need gradient clipping/scaling or other processing, you can 327 | wrap the optimizer with your custom `step()` method. But it is 328 | suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4 329 | """ 330 | 331 | def build_hooks(self): 332 | """ 333 | Build a list of default hooks, including timing, evaluation, 334 | checkpointing, lr scheduling, precise BN, writing events. 335 | Returns: 336 | list[HookBase]: 337 | """ 338 | cfg = self.cfg.clone() 339 | cfg.defrost() 340 | cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN 341 | 342 | ret = [ 343 | hooks.IterationTimer(), 344 | LRScheduler(), 345 | hooks.PreciseBN( 346 | # Run at the same freq as (but before) evaluation. 347 | cfg.TEST.EVAL_PERIOD, 348 | self.model, 349 | # Build a new data loader to not affect training 350 | self.build_train_loader(cfg), 351 | cfg.TEST.PRECISE_BN.NUM_ITER, 352 | ) 353 | if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) 354 | else None, 355 | ] 356 | 357 | # Do PreciseBN before checkpointer, because it updates the model and need to 358 | # be saved by checkpointer. 359 | # This is not always the best: if checkpointing has a different frequency, 360 | # some checkpoints may have more precise statistics than others. 361 | if comm.is_main_process(): 362 | ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) 363 | 364 | def test_and_save_results(): 365 | self._last_eval_results = self.test(self.cfg, self.model) 366 | return self._last_eval_results 367 | 368 | def do_test_st(flag): 369 | if flag == 'st': 370 | model = self.model 371 | else: 372 | print("Error in the flag") 373 | 374 | results = OrderedDict() 375 | for dataset_name in self.cfg.DATASETS.TEST: 376 | data_loader = build_detection_test_loader(self.cfg, dataset_name) 377 | evaluator = CustomPascalVOCDetectionEvaluator(dataset_name) 378 | results_i = inference_on_dataset(model, data_loader, evaluator) 379 | results[dataset_name] = results_i 380 | if comm.is_main_process(): 381 | logger.info("Evaluation results for {} in csv format:".format(dataset_name)) 382 | print_csv_format(results_i) 383 | storage = get_event_storage() 384 | storage.put_scalar(f'{dataset_name}_AP50', results_i['bbox']['AP50'],smoothing_hint=False) 385 | if len(results) == 1: 386 | results = list(results.values())[0] 387 | return results 388 | 389 | 390 | # Do evaluation after checkpointer, because then if it fails, 391 | # we can use the saved checkpoint to debug. 392 | ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) 393 | ret.append(hooks.EvalHook(cfg.TEST.EVAL_SAVE_PERIOD, lambda flag='st': do_test_st(flag))) 394 | 395 | if comm.is_main_process(): 396 | # Here the default print/log frequency of each writer is used. 397 | # run writers in the end, so that evaluation metrics are written 398 | ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) 399 | return ret 400 | 401 | @classmethod 402 | def build_lr_scheduler(cls, cfg, optimizer): 403 | """ 404 | It now calls :func:`detectron2.solver.build_lr_scheduler`. 405 | Overwrite it if you'd like a different scheduler. 406 | """ 407 | 408 | return build_lr_scheduler(cfg, optimizer[0]) 409 | 410 | def state_dict(self): 411 | ret = super().state_dict() 412 | ret["optimizer1"] = self.optimizer[0].state_dict() 413 | ret["optimizer2"] = self.optimizer[1].state_dict() 414 | return ret 415 | 416 | def load_state_dict(self, state_dict): 417 | super().load_state_dict(state_dict) 418 | self.optimizer[0].load_state_dict(state_dict["optimizer1"]) 419 | self.optimizer[1].load_state_dict(state_dict["optimizer2"]) 420 | 421 | 422 | 423 | class LRScheduler(HookBase): 424 | """ 425 | A hook which executes a torch builtin LR scheduler and summarizes the LR. 426 | It is executed after every iteration. 427 | """ 428 | 429 | def __init__(self, optimizer=None, scheduler=None): 430 | """ 431 | Args: 432 | optimizer (torch.optim.Optimizer): 433 | scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler): 434 | if a :class:`ParamScheduler` object, it defines the multiplier over the base LR 435 | in the optimizer. 436 | If any argument is not given, will try to obtain it from the trainer. 437 | """ 438 | self._optimizer = optimizer 439 | self._scheduler = scheduler 440 | 441 | def before_train(self): 442 | self._optimizer = self._optimizer or self.trainer.optimizer 443 | if isinstance(self.scheduler, ParamScheduler): 444 | self._scheduler = LRMultiplier( 445 | self._optimizer, 446 | self.scheduler, 447 | self.trainer.max_iter, 448 | last_iter=self.trainer.iter - 1, 449 | ) 450 | self._best_param_group_id1 = LRScheduler.get_best_param_group_id(self._optimizer[0]) 451 | self._best_param_group_id2 = LRScheduler.get_best_param_group_id(self._optimizer[1]) 452 | 453 | 454 | @staticmethod 455 | def get_best_param_group_id(optimizer): 456 | # NOTE: some heuristics on what LR to summarize 457 | # summarize the param group with most parameters 458 | largest_group = max(len(g["params"]) for g in optimizer.param_groups) 459 | 460 | if largest_group == 1: 461 | # If all groups have one parameter, 462 | # then find the most common initial LR, and use it for summary 463 | lr_count = Counter([g["lr"] for g in optimizer.param_groups]) 464 | lr = lr_count.most_common()[0][0] 465 | for i, g in enumerate(optimizer.param_groups): 466 | if g["lr"] == lr: 467 | return i 468 | else: 469 | for i, g in enumerate(optimizer.param_groups): 470 | if len(g["params"]) == largest_group: 471 | return i 472 | 473 | def after_step(self): 474 | lr1 = self._optimizer[0].param_groups[self._best_param_group_id1]["lr"] 475 | self.trainer.storage.put_scalar("lr1", lr1, smoothing_hint=False) 476 | 477 | lr2 = self._optimizer[1].param_groups[self._best_param_group_id2]["lr"] 478 | self.trainer.storage.put_scalar("lr2", lr2, smoothing_hint=False) 479 | 480 | self.scheduler.step() 481 | 482 | @property 483 | def scheduler(self): 484 | return self._scheduler or self.trainer.scheduler 485 | 486 | def state_dict(self): 487 | if isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler): 488 | return self.scheduler.state_dict() 489 | return {} 490 | 491 | def load_state_dict(self, state_dict): 492 | if isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler): 493 | logger = logging.getLogger(__name__) 494 | logger.info("Loading scheduler from state_dict ...") 495 | self.scheduler.load_state_dict(state_dict) 496 | 497 | def custom_build_detection_test_loader(cfg,dataset_name,mapper=None): 498 | 499 | if isinstance(dataset_name, str): 500 | dataset_name = [dataset_name] 501 | 502 | dataset = get_detection_dataset_dicts( 503 | dataset_name, 504 | filter_empty=False, 505 | proposal_files=[ 506 | cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name 507 | ] 508 | if cfg.MODEL.LOAD_PROPOSALS 509 | else None, 510 | ) 511 | if mapper is None: 512 | mapper = DatasetMapper(cfg, False) 513 | 514 | if isinstance(dataset, list): 515 | dataset = DatasetFromList(dataset, copy=False) 516 | if mapper is not None: 517 | dataset = MapDataset(dataset, mapper) 518 | 519 | sampler = None 520 | if isinstance(dataset, torchdata.IterableDataset): 521 | assert sampler is None, "sampler must be None if dataset is IterableDataset" 522 | else: 523 | if sampler is None: 524 | sampler = InferenceSampler(len(dataset)) 525 | collate_fn = None 526 | 527 | def trivial_batch_collator(batch): 528 | """ 529 | A batch collator that does nothing. 530 | """ 531 | return batch 532 | 533 | return torchdata.DataLoader( 534 | dataset, 535 | batch_size=1, 536 | sampler=sampler, 537 | drop_last=False, 538 | num_workers=cfg.DATALOADER.NUM_WORKERS, 539 | collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, 540 | ) 541 | 542 | 543 | def do_test(cfg, model, model_type=''): 544 | results = OrderedDict() 545 | for dataset_name in cfg.DATASETS.TEST: 546 | data_loader = build_detection_test_loader(cfg, dataset_name) 547 | evaluator = CustomPascalVOCDetectionEvaluator(dataset_name) 548 | results_i = inference_on_dataset(model, data_loader, evaluator) 549 | results[dataset_name] = results_i 550 | if comm.is_main_process(): 551 | logger.info("Evaluation results for {} in csv format:".format(dataset_name)) 552 | print_csv_format(results_i) 553 | 554 | if len(results) == 1: 555 | results = list(results.values())[0] 556 | return results 557 | 558 | def main(args): 559 | cfg = setup(args) 560 | if args.eval_only: 561 | model = Trainer.build_model(cfg) 562 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 563 | cfg.MODEL.WEIGHTS, resume=args.resume 564 | ) 565 | return do_test(cfg,model) 566 | trainer = Trainer(cfg) 567 | trainer.resume_or_load(resume=args.resume) 568 | for dataset_name in cfg.DATASETS.TEST: 569 | if 'daytime_clear_test' in dataset_name : 570 | trainer.register_hooks([ 571 | hooks.BestCheckpointer(cfg.TEST.EVAL_SAVE_PERIOD,trainer.checkpointer,f'{dataset_name}_AP50',file_prefix='daytime_clear_model_best'), 572 | ]) 573 | 574 | trainer.train() 575 | 576 | 577 | if __name__ == "__main__": 578 | args = default_argument_parser().parse_args() 579 | cfg = setup(args) 580 | print("Command Line Args:", args) 581 | main(args) 582 | -------------------------------------------------------------------------------- /OD/CLIP/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | import ipdb 9 | from .mask import Mask_s, Mask_c 10 | import math 11 | 12 | 13 | def conv2d_out_dim(dim, kernel_size, padding=0, stride=1, dilation=1, ceil_mode=False): 14 | if ceil_mode: 15 | return int(math.ceil((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)) 16 | else: 17 | return int(math.floor((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)) 18 | 19 | 20 | 21 | class Bottleneck(nn.Module): 22 | expansion = 4 23 | 24 | def __init__(self, inplanes, planes, h=0, w=0, eta=8, stride=1, base_width=64, dilation=1): 25 | 26 | super().__init__() 27 | 28 | 29 | self.height_1, self.width_1 = h, w 30 | if self.height_1>0: 31 | self.dynamic = True 32 | else: 33 | self.dynamic = False 34 | 35 | if self.dynamic: 36 | width = int(planes * (base_width / 64.)) # * groups 37 | # spatial gating module 38 | 39 | self.height_2 = conv2d_out_dim(h, 3, dilation, 2, dilation) 40 | self.width_2 = conv2d_out_dim(w, 3, dilation, 2, dilation) 41 | self.mask_s = Mask_s(self.height_2, self.width_2, inplanes, eta, eta) 42 | self.upsample_1 = nn.Upsample(size=(self.height_2, self.width_2), mode='nearest') 43 | self.upsample_2 = nn.Upsample(size=(self.height_2, self.width_2), mode='nearest') 44 | 45 | self.mask_c1 = Mask_c(inplanes, width) 46 | self.mask_c2 = Mask_c(width, width) 47 | 48 | 49 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 50 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.relu1 = nn.ReLU(inplace=True) 53 | 54 | # conv 2 55 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.relu2 = nn.ReLU(inplace=True) 58 | 59 | 60 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 61 | 62 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 63 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 64 | self.relu3 = nn.ReLU(inplace=True) 65 | 66 | self.downsample = None 67 | self.stride = stride 68 | 69 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 70 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 71 | self.downsample = nn.Sequential(OrderedDict([ 72 | ("-1", nn.AvgPool2d(stride)), 73 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 74 | ("1", nn.BatchNorm2d(planes * self.expansion)) 75 | ])) 76 | 77 | if self.dynamic: 78 | self.inplanes, self.width, self.planes = inplanes, width, planes * self.expansion 79 | 80 | flops_conv1_full = torch.Tensor([self.height_1 * self.width_1 * width * inplanes]) 81 | flops_conv2_full = torch.Tensor([9 * self.height_2 * self.width_2 * width * width]) 82 | flops_conv3_full = torch.Tensor([self.height_2 * self.width_2 * width * planes*self.expansion]) 83 | self.flops_downsample = torch.Tensor([self.height_2*self.width_2*planes*self.expansion*inplanes] 84 | ) if self.downsample is not None else torch.Tensor([0]) 85 | self.flops_full = flops_conv1_full+flops_conv2_full+flops_conv3_full+self.flops_downsample 86 | # mask flops 87 | flops_mask_s = self.mask_s.get_flops() 88 | flops_mask_c1 = self.mask_c1.get_flops() 89 | flops_mask_c2 = self.mask_c2.get_flops() 90 | self.flops_mask = torch.Tensor([flops_mask_s + flops_mask_c1 + flops_mask_c2]) 91 | 92 | def forward(self, input: torch.Tensor): 93 | if self.dynamic: 94 | x, txt_emb, norm_1, norm_2, flops = input 95 | # spatial mask 96 | mask_s_m, norm_s, norm_s_t = self.mask_s(x) # [N, 1, h, w] 97 | mask_c1, norm_c1, norm_c1_t = self.mask_c1(x, txt_emb) 98 | mask_s1 = self.upsample_1(mask_s_m) # [N, 1, H1, W1] 99 | mask_s = self.upsample_2(mask_s_m) # [N, 1, H2, W2] 100 | 101 | else: 102 | x = input 103 | identity = x 104 | 105 | out = self.relu1(self.bn1(self.conv1(x))) 106 | 107 | if self.dynamic: 108 | out * mask_c1 * mask_s1 if not self.training else out * mask_c1 * mask_s1 109 | mask_c2, norm_c2, norm_c2_t = self.mask_c2(out, txt_emb) 110 | 111 | out = self.relu2(self.bn2(self.conv2(out))) 112 | 113 | if self.dynamic: 114 | out = out * mask_c2* mask_s if not self.training else out * mask_c2 * mask_s 115 | 116 | out = self.avgpool(out) 117 | out = self.bn3(self.conv3(out)) 118 | 119 | if self.downsample is not None: 120 | identity = self.downsample(x) 121 | 122 | out += identity 123 | out = self.relu3(out) 124 | 125 | if self.dynamic: 126 | # norm 127 | norm_1 = torch.cat((norm_1, torch.cat((norm_s, norm_s_t)).unsqueeze(0))) 128 | norm_2 = torch.cat((norm_2, torch.cat((norm_c1, norm_c1_t)).unsqueeze(0))) 129 | norm_2 = torch.cat((norm_2, torch.cat((norm_c2, norm_c2_t)).unsqueeze(0))) 130 | flops_blk = self.get_flops(mask_s, mask_s1, mask_c1, mask_c2) 131 | flops = torch.cat((flops, flops_blk.unsqueeze(0))) 132 | 133 | return (out, txt_emb, norm_1, norm_2, flops) 134 | return out 135 | 136 | def get_flops(self, mask_s, mask_s1, mask_c1, mask_c2): 137 | s_sum = mask_s.sum((1,2,3)) 138 | c1_sum, c2_sum = mask_c1.sum((1,2,3)), mask_c2.sum((1,2,3)) 139 | # conv 140 | s_sum_1 = mask_s1.sum((1,2,3)) 141 | flops_conv1 = s_sum_1 * c1_sum * self.inplanes 142 | flops_conv2 = 9 * s_sum * c2_sum * c1_sum 143 | flops_conv3 = s_sum * self.planes * c2_sum 144 | # total 145 | flops = flops_conv1+flops_conv2+flops_conv3+self.flops_downsample.to(flops_conv1.device) 146 | return torch.cat((flops, self.flops_mask.to(flops.device), self.flops_full.to(flops.device))) 147 | 148 | 149 | class AttentionPool2d(nn.Module): 150 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 151 | super().__init__() 152 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 153 | self.k_proj = nn.Linear(embed_dim, embed_dim) 154 | self.q_proj = nn.Linear(embed_dim, embed_dim) 155 | self.v_proj = nn.Linear(embed_dim, embed_dim) 156 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 157 | self.num_heads = num_heads 158 | 159 | def forward(self, x): 160 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 161 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 162 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 163 | x, _ = F.multi_head_attention_forward( 164 | query=x[:1], key=x, value=x, 165 | embed_dim_to_check=x.shape[-1], 166 | num_heads=self.num_heads, 167 | q_proj_weight=self.q_proj.weight, 168 | k_proj_weight=self.k_proj.weight, 169 | v_proj_weight=self.v_proj.weight, 170 | in_proj_weight=None, 171 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 172 | bias_k=None, 173 | bias_v=None, 174 | add_zero_attn=False, 175 | dropout_p=0, 176 | out_proj_weight=self.c_proj.weight, 177 | out_proj_bias=self.c_proj.bias, 178 | use_separate_proj_weight=True, 179 | training=self.training, 180 | need_weights=False 181 | ) 182 | return x.squeeze(0) 183 | 184 | 185 | class ModifiedResNet(nn.Module): 186 | """ 187 | A ResNet class that is similar to torchvision's but contains the following changes: 188 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 189 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 190 | - The final pooling layer is a QKV attention instead of an average pool 191 | """ 192 | 193 | def __init__(self, layers, output_dim, heads, h=75, w=133, input_resolution=224, width=64): 194 | super().__init__() 195 | 196 | self.height, self.width = h, w 197 | 198 | self.output_dim = output_dim 199 | self.input_resolution = input_resolution 200 | 201 | # the 3-layer stem 202 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 203 | self.bn1 = nn.BatchNorm2d(width // 2) 204 | self.relu1 = nn.ReLU(inplace=True) 205 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 206 | self.bn2 = nn.BatchNorm2d(width // 2) 207 | self.relu2 = nn.ReLU(inplace=True) 208 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 209 | self.bn3 = nn.BatchNorm2d(width) 210 | self.relu3 = nn.ReLU(inplace=True) 211 | self.avgpool = nn.AvgPool2d(2) 212 | 213 | 214 | 215 | # residual layers 216 | self._inplanes = width # this is a *mutable* variable used during construction 217 | self.layer1 = self._make_layer(width, layers[0]) 218 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 219 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 220 | self.layer4, h, w = self._make_layer_dynamic(width * 8, layers[3], 14*2, 14*2, stride=2) 221 | 222 | embed_dim = width * 32 # the ResNet feature dimension 223 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 224 | 225 | def _make_layer(self, planes, blocks, stride=1): 226 | 227 | layers = [Bottleneck(self._inplanes, planes,0,0,8, stride)] 228 | self._inplanes = planes * Bottleneck.expansion 229 | for _ in range(1, blocks): 230 | layers.append(Bottleneck(self._inplanes, planes)) 231 | return nn.Sequential(*layers) 232 | 233 | def _make_layer_dynamic(self, planes, blocks, h, w, stride=1): 234 | 235 | layers = [Bottleneck(self._inplanes, planes, h, w, 8, stride)] 236 | 237 | h = conv2d_out_dim(h, kernel_size=1, stride=stride, padding=0) 238 | w = conv2d_out_dim(w, kernel_size=1, stride=stride, padding=0) 239 | self._inplanes = planes * Bottleneck.expansion 240 | for _ in range(1, blocks): 241 | layers.append(Bottleneck(self._inplanes, planes, h, w)) 242 | 243 | return nn.Sequential(*layers), h, w 244 | 245 | 246 | 247 | def forward(self, x): 248 | def stem(x): 249 | x = self.relu1(self.bn1(self.conv1(x))) 250 | x = self.relu2(self.bn2(self.conv2(x))) 251 | x = self.relu3(self.bn3(self.conv3(x))) 252 | x = self.avgpool(x) 253 | return x 254 | 255 | x = x.type(self.conv1.weight.dtype) 256 | x = stem(x) 257 | x = self.layer1(x) 258 | x = self.layer2(x) 259 | x = self.layer3(x) 260 | x = self.layer4(x) 261 | x = self.attnpool(x) 262 | 263 | return x 264 | 265 | 266 | class LayerNorm(nn.LayerNorm): 267 | """Subclass torch's LayerNorm to handle fp16.""" 268 | 269 | def forward(self, x: torch.Tensor): 270 | orig_type = x.dtype 271 | ret = super().forward(x.type(torch.float32)) 272 | return ret.type(orig_type) 273 | 274 | 275 | class QuickGELU(nn.Module): 276 | def forward(self, x: torch.Tensor): 277 | return x * torch.sigmoid(1.702 * x) 278 | 279 | 280 | class ResidualAttentionBlock(nn.Module): 281 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 282 | super().__init__() 283 | 284 | self.attn = nn.MultiheadAttention(d_model, n_head) 285 | self.ln_1 = LayerNorm(d_model) 286 | self.mlp = nn.Sequential(OrderedDict([ 287 | ("c_fc", nn.Linear(d_model, d_model * 4)), 288 | ("gelu", QuickGELU()), 289 | ("c_proj", nn.Linear(d_model * 4, d_model)) 290 | ])) 291 | self.ln_2 = LayerNorm(d_model) 292 | self.attn_mask = attn_mask 293 | 294 | def attention(self, x: torch.Tensor): 295 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 296 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 297 | 298 | def forward(self, x: torch.Tensor): 299 | x = x + self.attention(self.ln_1(x)) 300 | x = x + self.mlp(self.ln_2(x)) 301 | return x 302 | 303 | 304 | class Transformer(nn.Module): 305 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 306 | super().__init__() 307 | self.width = width 308 | self.layers = layers 309 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 310 | 311 | def forward(self, x: torch.Tensor): 312 | return self.resblocks(x) 313 | 314 | 315 | class VisionTransformer(nn.Module): 316 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 317 | super().__init__() 318 | self.input_resolution = input_resolution 319 | self.output_dim = output_dim 320 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 321 | 322 | scale = width ** -0.5 323 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 324 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 325 | self.ln_pre = LayerNorm(width) 326 | 327 | self.transformer = Transformer(width, layers, heads) 328 | 329 | self.ln_post = LayerNorm(width) 330 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 331 | 332 | def forward(self, x: torch.Tensor): 333 | x = self.conv1(x) # shape = [*, width, grid, grid] 334 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 335 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 336 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 337 | x = x + self.positional_embedding.to(x.dtype) 338 | x = self.ln_pre(x) 339 | 340 | x = x.permute(1, 0, 2) # NLD -> LND 341 | x = self.transformer(x) 342 | x = x.permute(1, 0, 2) # LND -> NLD 343 | 344 | x = self.ln_post(x[:, 0, :]) 345 | 346 | if self.proj is not None: 347 | x = x @ self.proj 348 | 349 | return x 350 | 351 | 352 | class CLIP(nn.Module): 353 | def __init__(self, 354 | embed_dim: int, 355 | # vision 356 | image_resolution: int, 357 | vision_layers: Union[Tuple[int, int, int, int], int], 358 | vision_width: int, 359 | vision_patch_size: int, 360 | # text 361 | context_length: int, 362 | vocab_size: int, 363 | transformer_width: int, 364 | transformer_heads: int, 365 | transformer_layers: int 366 | ): 367 | super().__init__() 368 | 369 | self.context_length = context_length 370 | 371 | if isinstance(vision_layers, (tuple, list)): 372 | vision_heads = vision_width * 32 // 64 373 | self.visual = ModifiedResNet( 374 | layers=vision_layers, 375 | output_dim=embed_dim, 376 | heads=vision_heads, 377 | input_resolution=image_resolution, 378 | width=vision_width 379 | ) 380 | else: 381 | vision_heads = vision_width // 64 382 | self.visual = VisionTransformer( 383 | input_resolution=image_resolution, 384 | patch_size=vision_patch_size, 385 | width=vision_width, 386 | layers=vision_layers, 387 | heads=vision_heads, 388 | output_dim=embed_dim 389 | ) 390 | 391 | self.transformer = Transformer( 392 | width=transformer_width, 393 | layers=transformer_layers, 394 | heads=transformer_heads, 395 | attn_mask=self.build_attention_mask() 396 | ) 397 | 398 | self.vocab_size = vocab_size 399 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 400 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 401 | self.ln_final = LayerNorm(transformer_width) 402 | 403 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 404 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 405 | 406 | self.initialize_parameters() 407 | 408 | def initialize_parameters(self): 409 | nn.init.normal_(self.token_embedding.weight, std=0.02) 410 | nn.init.normal_(self.positional_embedding, std=0.01) 411 | 412 | if isinstance(self.visual, ModifiedResNet): 413 | if self.visual.attnpool is not None: 414 | std = self.visual.attnpool.c_proj.in_features ** -0.5 415 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 416 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 417 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 418 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 419 | 420 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 421 | for name, param in resnet_block.named_parameters(): 422 | if name.endswith("bn3.weight"): 423 | nn.init.zeros_(param) 424 | 425 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 426 | attn_std = self.transformer.width ** -0.5 427 | fc_std = (2 * self.transformer.width) ** -0.5 428 | for block in self.transformer.resblocks: 429 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 430 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 431 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 432 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 433 | 434 | if self.text_projection is not None: 435 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 436 | 437 | def build_attention_mask(self): 438 | # lazily create causal attention mask, with full attention between the vision tokens 439 | # pytorch uses additive attention mask; fill with -inf 440 | mask = torch.empty(self.context_length, self.context_length) 441 | mask.fill_(float("-inf")) 442 | mask.triu_(1) # zero out the lower diagonal 443 | return mask 444 | 445 | @property 446 | def dtype(self): 447 | return self.visual.conv1.weight.dtype 448 | 449 | def encode_image(self, image): 450 | return self.visual(image.type(self.dtype)) 451 | 452 | def encode_text(self, text): 453 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 454 | 455 | x = x + self.positional_embedding.type(self.dtype) 456 | x = x.permute(1, 0, 2) # NLD -> LND 457 | x = self.transformer(x) 458 | x = x.permute(1, 0, 2) # LND -> NLD 459 | x = self.ln_final(x).type(self.dtype) 460 | 461 | # take features from the eot embedding (eot_token is the highest number in each sequence) 462 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 463 | 464 | return x 465 | 466 | def forward(self, image, text): 467 | image_features = self.encode_image(image) 468 | text_features = self.encode_text(text) 469 | 470 | # normalized features 471 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 472 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 473 | 474 | # cosine similarity as logits 475 | logit_scale = self.logit_scale.exp() 476 | logits_per_image = logit_scale * image_features @ text_features.t() 477 | logits_per_text = logits_per_image.t() 478 | 479 | # shape = [global_batch_size, global_batch_size] 480 | return logits_per_image, logits_per_text 481 | 482 | 483 | def convert_weights(model: nn.Module): 484 | """Convert applicable model parameters to fp16""" 485 | 486 | def _convert_weights_to_fp16(l): 487 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 488 | l.weight.data = l.weight.data.half() 489 | if l.bias is not None: 490 | l.bias.data = l.bias.data.half() 491 | 492 | if isinstance(l, nn.MultiheadAttention): 493 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 494 | tensor = getattr(l, attr) 495 | if tensor is not None: 496 | tensor.data = tensor.data.half() 497 | 498 | for name in ["text_projection", "proj"]: 499 | if hasattr(l, name): 500 | attr = getattr(l, name) 501 | if attr is not None: 502 | attr.data = attr.data.half() 503 | 504 | model.apply(_convert_weights_to_fp16) 505 | 506 | 507 | def build_model(state_dict: dict): 508 | vit = "visual.proj" in state_dict 509 | 510 | if vit: 511 | vision_width = state_dict["visual.conv1.weight"].shape[0] 512 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 513 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 514 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 515 | image_resolution = vision_patch_size * grid_size 516 | else: 517 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 518 | vision_layers = tuple(counts) 519 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 520 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 521 | vision_patch_size = None 522 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 523 | image_resolution = output_width * 32 524 | 525 | embed_dim = state_dict["text_projection"].shape[1] 526 | context_length = state_dict["positional_embedding"].shape[0] 527 | vocab_size = state_dict["token_embedding.weight"].shape[0] 528 | transformer_width = state_dict["ln_final.weight"].shape[0] 529 | transformer_heads = transformer_width // 64 530 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 531 | 532 | model = CLIP( 533 | embed_dim, 534 | image_resolution, vision_layers, vision_width, vision_patch_size, 535 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 536 | ) 537 | 538 | for key in ["input_resolution", "context_length", "vocab_size"]: 539 | if key in state_dict: 540 | del state_dict[key] 541 | 542 | convert_weights(model) 543 | model.load_state_dict(state_dict, strict=False) 544 | return model.eval() 545 | --------------------------------------------------------------------------------