├── OD
├── CLIP
│ ├── clip
│ │ ├── __init__.py
│ │ ├── bpe_simple_vocab_16e6.txt.gz
│ │ ├── simple_tokenizer.py
│ │ ├── mask.py
│ │ ├── clip.py
│ │ └── model.py
│ ├── run_install.sh
│ ├── MANIFEST.in
│ ├── requirements.txt
│ ├── CLIP.png
│ ├── dist
│ │ └── clip-1.0-py3.8.egg
│ ├── .gitignore
│ ├── setup.py
│ ├── data
│ │ ├── rendered-sst2.md
│ │ ├── yfcc100m.md
│ │ └── country211.md
│ ├── tests
│ │ └── test_consistency.py
│ ├── .github
│ │ └── workflows
│ │ │ └── test.yml
│ ├── LICENSE
│ ├── hubconf.py
│ ├── model-card.md
│ └── README.md
├── .DS_Store
├── modeling
│ ├── .DS_Store
│ ├── __init__.py
│ ├── .ipynb_checkpoints
│ │ ├── __init__-checkpoint.py
│ │ ├── config-checkpoint.py
│ │ ├── box_predictor-checkpoint.py
│ │ ├── custom_pascal_evaluation-checkpoint.py
│ │ ├── regularization-checkpoint.py
│ │ ├── backbone-checkpoint.py
│ │ ├── rpn-checkpoint.py
│ │ ├── clip-checkpoint.py
│ │ ├── roi_head-checkpoint.py
│ │ └── meta_arch-checkpoint.py
│ ├── config.py
│ ├── box_predictor.py
│ ├── custom_pascal_evaluation.py
│ ├── regularization.py
│ ├── backbone.py
│ ├── rpn.py
│ ├── clip.py
│ ├── roi_head.py
│ └── meta_arch.py
├── datasets
│ └── diverseWeather
│ │ └── .DS_Store
├── data
│ └── datasets
│ │ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── builtin.cpython-36.pyc
│ │ ├── builtin.cpython-37.pyc
│ │ ├── builtin.cpython-38.pyc
│ │ ├── diverse_weather.cpython-36.pyc
│ │ ├── diverse_weather.cpython-37.pyc
│ │ ├── diverse_weather.cpython-38.pyc
│ │ ├── pascal_voc_adaptation.cpython-36.pyc
│ │ ├── pascal_voc_adaptation.cpython-37.pyc
│ │ ├── pascal_voc_adaptation.cpython-38.pyc
│ │ ├── comic_water_adaptation.cpython-36.pyc
│ │ ├── comic_water_adaptation.cpython-37.pyc
│ │ └── comic_water_adaptation.cpython-38.pyc
│ │ ├── __init__.py
│ │ ├── builtin.py
│ │ ├── pascal_voc_adaptation.py
│ │ └── diverse_weather.py
├── prunedprompts.txt
├── configs
│ ├── diverse_weather_dusk_rainy_test.yaml
│ ├── diverse_weather_foggy_test.yaml
│ ├── diverse_weather_night_rainy_test.yaml
│ ├── diverse_weather_night_sunny_test.yaml
│ └── diverse_weather.yaml
└── train.py
├── Framework.png
├── Introduction.jpg
├── Introduction.png
└── README.md
/OD/CLIP/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/OD/CLIP/run_install.sh:
--------------------------------------------------------------------------------
1 | python setup.py develop
--------------------------------------------------------------------------------
/OD/CLIP/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include clip/bpe_simple_vocab_16e6.txt.gz
2 |
--------------------------------------------------------------------------------
/Framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/Framework.png
--------------------------------------------------------------------------------
/OD/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/.DS_Store
--------------------------------------------------------------------------------
/OD/CLIP/requirements.txt:
--------------------------------------------------------------------------------
1 | ftfy
2 | regex
3 | tqdm
4 | torch
5 | torchvision
6 |
--------------------------------------------------------------------------------
/Introduction.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/Introduction.jpg
--------------------------------------------------------------------------------
/Introduction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/Introduction.png
--------------------------------------------------------------------------------
/OD/CLIP/CLIP.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/CLIP/CLIP.png
--------------------------------------------------------------------------------
/OD/modeling/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/modeling/.DS_Store
--------------------------------------------------------------------------------
/OD/CLIP/dist/clip-1.0-py3.8.egg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/CLIP/dist/clip-1.0-py3.8.egg
--------------------------------------------------------------------------------
/OD/datasets/diverseWeather/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/datasets/diverseWeather/.DS_Store
--------------------------------------------------------------------------------
/OD/CLIP/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/CLIP/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/builtin.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/builtin.cpython-36.pyc
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/builtin.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/builtin.cpython-37.pyc
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/builtin.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/builtin.cpython-38.pyc
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/diverse_weather.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/diverse_weather.cpython-36.pyc
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/diverse_weather.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/diverse_weather.cpython-37.pyc
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/diverse_weather.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/diverse_weather.cpython-38.pyc
--------------------------------------------------------------------------------
/OD/CLIP/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *$py.class
4 | *.egg-info
5 | .pytest_cache
6 | .ipynb_checkpoints
7 |
8 | thumbs.db
9 | .DS_Store
10 | .idea
11 |
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/pascal_voc_adaptation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/pascal_voc_adaptation.cpython-36.pyc
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/pascal_voc_adaptation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/pascal_voc_adaptation.cpython-37.pyc
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/pascal_voc_adaptation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/pascal_voc_adaptation.cpython-38.pyc
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/comic_water_adaptation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/comic_water_adaptation.cpython-36.pyc
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/comic_water_adaptation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/comic_water_adaptation.cpython-37.pyc
--------------------------------------------------------------------------------
/OD/data/datasets/__pycache__/comic_water_adaptation.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Daniel00008/PDOC/HEAD/OD/data/datasets/__pycache__/comic_water_adaptation.cpython-38.pyc
--------------------------------------------------------------------------------
/OD/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | from . import builtin # ensure the builtin datasets are registered
4 |
5 | __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")]
6 |
--------------------------------------------------------------------------------
/OD/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | from .rpn import SBRPN
2 | from .backbone import ClipRN101
3 | from .meta_arch import ClipRCNNWithClipBackbone
4 | from .roi_head import ClipRes5ROIHeads
5 | from .config import add_stn_config
6 | from .custom_pascal_evaluation import CustomPascalVOCDetectionEvaluator
--------------------------------------------------------------------------------
/OD/modeling/.ipynb_checkpoints/__init__-checkpoint.py:
--------------------------------------------------------------------------------
1 | from .rpn import SBRPN
2 | from .backbone import ClipRN101
3 | from .meta_arch import ClipRCNNWithClipBackbone
4 | from .roi_head import ClipRes5ROIHeads
5 | from .config import add_stn_config
6 | from .custom_pascal_evaluation import CustomPascalVOCDetectionEvaluator
--------------------------------------------------------------------------------
/OD/prunedprompts.txt:
--------------------------------------------------------------------------------
1 | an image taken on a snow night
2 | an image taken on a fog night
3 | an image taken on a cloudy night
4 | an image taken on a rain night
5 | an image taken on a stormy night
6 | an image taken on a snow day
7 | an image taken on a fog day
8 | an image taken on a cloudy day
9 | an image taken on a rain day
10 | an image taken on a stormy day
11 | an image taken on a snow evening
12 | an image taken on a fog evening
13 | an image taken on a cloudy evening
14 | an image taken on a rain evening
15 | an image taken on a stormy evening
16 |
--------------------------------------------------------------------------------
/OD/data/datasets/builtin.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 |
3 | from .diverse_weather import register_dataset as register_diverse_weather
4 | from .pascal_voc_adaptation import register_all_pascal_voc as register_pascal_voc
5 | from .comic_water_adaptation import register_dataset as register_comic_water
6 | import os
7 |
8 | _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets"))
9 | DEFAULT_DATASETS_ROOT = "data/"
10 |
11 |
12 | register_diverse_weather(_root)
13 | register_pascal_voc(_root)
14 | register_comic_water(_root)
15 |
--------------------------------------------------------------------------------
/OD/CLIP/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pkg_resources
4 | from setuptools import setup, find_packages
5 |
6 | setup(
7 | name="clip",
8 | py_modules=["clip"],
9 | version="1.0",
10 | description="",
11 | author="OpenAI",
12 | packages=find_packages(exclude=["tests*"]),
13 | install_requires=[
14 | str(r)
15 | for r in pkg_resources.parse_requirements(
16 | open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
17 | )
18 | ],
19 | include_package_data=True,
20 | extras_require={'dev': ['pytest']},
21 | )
22 |
--------------------------------------------------------------------------------
/OD/CLIP/data/rendered-sst2.md:
--------------------------------------------------------------------------------
1 | # The Rendered SST2 Dataset
2 |
3 | In the paper, we used an image classification dataset called Rendered SST2, to evaluate the model's capability on optical character recognition. To do so, we rendered the sentences in the [Standford Sentiment Treebank v2](https://nlp.stanford.edu/sentiment/treebank.html) dataset and used those as the input to the CLIP image encoder.
4 |
5 | The following command will download a 131MB archive countaining the images and extract into a subdirectory `rendered-sst2`:
6 |
7 | ```bash
8 | wget https://openaipublic.azureedge.net/clip/data/rendered-sst2.tgz
9 | tar zxvf rendered-sst2.tgz
10 | ```
11 |
12 |
--------------------------------------------------------------------------------
/OD/modeling/config.py:
--------------------------------------------------------------------------------
1 | from detectron2.config import CfgNode as CN
2 |
3 | def add_stn_config(cfg):
4 | cfg.OFFSET_DOMAIN = ''
5 | cfg.OFFSET_FROZENBN = False
6 | cfg.OFFSET_DOMAIN_TEXT = ''
7 | cfg.OFFSET_NAME = 0
8 | cfg.OFFSET_OPT_INTERVAL = [10]
9 | cfg.OFFSET_OPT_ITERS = 0
10 | cfg.AUG_PROB = 0.5
11 | cfg.DOMAIN_NAME = ""
12 | cfg.TEST.EVAL_SAVE_PERIOD = 5000
13 | cfg.INPUT.CLIP_WITH_IMG = False
14 | cfg.INPUT.CLIP_RANDOM_CROPS = False
15 | cfg.INPUT.IMAGE_JITTER = False
16 | cfg.INPUT.RANDOM_CROP_SIZE = 224
17 | cfg.MODEL.GLOBAL_GND = False
18 | cfg.BASE_YAML = "COCO-Detection/faster_rcnn_R_50_C4_3x.yaml"
19 | cfg.MODEL.RENAME = list()
20 | cfg.MODEL.CLIP_IMAGE_ENCODER_NAME = 'ViT-B/32'
21 | cfg.MODEL.BACKBONE.UNFREEZE = ['layer3','layer4','attnpool']
22 | cfg.MODEL.USE_PROJ = True
23 |
--------------------------------------------------------------------------------
/OD/modeling/.ipynb_checkpoints/config-checkpoint.py:
--------------------------------------------------------------------------------
1 | from detectron2.config import CfgNode as CN
2 |
3 | def add_stn_config(cfg):
4 | cfg.OFFSET_DOMAIN = ''
5 | cfg.OFFSET_FROZENBN = False
6 | cfg.OFFSET_DOMAIN_TEXT = ''
7 | cfg.OFFSET_NAME = 0
8 | cfg.OFFSET_OPT_INTERVAL = [10]
9 | cfg.OFFSET_OPT_ITERS = 0
10 | cfg.AUG_PROB = 0.5
11 | cfg.DOMAIN_NAME = ""
12 | cfg.TEST.EVAL_SAVE_PERIOD = 5000
13 | cfg.INPUT.CLIP_WITH_IMG = False
14 | cfg.INPUT.CLIP_RANDOM_CROPS = False
15 | cfg.INPUT.IMAGE_JITTER = False
16 | cfg.INPUT.RANDOM_CROP_SIZE = 224
17 | cfg.MODEL.GLOBAL_GND = False
18 | cfg.BASE_YAML = "COCO-Detection/faster_rcnn_R_50_C4_3x.yaml"
19 | cfg.MODEL.RENAME = list()
20 | cfg.MODEL.CLIP_IMAGE_ENCODER_NAME = 'ViT-B/32'
21 | cfg.MODEL.BACKBONE.UNFREEZE = ['layer3','layer4','attnpool']
22 | cfg.MODEL.USE_PROJ = True
23 |
--------------------------------------------------------------------------------
/OD/CLIP/tests/test_consistency.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 | from PIL import Image
5 |
6 | import clip
7 |
8 |
9 | @pytest.mark.parametrize('model_name', clip.available_models())
10 | def test_consistency(model_name):
11 | device = "cpu"
12 | jit_model, transform = clip.load(model_name, device=device, jit=True)
13 | py_model, _ = clip.load(model_name, device=device, jit=False)
14 |
15 | image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)
16 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
17 |
18 | with torch.no_grad():
19 | logits_per_image, _ = jit_model(image, text)
20 | jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
21 |
22 | logits_per_image, _ = py_model(image, text)
23 | py_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
24 |
25 | assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1)
26 |
--------------------------------------------------------------------------------
/OD/modeling/box_predictor.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Tuple
2 | import torch
3 |
4 | from detectron2.layers import cat, cross_entropy
5 | from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
6 | from .clip import ClipPredictor
7 |
8 | class ClipFastRCNNOutputLayers(FastRCNNOutputLayers):
9 |
10 | def __init__(self,cfg, input_shape, clsnames) -> None:
11 | super().__init__(cfg, input_shape)
12 | self.cls_score = ClipPredictor(cfg.MODEL.CLIP_IMAGE_ENCODER_NAME, input_shape.channels, cfg.MODEL.DEVICE,clsnames)
13 | def forward(self,x,gfeat=None):
14 |
15 | if isinstance(x,list):
16 | scores = self.cls_score(x[0],gfeat)
17 | proposal_deltas = self.bbox_pred(x[1])
18 | else:
19 | scores = self.cls_score(x,gfeat)
20 | proposal_deltas = self.bbox_pred(x)
21 |
22 | return scores, proposal_deltas
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/OD/modeling/.ipynb_checkpoints/box_predictor-checkpoint.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Tuple
2 | import torch
3 |
4 | from detectron2.layers import cat, cross_entropy
5 | from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers
6 | from .clip import ClipPredictor
7 |
8 | class ClipFastRCNNOutputLayers(FastRCNNOutputLayers):
9 |
10 | def __init__(self,cfg, input_shape, clsnames) -> None:
11 | super().__init__(cfg, input_shape)
12 | self.cls_score = ClipPredictor(cfg.MODEL.CLIP_IMAGE_ENCODER_NAME, input_shape.channels, cfg.MODEL.DEVICE,clsnames)
13 | def forward(self,x,gfeat=None):
14 |
15 | if isinstance(x,list):
16 | scores = self.cls_score(x[0],gfeat)
17 | proposal_deltas = self.bbox_pred(x[1])
18 | else:
19 | scores = self.cls_score(x,gfeat)
20 | proposal_deltas = self.bbox_pred(x)
21 |
22 | return scores, proposal_deltas
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/OD/CLIP/data/yfcc100m.md:
--------------------------------------------------------------------------------
1 | # The YFCC100M Subset
2 |
3 | In the paper, we performed a dataset ablation using a subset of the YFCC100M dataset and showed that the performance remained largely similar.
4 |
5 | The subset contains 14,829,396 images, about 15% of the full dataset, which have been filtered to only keep those with natural languag titles and/or descriptions in English.
6 |
7 | We provide the list of (line number, photo identifier, photo hash) of each image contained in this subset. These correspond to the first three columns in the dataset's metadata TSV file.
8 |
9 | ```bash
10 | wget https://openaipublic.azureedge.net/clip/data/yfcc100m_subset_data.tsv.bz2
11 | bunzip2 yfcc100m_subset_data.tsv.bz2
12 | ```
13 |
14 | Use of the underlying media files is subject to the Creative Commons licenses chosen by their creators/uploaders. For more information about the YFCC100M dataset, visit [the official website](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/).
--------------------------------------------------------------------------------
/OD/CLIP/data/country211.md:
--------------------------------------------------------------------------------
1 | # The Country211 Dataset
2 |
3 | In the paper, we used an image classification dataset called Country211, to evaluate the model's capability on geolocation. To do so, we filtered the YFCC100m dataset that have GPS coordinate corresponding to a [ISO-3166 country code](https://en.wikipedia.org/wiki/List_of_ISO_3166_country_codes) and created a balanced dataset by sampling 150 train images, 50 validation images, and 100 test images images for each country.
4 |
5 | The following command will download an 11GB archive countaining the images and extract into a subdirectory `country211`:
6 |
7 | ```bash
8 | wget https://openaipublic.azureedge.net/clip/data/country211.tgz
9 | tar zxvf country211.tgz
10 | ```
11 |
12 | These images are a subset of the YFCC100m dataset. Use of the underlying media files is subject to the Creative Commons licenses chosen by their creators/uploaders. For more information about the YFCC100M dataset, visit [the official website](https://multimediacommons.wordpress.com/yfcc100m-core-dataset/).
--------------------------------------------------------------------------------
/OD/configs/diverse_weather_dusk_rainy_test.yaml:
--------------------------------------------------------------------------------
1 | BASE_YAML: "COCO-Detection/faster_rcnn_R_101_C4_3x.yaml"
2 | DATASETS:
3 | TRAIN: ("daytime_clear_train",)
4 | TEST: ('dusk_rainy_train',)
5 | DATALOADER:
6 | NUM_WORKERS: 16
7 | INPUT:
8 | MIN_SIZE_TRAIN: (600,)
9 | MIN_SIZE_TEST: 600
10 | CLIP_RANDOM_CROPS: True
11 | RANDOM_CROP_SIZE: 400
12 |
13 | SOLVER:
14 | BASE_LR: 0.001
15 | MAX_ITER: 100000
16 | STEPS: [40000,]
17 | WARMUP_ITERS: 0
18 | IMS_PER_BATCH: 4
19 | CHECKPOINT_PERIOD: 1000000
20 | MODEL:
21 | BACKBONE:
22 | NAME: ClipRN101
23 | WEIGHTS: ""
24 | CLIP_IMAGE_ENCODER_NAME: 'RN101'
25 | META_ARCHITECTURE: 'ClipRCNNWithClipBackboneTrainable'
26 |
27 | PROPOSAL_GENERATOR:
28 | NAME: 'SBRPN'
29 | ROI_HEADS:
30 | NAME: 'ClipRes5ROIHeadsAttn'
31 | NUM_CLASSES: 7
32 | TEST:
33 | EVAL_SAVE_PERIOD: 5000
34 | OUTPUT_DIR: "all_outs/diverse_weather_dusk_rainy"
35 | VIS_PERIOD: 5000
36 |
--------------------------------------------------------------------------------
/OD/configs/diverse_weather_foggy_test.yaml:
--------------------------------------------------------------------------------
1 | BASE_YAML: "COCO-Detection/faster_rcnn_R_101_C4_3x.yaml"
2 | DATASETS:
3 | TRAIN: ("daytime_clear_train",)
4 | TEST: ('daytime_foggy_train',)
5 | DATALOADER:
6 | NUM_WORKERS: 16
7 | INPUT:
8 | MIN_SIZE_TRAIN: (600,)
9 | MIN_SIZE_TEST: 600
10 | CLIP_RANDOM_CROPS: True
11 | RANDOM_CROP_SIZE: 400
12 |
13 | SOLVER:
14 | BASE_LR: 0.001
15 | MAX_ITER: 100000
16 | STEPS: [40000,]
17 | WARMUP_ITERS: 0
18 | IMS_PER_BATCH: 4
19 | CHECKPOINT_PERIOD: 1000000
20 | MODEL:
21 | BACKBONE:
22 | NAME: ClipRN101
23 | WEIGHTS: ""
24 | CLIP_IMAGE_ENCODER_NAME: 'RN101'
25 | META_ARCHITECTURE: 'ClipRCNNWithClipBackboneTrainable'
26 |
27 | PROPOSAL_GENERATOR:
28 | NAME: 'SBRPN'
29 | ROI_HEADS:
30 | NAME: 'ClipRes5ROIHeadsAttn'
31 | NUM_CLASSES: 7
32 | TEST:
33 | EVAL_SAVE_PERIOD: 5000
34 | OUTPUT_DIR: "all_outs/diverse_weather_foggy"
35 | VIS_PERIOD: 5000
36 |
37 |
--------------------------------------------------------------------------------
/OD/configs/diverse_weather_night_rainy_test.yaml:
--------------------------------------------------------------------------------
1 | BASE_YAML: "COCO-Detection/faster_rcnn_R_101_C4_3x.yaml"
2 | DATASETS:
3 | TRAIN: ("daytime_clear_train",)
4 | TEST: ('night_rainy_train',)
5 | DATALOADER:
6 | NUM_WORKERS: 16
7 | INPUT:
8 | MIN_SIZE_TRAIN: (600,)
9 | MIN_SIZE_TEST: 600
10 | CLIP_RANDOM_CROPS: True
11 | RANDOM_CROP_SIZE: 400
12 |
13 | SOLVER:
14 | BASE_LR: 0.001
15 | MAX_ITER: 100000
16 | STEPS: [40000,]
17 | WARMUP_ITERS: 0
18 | IMS_PER_BATCH: 4
19 | CHECKPOINT_PERIOD: 1000000
20 | MODEL:
21 | BACKBONE:
22 | NAME: ClipRN101
23 | WEIGHTS: ""
24 | CLIP_IMAGE_ENCODER_NAME: 'RN101'
25 | META_ARCHITECTURE: 'ClipRCNNWithClipBackboneTrainable'
26 |
27 | PROPOSAL_GENERATOR:
28 | NAME: 'SBRPN'
29 | ROI_HEADS:
30 | NAME: 'ClipRes5ROIHeadsAttn'
31 | NUM_CLASSES: 7
32 | TEST:
33 | EVAL_SAVE_PERIOD: 5000
34 | OUTPUT_DIR: "all_outs/diverse_weather_night_rainy"
35 | VIS_PERIOD: 5000
36 |
37 |
--------------------------------------------------------------------------------
/OD/configs/diverse_weather_night_sunny_test.yaml:
--------------------------------------------------------------------------------
1 | BASE_YAML: "COCO-Detection/faster_rcnn_R_101_C4_3x.yaml"
2 | DATASETS:
3 | TRAIN: ("daytime_clear_train",)
4 | TEST: ('night_sunny_train',)
5 | DATALOADER:
6 | NUM_WORKERS: 16
7 | INPUT:
8 | MIN_SIZE_TRAIN: (600,)
9 | MIN_SIZE_TEST: 600
10 | CLIP_RANDOM_CROPS: True
11 | RANDOM_CROP_SIZE: 400
12 |
13 | SOLVER:
14 | BASE_LR: 0.001
15 | MAX_ITER: 100000
16 | STEPS: [40000,]
17 | WARMUP_ITERS: 0
18 | IMS_PER_BATCH: 4
19 | CHECKPOINT_PERIOD: 1000000
20 | MODEL:
21 | BACKBONE:
22 | NAME: ClipRN101
23 | WEIGHTS: ""
24 | CLIP_IMAGE_ENCODER_NAME: 'RN101'
25 | META_ARCHITECTURE: 'ClipRCNNWithClipBackboneTrainable'
26 |
27 | PROPOSAL_GENERATOR:
28 | NAME: 'SBRPN'
29 | ROI_HEADS:
30 | NAME: 'ClipRes5ROIHeadsAttn'
31 | NUM_CLASSES: 7
32 | TEST:
33 | EVAL_SAVE_PERIOD: 5000
34 | OUTPUT_DIR: "all_outs/diverse_weather_night_sunny"
35 | VIS_PERIOD: 5000
36 |
37 |
--------------------------------------------------------------------------------
/OD/CLIP/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: test
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 | branches:
8 | - main
9 | jobs:
10 | CLIP-test:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: [3.8]
15 | pytorch-version: [1.7.1, 1.9.1, 1.10.1]
16 | include:
17 | - python-version: 3.8
18 | pytorch-version: 1.7.1
19 | torchvision-version: 0.8.2
20 | - python-version: 3.8
21 | pytorch-version: 1.9.1
22 | torchvision-version: 0.10.1
23 | - python-version: 3.8
24 | pytorch-version: 1.10.1
25 | torchvision-version: 0.11.2
26 | steps:
27 | - uses: conda-incubator/setup-miniconda@v2
28 | - run: conda install -n test python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} torchvision=${{ matrix.torchvision-version }} cpuonly -c pytorch
29 | - uses: actions/checkout@v2
30 | - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
31 | - run: pip install pytest
32 | - run: pip install .
33 | - run: pytest
34 |
--------------------------------------------------------------------------------
/OD/CLIP/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 OpenAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/OD/configs/diverse_weather.yaml:
--------------------------------------------------------------------------------
1 | BASE_YAML: "COCO-Detection/faster_rcnn_R_101_C4_3x.yaml"
2 | DATASETS:
3 | TRAIN: ("daytime_clear_train",)
4 | TEST: ('daytime_clear_test',)
5 | DATALOADER:
6 | NUM_WORKERS: 16
7 | ASPECT_RATIO_GROUPING: True
8 | INPUT:
9 | MIN_SIZE_TRAIN: (600,)
10 | MIN_SIZE_TEST: 600
11 | CLIP_RANDOM_CROPS: True
12 | RANDOM_CROP_SIZE: 400
13 |
14 | SOLVER:
15 | BASE_LR: 0.001
16 | MAX_ITER: 100000
17 | STEPS: [40000,]
18 | WARMUP_ITERS: 0
19 | IMS_PER_BATCH: 4
20 | CHECKPOINT_PERIOD: 1000000
21 | MODEL:
22 | BACKBONE:
23 | NAME: ClipRN101
24 | FREEZE_AT: 2
25 | UNFREEZE:
26 | - layer3
27 | - layer4
28 | - attnpool
29 | WEIGHTS: ""
30 | CLIP_IMAGE_ENCODER_NAME: 'RN101'
31 | META_ARCHITECTURE: 'ClipRCNNWithClipBackboneTrainable'
32 |
33 | PROPOSAL_GENERATOR:
34 | NAME: 'SBRPN'
35 | ROI_HEADS:
36 | NAME: 'ClipRes5ROIHeadsAttn'
37 | NUM_CLASSES: 7
38 | TEST:
39 | EVAL_SAVE_PERIOD: 5000
40 |
41 | OUTPUT_DIR: "all_outs/diverse_weather"
42 | VIS_PERIOD: 5000
43 |
44 |
--------------------------------------------------------------------------------
/OD/CLIP/hubconf.py:
--------------------------------------------------------------------------------
1 | from clip.clip import tokenize as _tokenize, load as _load, available_models as _available_models
2 | import re
3 | import string
4 |
5 | dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"]
6 |
7 | # For compatibility (cannot include special characters in function name)
8 | model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()}
9 |
10 | def _create_hub_entrypoint(model):
11 | def entrypoint(**kwargs):
12 | return _load(model, **kwargs)
13 |
14 | entrypoint.__doc__ = f"""Loads the {model} CLIP model
15 |
16 | Parameters
17 | ----------
18 | device : Union[str, torch.device]
19 | The device to put the loaded model
20 |
21 | jit : bool
22 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
23 |
24 | download_root: str
25 | path to download the model files; by default, it uses "~/.cache/clip"
26 |
27 | Returns
28 | -------
29 | model : torch.nn.Module
30 | The {model} CLIP model
31 |
32 | preprocess : Callable[[PIL.Image], torch.Tensor]
33 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
34 | """
35 | return entrypoint
36 |
37 | def tokenize():
38 | return _tokenize
39 |
40 | _entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()}
41 |
42 | globals().update(_entrypoints)
--------------------------------------------------------------------------------
/OD/modeling/custom_pascal_evaluation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import tempfile
4 | from collections import OrderedDict, defaultdict
5 |
6 | from detectron2.utils import comm
7 | from detectron2.evaluation import PascalVOCDetectionEvaluator
8 | from detectron2.evaluation.pascal_voc_evaluation import voc_eval
9 |
10 | class CustomPascalVOCDetectionEvaluator(PascalVOCDetectionEvaluator):
11 | def __init__(self,dataset_name):
12 | super().__init__(dataset_name)
13 |
14 | def evaluate(self):
15 | """
16 | Returns:
17 | dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75".
18 | """
19 | all_predictions = comm.gather(self._predictions, dst=0)
20 | if not comm.is_main_process():
21 | return
22 | predictions = defaultdict(list)
23 | for predictions_per_rank in all_predictions:
24 | for clsid, lines in predictions_per_rank.items():
25 | predictions[clsid].extend(lines)
26 | del all_predictions
27 |
28 | self._logger.info(
29 | "Evaluating {} using {} metric. "
30 | "Note that results do not use the official Matlab API.".format(
31 | self._dataset_name, 2007 if self._is_2007 else 2012
32 | )
33 | )
34 |
35 |
36 | with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname:
37 | res_file_template = os.path.join(dirname, "{}.txt")
38 |
39 | aps = defaultdict(list) # iou -> ap per class
40 | for cls_id, cls_name in enumerate(self._class_names):
41 | lines = predictions.get(cls_id, [""])
42 |
43 | with open(res_file_template.format(cls_name), "w") as f:
44 | f.write("\n".join(lines))
45 |
46 | for thresh in range(50, 100, 5):
47 | rec, prec, ap = voc_eval(
48 | res_file_template,
49 | self._anno_file_template,
50 | self._image_set_path,
51 | cls_name,
52 | ovthresh=thresh / 100.0,
53 | use_07_metric=self._is_2007,
54 | )
55 | aps[thresh].append(ap * 100)
56 |
57 | ret = OrderedDict()
58 | mAP = {iou: np.mean(x) for iou, x in aps.items()}
59 |
60 | clsaps = ','.join(['{:.2f}'.format(a) for a in aps[50]])
61 | self._logger.info("classwise ap {}".format(clsaps))
62 |
63 | ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]}
64 | return ret
65 |
--------------------------------------------------------------------------------
/OD/modeling/.ipynb_checkpoints/custom_pascal_evaluation-checkpoint.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import tempfile
4 | from collections import OrderedDict, defaultdict
5 |
6 | from detectron2.utils import comm
7 | from detectron2.evaluation import PascalVOCDetectionEvaluator
8 | from detectron2.evaluation.pascal_voc_evaluation import voc_eval
9 |
10 | class CustomPascalVOCDetectionEvaluator(PascalVOCDetectionEvaluator):
11 | def __init__(self,dataset_name):
12 | super().__init__(dataset_name)
13 |
14 | def evaluate(self):
15 | """
16 | Returns:
17 | dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75".
18 | """
19 | all_predictions = comm.gather(self._predictions, dst=0)
20 | if not comm.is_main_process():
21 | return
22 | predictions = defaultdict(list)
23 | for predictions_per_rank in all_predictions:
24 | for clsid, lines in predictions_per_rank.items():
25 | predictions[clsid].extend(lines)
26 | del all_predictions
27 |
28 | self._logger.info(
29 | "Evaluating {} using {} metric. "
30 | "Note that results do not use the official Matlab API.".format(
31 | self._dataset_name, 2007 if self._is_2007 else 2012
32 | )
33 | )
34 |
35 |
36 | with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname:
37 | res_file_template = os.path.join(dirname, "{}.txt")
38 |
39 | aps = defaultdict(list) # iou -> ap per class
40 | for cls_id, cls_name in enumerate(self._class_names):
41 | lines = predictions.get(cls_id, [""])
42 |
43 | with open(res_file_template.format(cls_name), "w") as f:
44 | f.write("\n".join(lines))
45 |
46 | for thresh in range(50, 100, 5):
47 | rec, prec, ap = voc_eval(
48 | res_file_template,
49 | self._anno_file_template,
50 | self._image_set_path,
51 | cls_name,
52 | ovthresh=thresh / 100.0,
53 | use_07_metric=self._is_2007,
54 | )
55 | aps[thresh].append(ap * 100)
56 |
57 | ret = OrderedDict()
58 | mAP = {iou: np.mean(x) for iou, x in aps.items()}
59 |
60 | clsaps = ','.join(['{:.2f}'.format(a) for a in aps[50]])
61 | self._logger.info("classwise ap {}".format(clsaps))
62 |
63 | ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]}
64 | return ret
65 |
--------------------------------------------------------------------------------
/OD/modeling/regularization.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class spar_loss(nn.Module):
8 | def __init__(self):
9 | super(spar_loss, self).__init__()
10 |
11 | def forward(self, flops_real, flops_mask, flops_ori, batch_size, den_target, lbda):
12 | # total sparsity
13 | flops_tensor, flops_conv1, flops_fc = flops_real[0], flops_real[1], flops_real[2]
14 | # block flops
15 | flops_conv = flops_tensor[0:batch_size,:].mean(0).sum()
16 | flops_mask = flops_mask.mean(0).sum()
17 | flops_ori = flops_ori.mean(0).sum() + flops_conv1.mean() + flops_fc.mean()
18 | flops_real = flops_conv + flops_mask + flops_conv1.mean() + flops_fc.mean()
19 | # loss
20 | rloss = lbda * (flops_real / flops_ori - den_target)**2
21 | return rloss
22 |
23 |
24 | class blance_loss(nn.Module):
25 | def __init__(self):
26 | super(blance_loss, self).__init__()
27 |
28 | def forward(self, mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, batch_size,
29 | den_target, gamma, p):
30 | norm_s = mask_norm_s
31 | norm_s_t = norm_s_t.mean(0)
32 | norm_c = mask_norm_c
33 | norm_c_t = norm_c_t.mean(0)
34 | den_s = norm_s[0:batch_size,:].mean(0) / norm_s_t
35 | den_c = norm_c[0:batch_size,:].mean(0) / norm_c_t
36 | den_tar = math.sqrt(den_target)
37 | bloss_s = get_bloss_basic(den_s, den_tar, batch_size, gamma, p)
38 | bloss_c = get_bloss_basic(den_c, den_tar, batch_size, gamma, p)
39 | bloss = bloss_s + bloss_c
40 | return bloss
41 |
42 |
43 | def get_bloss_basic(spar, spar_tar, batch_size, gamma, p):
44 | # bound
45 | bloss_l = (F.relu(p*spar_tar-spar)**2).mean()
46 | bloss_u = (F.relu(spar-1+p-p*spar_tar)**2).mean()
47 | bloss = gamma * (bloss_l + bloss_u)
48 | return bloss
49 |
50 |
51 | class Loss(nn.Module):
52 | def __init__(self):
53 | super(Loss, self).__init__()
54 | self.task_loss = nn.CrossEntropyLoss()
55 | self.spar_loss = spar_loss()
56 | self.balance_loss = blance_loss()
57 |
58 | def forward(self, output, targets, flops_real, flops_mask, flops_ori, batch_size,
59 | den_target, lbda, mask_norm_s, mask_norm_c, norm_s_t, norm_c_t,
60 | gamma, p):
61 | closs = self.task_loss(output, targets)
62 | sloss = self.spar_loss(flops_real, flops_mask, flops_ori, batch_size, den_target, lbda)
63 | bloss = self.balance_loss(mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, batch_size,
64 | den_target, gamma, p)
65 | return closs, sloss, bloss
66 |
--------------------------------------------------------------------------------
/OD/modeling/.ipynb_checkpoints/regularization-checkpoint.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class spar_loss(nn.Module):
8 | def __init__(self):
9 | super(spar_loss, self).__init__()
10 |
11 | def forward(self, flops_real, flops_mask, flops_ori, batch_size, den_target, lbda):
12 | # total sparsity
13 | flops_tensor, flops_conv1, flops_fc = flops_real[0], flops_real[1], flops_real[2]
14 | # block flops
15 | flops_conv = flops_tensor[0:batch_size,:].mean(0).sum()
16 | flops_mask = flops_mask.mean(0).sum()
17 | flops_ori = flops_ori.mean(0).sum() + flops_conv1.mean() + flops_fc.mean()
18 | flops_real = flops_conv + flops_mask + flops_conv1.mean() + flops_fc.mean()
19 | # loss
20 | rloss = lbda * (flops_real / flops_ori - den_target)**2
21 | return rloss
22 |
23 |
24 | class blance_loss(nn.Module):
25 | def __init__(self):
26 | super(blance_loss, self).__init__()
27 |
28 | def forward(self, mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, batch_size,
29 | den_target, gamma, p):
30 | norm_s = mask_norm_s
31 | norm_s_t = norm_s_t.mean(0)
32 | norm_c = mask_norm_c
33 | norm_c_t = norm_c_t.mean(0)
34 | den_s = norm_s[0:batch_size,:].mean(0) / norm_s_t
35 | den_c = norm_c[0:batch_size,:].mean(0) / norm_c_t
36 | den_tar = math.sqrt(den_target)
37 | bloss_s = get_bloss_basic(den_s, den_tar, batch_size, gamma, p)
38 | bloss_c = get_bloss_basic(den_c, den_tar, batch_size, gamma, p)
39 | bloss = bloss_s + bloss_c
40 | return bloss
41 |
42 |
43 | def get_bloss_basic(spar, spar_tar, batch_size, gamma, p):
44 | # bound
45 | bloss_l = (F.relu(p*spar_tar-spar)**2).mean()
46 | bloss_u = (F.relu(spar-1+p-p*spar_tar)**2).mean()
47 | bloss = gamma * (bloss_l + bloss_u)
48 | return bloss
49 |
50 |
51 | class Loss(nn.Module):
52 | def __init__(self):
53 | super(Loss, self).__init__()
54 | self.task_loss = nn.CrossEntropyLoss()
55 | self.spar_loss = spar_loss()
56 | self.balance_loss = blance_loss()
57 |
58 | def forward(self, output, targets, flops_real, flops_mask, flops_ori, batch_size,
59 | den_target, lbda, mask_norm_s, mask_norm_c, norm_s_t, norm_c_t,
60 | gamma, p):
61 | closs = self.task_loss(output, targets)
62 | sloss = self.spar_loss(flops_real, flops_mask, flops_ori, batch_size, den_target, lbda)
63 | bloss = self.balance_loss(mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, batch_size,
64 | den_target, gamma, p)
65 | return closs, sloss, bloss
66 |
--------------------------------------------------------------------------------
/OD/modeling/backbone.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import torchvision.transforms as T
6 |
7 | from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
8 | from ipdb import set_trace as stxx
9 |
10 | @BACKBONE_REGISTRY.register()
11 | class ClipRN101(Backbone):
12 | def __init__(self, cfg, clip_visual):
13 | super().__init__()
14 | self.enc = None
15 | self.unfreeze = cfg.MODEL.BACKBONE.UNFREEZE
16 | self.proj = nn.Linear(512,512)
17 | self.global_proj = nn.Linear(512,512)
18 | self.use_proj = cfg.MODEL.USE_PROJ
19 |
20 |
21 | def set_backbone_model(self,model):
22 | self.enc = model
23 | for name,val in self.enc.named_parameters():
24 | head = name.split('.')[0]
25 | if head not in self.unfreeze:
26 | val.requires_grad = False
27 | else:
28 | val.requires_grad = True
29 |
30 | self.backbone_unchanged = nn.Sequential(*self.enc.layer3[:19])
31 |
32 | def forward(self, image):
33 |
34 | x = image
35 |
36 | batch_num, _, _, _ = x.shape
37 |
38 |
39 | gate_activations = []
40 | x = self.enc.relu1(self.enc.bn1(self.enc.conv1(x)))
41 | x = self.enc.relu2(self.enc.bn2(self.enc.conv2(x)))
42 | x = self.enc.relu3(self.enc.bn3(self.enc.conv3(x)))
43 | x = self.enc.avgpool(x)
44 |
45 | norm1 = torch.zeros(1, batch_num+1).to(x.device)
46 | norm2 = torch.zeros(1, batch_num+1).to(x.device)
47 | flops = torch.zeros(1, batch_num+2).to(x.device)
48 |
49 | x = self.enc.layer1(x)
50 | x = self.enc.layer2(x)
51 | x = self.enc.layer3(x)
52 | return {"res4": x}
53 |
54 |
55 | def forward_l12(self, image):
56 | x = image
57 | x = self.enc.relu1(self.enc.bn1(self.enc.conv1(x)))
58 | x = self.enc.relu2(self.enc.bn2(self.enc.conv2(x)))
59 | x = self.enc.relu3(self.enc.bn3(self.enc.conv3(x)))
60 | x = self.enc.avgpool(x)
61 |
62 | x = self.enc.layer1(x)
63 | x = self.enc.layer2(x)
64 |
65 | return x
66 |
67 | def forward_l3(self, x):
68 | x = self.enc.layer3(x)
69 | return {"res4": x}
70 |
71 |
72 | def output_shape(self):
73 | return {"res4": ShapeSpec(channels=1024, stride=16)}
74 |
75 | # def forward_res5(self,x):
76 | # def forward_res5(self, x, norm1, norm2, flops):
77 | # #detectron used last resnet layer for roi heads
78 | # x, norm1, norm2, flops = self.enc.layer4((x, norm1, norm2, flops))
79 | # # x = self.enc.layer4(x)
80 | # return x
81 |
82 |
83 |
84 | def forward_res5(self, x, txt_emb, norm1, norm2, flops):
85 | #detectron used last resnet layer for roi heads
86 | x, txt_emb, norm1, norm2, flops = self.enc.layer4((x, txt_emb, norm1, norm2, flops))
87 | return x
88 |
89 |
90 |
91 | def attention_global_pool(self,input):
92 | x = input
93 | x = self.enc.attnpool(x)
94 | return x
95 |
96 |
--------------------------------------------------------------------------------
/OD/modeling/.ipynb_checkpoints/backbone-checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import torchvision.transforms as T
6 |
7 | from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
8 | from ipdb import set_trace as stxx
9 |
10 | @BACKBONE_REGISTRY.register()
11 | class ClipRN101(Backbone):
12 | def __init__(self, cfg, clip_visual):
13 | super().__init__()
14 | self.enc = None
15 | self.unfreeze = cfg.MODEL.BACKBONE.UNFREEZE
16 | self.proj = nn.Linear(512,512)
17 | self.global_proj = nn.Linear(512,512)
18 | self.use_proj = cfg.MODEL.USE_PROJ
19 |
20 |
21 | def set_backbone_model(self,model):
22 | self.enc = model
23 | for name,val in self.enc.named_parameters():
24 | head = name.split('.')[0]
25 | if head not in self.unfreeze:
26 | val.requires_grad = False
27 | else:
28 | val.requires_grad = True
29 |
30 | self.backbone_unchanged = nn.Sequential(*self.enc.layer3[:19])
31 |
32 | def forward(self, image):
33 |
34 | x = image
35 |
36 | batch_num, _, _, _ = x.shape
37 |
38 |
39 | gate_activations = []
40 | x = self.enc.relu1(self.enc.bn1(self.enc.conv1(x)))
41 | x = self.enc.relu2(self.enc.bn2(self.enc.conv2(x)))
42 | x = self.enc.relu3(self.enc.bn3(self.enc.conv3(x)))
43 | x = self.enc.avgpool(x)
44 |
45 | norm1 = torch.zeros(1, batch_num+1).to(x.device)
46 | norm2 = torch.zeros(1, batch_num+1).to(x.device)
47 | flops = torch.zeros(1, batch_num+2).to(x.device)
48 |
49 | x = self.enc.layer1(x)
50 | x = self.enc.layer2(x)
51 | x = self.enc.layer3(x)
52 | return {"res4": x}
53 |
54 |
55 | def forward_l12(self, image):
56 | x = image
57 | x = self.enc.relu1(self.enc.bn1(self.enc.conv1(x)))
58 | x = self.enc.relu2(self.enc.bn2(self.enc.conv2(x)))
59 | x = self.enc.relu3(self.enc.bn3(self.enc.conv3(x)))
60 | x = self.enc.avgpool(x)
61 |
62 | x = self.enc.layer1(x)
63 | x = self.enc.layer2(x)
64 |
65 | return x
66 |
67 | def forward_l3(self, x):
68 | x = self.enc.layer3(x)
69 | return {"res4": x}
70 |
71 |
72 | def output_shape(self):
73 | return {"res4": ShapeSpec(channels=1024, stride=16)}
74 |
75 | # def forward_res5(self,x):
76 | # def forward_res5(self, x, norm1, norm2, flops):
77 | # #detectron used last resnet layer for roi heads
78 | # x, norm1, norm2, flops = self.enc.layer4((x, norm1, norm2, flops))
79 | # # x = self.enc.layer4(x)
80 | # return x
81 |
82 |
83 |
84 | def forward_res5(self, x, txt_emb, norm1, norm2, flops):
85 | #detectron used last resnet layer for roi heads
86 | x, txt_emb, norm1, norm2, flops = self.enc.layer4((x, txt_emb, norm1, norm2, flops))
87 | return x
88 |
89 |
90 |
91 | def attention_global_pool(self,input):
92 | x = input
93 | x = self.enc.attnpool(x)
94 | return x
95 |
96 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Prompt-Driven Dynamic Object-Centric Learning for Single Domain Generalization
2 | This repo is the implementation of Prompt-Driven Dynamic Object-Centric Learning for Single Domain Generalization (CVPR 2024).
3 |
4 | > **Abstract:** *Single-domain generalization aims to learn a model from single source domain data to achieve generalized performance on other unseen target domains. Existing works primarily focus on improving the generalization ability of static networks. However, static networks are unable to dynamically adapt to the diverse variations in different image scenes, leading to limited generalization capability. Different scenes exhibit varying levels of complexity, and the complexity of images further varies significantly in cross-domain scenarios. In this paper, we propose a dynamic object-centric perception network based on prompt learning, aiming to adapt to the variations in image complexity. Specifically, we propose an object-centric gating module based on prompt learning to focus attention on the object-centric features guided by the various scene prompts. Then, with the object-centric gating masks, the dynamic selective module dynamically selects highly correlated feature regions in both spatial and channel dimensions enabling the model to adaptively perceive object-centric relevant features, thereby enhancing the generalization capability. Extensive experiments were conducted on single-domain generalization tasks in image classification and object detection. The experimental results demonstrate that our approach outperforms state-of-the-art methods, which validates the effectiveness and generally of our proposed method.*
5 |
6 | ## 1. Illustration of dynamic object-centric learning via prompts for single domain generalization.
7 |
8 |
9 |
10 |
11 |
12 | Object-centric features capture the essential information related to individual objects.
13 | Incorporating the given scene prompts to dynamically optimize the extraction of object-centric features is beneficial for improving the generalization performance of models.
14 |
15 | ## 2. Method
16 |
17 |
18 |
19 |
20 |
21 | The proposed prompt-based dynamic object-centric learning framework. It mainly includes a prompt-based object-centric gating module and a dynamic selective module. First, the Slot Attention multimodal fusion module extracts object-centric features and leverages the various scene prompts to guide the object-centric gating mask learning for the input from different scenes. Next, the gating mask is used to dynamically select the relevant object-centric features to improve the
22 | generalization ability.
23 |
24 | ## 3. Usage
25 | ### 3.1 Prepare data
26 | #### Image Classificaton : PACS (Art paintings, Cartoons, Photos, and Sketches)
27 | #### Object Detection: Diverse-Weather Dataset(Daytime-Sunny, Night-Sunny, Dusk-Rainy, Night-Rainy, and Daytime-Foggy)
28 |
29 |
30 |
31 | ### 3.2 Dependencies
32 |
33 | Python: 3.8.10
34 | PyTorch: 1.9.1
35 | Pillow: 9.5.0
36 | Torchvision: 0.8.2
37 | CUDA: 11.8
38 | NumPy: 1.22.4
39 | PIL: 7.2.0
40 | clip: 1.0
41 | detectron2: 0.6
42 |
43 | ### 3.3 Train and Test
44 |
45 | - Train on source domain
46 | ```
47 | CUDA_VISIBLE_DEVICES=0 python train.py --config-file configs/diverse_weather.yaml
48 | ```
49 |
50 | - Test on target domain (Daytime-Foggy)
51 |
52 | ```
53 | python train.py --config-file configs/diverse_weather_foggy_test.yaml --eval-only MODEL.WEIGHTS all_outs/diverse_weather/model_best.pth > diverse_weather_foggy_test.log
54 | ```
55 |
56 | ## Acknowledgement
57 | Our code is based on the project [Detectron2](https://github.com/facebookresearch/detectron2).
58 |
--------------------------------------------------------------------------------
/OD/data/datasets/pascal_voc_adaptation.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) Facebook, Inc. and its affiliates.
3 |
4 | import numpy as np
5 | import os
6 | import xml.etree.ElementTree as ET
7 | from typing import List, Tuple, Union
8 |
9 | from detectron2.data import DatasetCatalog, MetadataCatalog
10 | from detectron2.structures import BoxMode
11 | from detectron2.utils.file_io import PathManager
12 |
13 | __all__ = ["load_voc_instances", "register_all_pascal_voc"]
14 |
15 |
16 | # fmt: off
17 | CLASS_NAMES = (
18 | "bicycle", "bird", "car", "cat", "dog", "person",
19 | )
20 | # fmt: on
21 |
22 |
23 | def load_voc_instances(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]):
24 | """
25 | Load Pascal VOC detection annotations to Detectron2 format.
26 | Args:
27 | dirname: Contain "Annotations", "ImageSets", "JPEGImages"
28 | split (str): one of "train", "test", "val", "trainval"
29 | class_names: list or tuple of class names
30 | """
31 | with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
32 | fileids = np.loadtxt(f, dtype=np.str)
33 |
34 | # Needs to read many small annotation files. Makes sense at local
35 | annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
36 | dicts = []
37 | for fileid in fileids:
38 | anno_file = os.path.join(annotation_dirname, fileid + ".xml")
39 | jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg")
40 |
41 | with PathManager.open(anno_file) as f:
42 | tree = ET.parse(f)
43 |
44 | r = {
45 | "file_name": jpeg_file,
46 | "image_id": fileid,
47 | "height": int(tree.findall("./size/height")[0].text),
48 | "width": int(tree.findall("./size/width")[0].text),
49 | }
50 | instances = []
51 |
52 | for obj in tree.findall("object"):
53 | cls = obj.find("name").text
54 | if cls not in CLASS_NAMES:
55 | continue
56 | # We include "difficult" samples in training.
57 | # Based on limited experiments, they don't hurt accuracy.
58 | # difficult = int(obj.find("difficult").text)
59 | # if difficult == 1:
60 | # continue
61 | bbox = obj.find("bndbox")
62 | bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]]
63 | # Original annotations are integers in the range [1, W or H]
64 | # Assuming they mean 1-based pixel indices (inclusive),
65 | # a box with annotation (xmin=1, xmax=W) covers the whole image.
66 | # In coordinate space this is represented by (xmin=0, xmax=W)
67 | bbox[0] -= 1.0
68 | bbox[1] -= 1.0
69 | instances.append(
70 | {"category_id": class_names.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS}
71 | )
72 | r["annotations"] = instances
73 | dicts.append(r)
74 | return dicts
75 |
76 |
77 | def register_pascal_voc(name, dirname, split, year, class_names=CLASS_NAMES):
78 | DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names))
79 | MetadataCatalog.get(name).set(
80 | thing_classes=list(class_names), dirname=dirname, year=year, split=split
81 | )
82 |
83 | def register_all_pascal_voc(root):
84 | SPLITS = [
85 | ("voc_adapt_2007_trainval", "VOC2007", "trainval"),
86 | ("voc_adapt_2007_train", "VOC2007", "train"),
87 | ("voc_adapt_2007_val", "VOC2007", "val"),
88 | ("voc_adapt_2007_test", "VOC2007", "test"),
89 | ("voc_adapt_2012_trainval", "VOC2012", "trainval"),
90 | ("voc_adapt_2012_train", "VOC2012", "train"),
91 | ("voc_adapt_2012_val", "VOC2012", "val"),
92 | ]
93 | for name, dirname, split in SPLITS:
94 | year = 2007 if "2007" in name else 2012
95 | register_pascal_voc(name, os.path.join(root, dirname), split, year)
96 | MetadataCatalog.get(name).evaluator_type = "pascal_voc"
--------------------------------------------------------------------------------
/OD/modeling/rpn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from detectron2.modeling import PROPOSAL_GENERATOR_REGISTRY, RPN_HEAD_REGISTRY
6 | from detectron2.modeling.proposal_generator.rpn import RPN, StandardRPNHead
7 | from detectron2.structures import ImageList
8 | from typing import List
9 | import time
10 |
11 |
12 | @PROPOSAL_GENERATOR_REGISTRY.register()
13 | class SBRPN(RPN):
14 |
15 | def forward(
16 | self,
17 | images,
18 | features,
19 | gt_instances= None,
20 | ):
21 | """
22 | Args:
23 | images (ImageList): input images of length `N`
24 | features (dict[str, Tensor]): input data as a mapping from feature
25 | map name to tensor. Axis 0 represents the number of images `N` in
26 | the input data; axes 1-3 are channels, height, and width, which may
27 | vary between feature maps (e.g., if a feature pyramid is used).
28 | gt_instances (list[Instances], optional): a length `N` list of `Instances`s.
29 | Each `Instances` stores ground-truth instances for the corresponding image.
30 | Returns:
31 | proposals: list[Instances]: contains fields "proposal_boxes", "objectness_logits"
32 | loss: dict[Tensor] or None
33 | """
34 | features = [features[f] for f in self.in_features]
35 | anchors = self.anchor_generator(features)
36 | val = self.rpn_head(features)
37 |
38 | if len(val) == 2:
39 | pred_objectness_logits, pred_anchor_deltas = val
40 | rep_clip = None
41 | else:
42 | pred_objectness_logits, pred_anchor_deltas, rep_clip = val
43 |
44 | # Transpose the Hi*Wi*A dimension to the middle:
45 | pred_objectness_logits= [
46 | # (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N, Hi*Wi*A)
47 | score.permute(0, 2, 3, 1).flatten(1)
48 | for score in pred_objectness_logits
49 | ]
50 |
51 |
52 | pred_anchor_deltas = [
53 | # (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B) -> (N, Hi*Wi*A, B)
54 | x.view(x.shape[0], -1, self.anchor_generator.box_dim, x.shape[-2], x.shape[-1])
55 | .permute(0, 3, 4, 1, 2)
56 | .flatten(1, -2)
57 | for x in pred_anchor_deltas
58 | ]
59 |
60 | if self.training:
61 | #assert gt_instances is not None, "RPN requires gt_instances in training!"
62 | if gt_instances is not None:
63 | gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances)
64 | losses = self.losses(
65 | anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes
66 | )
67 | if rep_clip is not None :
68 | if not isinstance(rep_clip, dict):
69 | ll = torch.stack(gt_labels)
70 | ll = ll.reshape(ll.shape[0],-1,15)
71 | valid_mask = ll.ge(0).sum(dim=-1).gt(0) # remove ignored anchors
72 | ll=ll.eq(1).sum(dim=-1) # if an object is present at this location
73 | ll = ll.gt(0).float()
74 |
75 | clip_loss = torch.nn.functional.binary_cross_entropy_with_logits(rep_clip[valid_mask],ll[valid_mask],reduction='sum')
76 | losses.update({'loss_rpn_cls_clip':clip_loss/(self.batch_size_per_image*ll.shape[0])})
77 | else:
78 | losses.update(rep_clip)
79 | else:
80 | losses = {}
81 | else:
82 | losses = {}
83 | proposals = self.predict_proposals(
84 | anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes
85 | )
86 | out = [
87 | # (N, Hi*Wi*A) -> (N, Hi, Wi, A)
88 | score.reshape(features[ind].shape[0],features[ind].shape[-2],features[ind].shape[-1],-1)
89 | for ind, score in enumerate(pred_objectness_logits)
90 | ]
91 | return out, proposals, losses
92 |
93 |
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/OD/modeling/.ipynb_checkpoints/rpn-checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from detectron2.modeling import PROPOSAL_GENERATOR_REGISTRY, RPN_HEAD_REGISTRY
6 | from detectron2.modeling.proposal_generator.rpn import RPN, StandardRPNHead
7 | from detectron2.structures import ImageList
8 | from typing import List
9 | import time
10 |
11 |
12 | @PROPOSAL_GENERATOR_REGISTRY.register()
13 | class SBRPN(RPN):
14 |
15 | def forward(
16 | self,
17 | images,
18 | features,
19 | gt_instances= None,
20 | ):
21 | """
22 | Args:
23 | images (ImageList): input images of length `N`
24 | features (dict[str, Tensor]): input data as a mapping from feature
25 | map name to tensor. Axis 0 represents the number of images `N` in
26 | the input data; axes 1-3 are channels, height, and width, which may
27 | vary between feature maps (e.g., if a feature pyramid is used).
28 | gt_instances (list[Instances], optional): a length `N` list of `Instances`s.
29 | Each `Instances` stores ground-truth instances for the corresponding image.
30 | Returns:
31 | proposals: list[Instances]: contains fields "proposal_boxes", "objectness_logits"
32 | loss: dict[Tensor] or None
33 | """
34 | features = [features[f] for f in self.in_features]
35 | anchors = self.anchor_generator(features)
36 | val = self.rpn_head(features)
37 |
38 | if len(val) == 2:
39 | pred_objectness_logits, pred_anchor_deltas = val
40 | rep_clip = None
41 | else:
42 | pred_objectness_logits, pred_anchor_deltas, rep_clip = val
43 |
44 | # Transpose the Hi*Wi*A dimension to the middle:
45 | pred_objectness_logits= [
46 | # (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N, Hi*Wi*A)
47 | score.permute(0, 2, 3, 1).flatten(1)
48 | for score in pred_objectness_logits
49 | ]
50 |
51 |
52 | pred_anchor_deltas = [
53 | # (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B) -> (N, Hi*Wi*A, B)
54 | x.view(x.shape[0], -1, self.anchor_generator.box_dim, x.shape[-2], x.shape[-1])
55 | .permute(0, 3, 4, 1, 2)
56 | .flatten(1, -2)
57 | for x in pred_anchor_deltas
58 | ]
59 |
60 | if self.training:
61 | #assert gt_instances is not None, "RPN requires gt_instances in training!"
62 | if gt_instances is not None:
63 | gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances)
64 | losses = self.losses(
65 | anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes
66 | )
67 | if rep_clip is not None :
68 | if not isinstance(rep_clip, dict):
69 | ll = torch.stack(gt_labels)
70 | ll = ll.reshape(ll.shape[0],-1,15)
71 | valid_mask = ll.ge(0).sum(dim=-1).gt(0) # remove ignored anchors
72 | ll=ll.eq(1).sum(dim=-1) # if an object is present at this location
73 | ll = ll.gt(0).float()
74 |
75 | clip_loss = torch.nn.functional.binary_cross_entropy_with_logits(rep_clip[valid_mask],ll[valid_mask],reduction='sum')
76 | losses.update({'loss_rpn_cls_clip':clip_loss/(self.batch_size_per_image*ll.shape[0])})
77 | else:
78 | losses.update(rep_clip)
79 | else:
80 | losses = {}
81 | else:
82 | losses = {}
83 | proposals = self.predict_proposals(
84 | anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes
85 | )
86 | out = [
87 | # (N, Hi*Wi*A) -> (N, Hi, Wi, A)
88 | score.reshape(features[ind].shape[0],features[ind].shape[-2],features[ind].shape[-1],-1)
89 | for ind, score in enumerate(pred_objectness_logits)
90 | ]
91 | return out, proposals, losses
92 |
93 |
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/OD/modeling/clip.py:
--------------------------------------------------------------------------------
1 | import clip
2 |
3 | import torch
4 | import torch.nn as nn
5 | import time
6 | import numpy as np
7 | import copy
8 |
9 | class SlotAttention(nn.Module):
10 | def __init__(self, dim=768, iters=3, eps=1e-8, hidden_dim=512, drop_rate=0.4, feature_size=512):
11 | super().__init__()
12 | self.iters = iters
13 | self.eps = eps
14 | self.scale = dim ** -0.5
15 | self.feature_size = feature_size
16 |
17 | self.to_q = nn.Linear(dim, dim)
18 | slot_share_qk = False
19 | if slot_share_qk:
20 | self.to_k = self.to_q
21 | else:
22 | self.to_k = nn.Linear(dim, dim)
23 |
24 | self.to_v = nn.Linear(feature_size, feature_size)
25 |
26 | hidden_dim = max(dim, hidden_dim, feature_size)
27 |
28 | self.gru = nn.GRUCell(feature_size, feature_size)
29 | self.mlp = nn.Sequential(
30 | nn.Linear(feature_size, hidden_dim),
31 | nn.ReLU(inplace=True),
32 | nn.Linear(hidden_dim, feature_size)
33 | )
34 |
35 | self.norm_slots = nn.LayerNorm(feature_size)
36 | self.norm_pre_ff = nn.LayerNorm(feature_size)
37 | self.norm_input = nn.LayerNorm(feature_size)
38 |
39 | self.slot_dropout = nn.Dropout(drop_rate)
40 | self.input_dropout = nn.Dropout(drop_rate)
41 |
42 | def forward(self,cand_feat, pano_feat):
43 |
44 | b, d, device = *pano_feat.shape, pano_feat.device
45 | # original cand_feat as the initial slot
46 | slots = cand_feat.clone()
47 | slots = self.slot_dropout(slots)
48 | pano_feat = self.norm_input(pano_feat.clone())
49 | pano_feat = self.input_dropout(pano_feat)
50 | # (bs, num_ctx, hidden_size)
51 | k = self.to_k(slots)
52 | v = self.to_v(slots)
53 | attn_weights = []
54 | for t in range(self.iters):
55 | slots_prev = slots
56 | slots = self.norm_slots(slots.clone())
57 | # (bs, num_slots, hidden_size)
58 | q = self.to_q(pano_feat.clone())
59 | # (bs, num_slots, num_ctx)
60 | dots = torch.einsum('id,jd->ijd', k, q) * self.scale
61 |
62 | attn = dots.softmax(dim=1)
63 | attn_weights.append(attn) # for visualization
64 | # (bs, num_slots, feature_size)
65 | updates = torch.einsum('id,ijd->id', v, attn)
66 | gru_updates = self.gru(
67 | updates.reshape(-1, self.feature_size),
68 | slots_prev.clone().reshape(-1, self.feature_size)
69 | )
70 | gru_updates = gru_updates + self.mlp(self.norm_pre_ff(gru_updates))
71 | slots = gru_updates.clone()
72 | return slots
73 |
74 |
75 | class ClipPredictor(nn.Module):
76 | def __init__(self, clip_enocder_name,inshape, device, clsnames):
77 | super().__init__()
78 | self.model, self.preprocess = clip.load(clip_enocder_name, device)
79 | self.model.float()
80 | #freeze everything
81 | for name, val in self.model.named_parameters():
82 | val.requires_grad = False
83 | # this is only used for inference
84 | self.frozen_clip_model = copy.deepcopy(self.model)
85 |
86 | self.visual_enc = self.model.visual
87 | prompt = 'a photo of a {}'
88 | print(clsnames)
89 | with torch.no_grad():
90 | text_inputs = torch.cat([clip.tokenize(prompt.format(cls)) for cls in clsnames]).to(device)
91 | self.text_features = self.model.encode_text(text_inputs).float()
92 | self.text_features /= self.text_features.norm(dim=-1, keepdim=True)
93 |
94 |
95 | self.projection = nn.Linear(inshape,512)
96 | self.projection_global = nn.Linear(inshape,512)
97 |
98 | self.slot_attention = SlotAttention(
99 | dim=512,
100 | iters=3,
101 | drop_rate=0,
102 | )
103 |
104 | def forward(self, feat, gfeat=None):
105 |
106 | if feat.shape[-1] > 512:
107 | feat = self.projection(feat)
108 | feat = 0.5* feat + 0.5* self.slot_attention(feat,self.text_features.detach())
109 | feat = feat/feat.norm(dim=-1,keepdim=True)
110 | if gfeat is not None:
111 |
112 | feat = feat-gfeat
113 | feat = feat/feat.norm(dim=-1,keepdim=True)
114 | scores = (100.0 * torch.matmul(feat,self.text_features.detach().T))
115 |
116 | # print(scores.min(),scores.max())
117 | # add for bkg class a score 0
118 | scores = torch.cat([scores,torch.zeros(scores.shape[0],1,device=scores.device)],1)
119 | return scores
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/OD/modeling/.ipynb_checkpoints/clip-checkpoint.py:
--------------------------------------------------------------------------------
1 | import clip
2 |
3 | import torch
4 | import torch.nn as nn
5 | import time
6 | import numpy as np
7 | import copy
8 |
9 | class SlotAttention(nn.Module):
10 | def __init__(self, dim=768, iters=3, eps=1e-8, hidden_dim=512, drop_rate=0.4, feature_size=512):
11 | super().__init__()
12 | self.iters = iters
13 | self.eps = eps
14 | self.scale = dim ** -0.5
15 | self.feature_size = feature_size
16 |
17 | self.to_q = nn.Linear(dim, dim)
18 | slot_share_qk = False
19 | if slot_share_qk:
20 | self.to_k = self.to_q
21 | else:
22 | self.to_k = nn.Linear(dim, dim)
23 |
24 | self.to_v = nn.Linear(feature_size, feature_size)
25 |
26 | hidden_dim = max(dim, hidden_dim, feature_size)
27 |
28 | self.gru = nn.GRUCell(feature_size, feature_size)
29 | self.mlp = nn.Sequential(
30 | nn.Linear(feature_size, hidden_dim),
31 | nn.ReLU(inplace=True),
32 | nn.Linear(hidden_dim, feature_size)
33 | )
34 |
35 | self.norm_slots = nn.LayerNorm(feature_size)
36 | self.norm_pre_ff = nn.LayerNorm(feature_size)
37 | self.norm_input = nn.LayerNorm(feature_size)
38 |
39 | self.slot_dropout = nn.Dropout(drop_rate)
40 | self.input_dropout = nn.Dropout(drop_rate)
41 |
42 | def forward(self,cand_feat, pano_feat):
43 |
44 | b, d, device = *pano_feat.shape, pano_feat.device
45 | # original cand_feat as the initial slot
46 | slots = cand_feat.clone()
47 | slots = self.slot_dropout(slots)
48 | pano_feat = self.norm_input(pano_feat.clone())
49 | pano_feat = self.input_dropout(pano_feat)
50 | # (bs, num_ctx, hidden_size)
51 | k = self.to_k(slots)
52 | v = self.to_v(slots)
53 | attn_weights = []
54 | for t in range(self.iters):
55 | slots_prev = slots
56 | slots = self.norm_slots(slots.clone())
57 | # (bs, num_slots, hidden_size)
58 | q = self.to_q(pano_feat.clone())
59 | # (bs, num_slots, num_ctx)
60 | dots = torch.einsum('id,jd->ijd', k, q) * self.scale
61 |
62 | attn = dots.softmax(dim=1)
63 | attn_weights.append(attn) # for visualization
64 | # (bs, num_slots, feature_size)
65 | updates = torch.einsum('id,ijd->id', v, attn)
66 | gru_updates = self.gru(
67 | updates.reshape(-1, self.feature_size),
68 | slots_prev.clone().reshape(-1, self.feature_size)
69 | )
70 | gru_updates = gru_updates + self.mlp(self.norm_pre_ff(gru_updates))
71 | slots = gru_updates.clone()
72 | return slots
73 |
74 |
75 | class ClipPredictor(nn.Module):
76 | def __init__(self, clip_enocder_name,inshape, device, clsnames):
77 | super().__init__()
78 | self.model, self.preprocess = clip.load(clip_enocder_name, device)
79 | self.model.float()
80 | #freeze everything
81 | for name, val in self.model.named_parameters():
82 | val.requires_grad = False
83 | # this is only used for inference
84 | self.frozen_clip_model = copy.deepcopy(self.model)
85 |
86 | self.visual_enc = self.model.visual
87 | prompt = 'a photo of a {}'
88 | print(clsnames)
89 | with torch.no_grad():
90 | text_inputs = torch.cat([clip.tokenize(prompt.format(cls)) for cls in clsnames]).to(device)
91 | self.text_features = self.model.encode_text(text_inputs).float()
92 | self.text_features /= self.text_features.norm(dim=-1, keepdim=True)
93 |
94 |
95 | self.projection = nn.Linear(inshape,512)
96 | self.projection_global = nn.Linear(inshape,512)
97 |
98 | self.slot_attention = SlotAttention(
99 | dim=512,
100 | iters=3,
101 | drop_rate=0,
102 | )
103 |
104 | def forward(self, feat, gfeat=None):
105 |
106 | if feat.shape[-1] > 512:
107 | feat = self.projection(feat)
108 | feat = 0.5* feat + 0.5* self.slot_attention(feat,self.text_features.detach())
109 | feat = feat/feat.norm(dim=-1,keepdim=True)
110 | if gfeat is not None:
111 |
112 | feat = feat-gfeat
113 | feat = feat/feat.norm(dim=-1,keepdim=True)
114 | scores = (100.0 * torch.matmul(feat,self.text_features.detach().T))
115 |
116 | # print(scores.min(),scores.max())
117 | # add for bkg class a score 0
118 | scores = torch.cat([scores,torch.zeros(scores.shape[0],1,device=scores.device)],1)
119 | return scores
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/OD/data/datasets/diverse_weather.py:
--------------------------------------------------------------------------------
1 | import os
2 | import errno
3 |
4 | from tqdm import tqdm
5 | import pickle as pkl
6 | import xml.etree.ElementTree as ET
7 |
8 | import cv2
9 | import numpy as np
10 | from pymage_size import get_image_size
11 |
12 | from detectron2.data import DatasetCatalog, MetadataCatalog
13 | from detectron2.structures import BoxMode
14 |
15 | all_class_name = ['bus' ,'bike', 'car', 'motor', 'person', 'rider' ,'truck']
16 |
17 | def get_annotation(root, image_id, ind):
18 | annotation_file = os.path.join(root,'VOC2007', "Annotations", "%s.xml" % image_id)
19 |
20 | et = ET.parse(annotation_file)
21 |
22 |
23 | objects = et.findall("object")
24 |
25 | record = {}
26 | record["file_name"] = os.path.join(root, 'VOC2007', "JPEGImages", "%s.jpg" % image_id)
27 | img_format = get_image_size(record["file_name"])
28 | w, h = img_format.get_dimensions()
29 |
30 | record["image_id"] = image_id#ind for pascal evaluation actual image name is needed
31 | record["annotations"] = []
32 |
33 | for obj in objects:
34 | class_name = obj.find('name').text.lower().strip()
35 | if class_name not in all_class_name:
36 | print(class_name)
37 | continue
38 | if obj.find('pose') is None:
39 | obj.append(ET.Element('pose'))
40 | obj.find('pose').text = '0'
41 |
42 | if obj.find('truncated') is None:
43 | obj.append(ET.Element('truncated'))
44 | obj.find('truncated').text = '0'
45 |
46 | if obj.find('difficult') is None:
47 | obj.append(ET.Element('difficult'))
48 | obj.find('difficult').text = '0'
49 |
50 | bbox = obj.find('bndbox')
51 | # VOC dataset format follows Matlab, in which indexes start from 0
52 | x1 = max(0,float(bbox.find('xmin').text) - 1) # fixing when -1 in anno
53 | y1 = max(0,float(bbox.find('ymin').text) - 1) # fixing when -1 in anno
54 | x2 = float(bbox.find('xmax').text) - 1
55 | y2 = float(bbox.find('ymax').text) - 1
56 | box = [x1, y1, x2, y2]
57 |
58 | #pascal voc evaluator requires int
59 | bbox.find('xmin').text = str(int(x1))
60 | bbox.find('ymin').text = str(int(y1))
61 | bbox.find('xmax').text = str(int(x2))
62 | bbox.find('ymax').text = str(int(y2))
63 |
64 |
65 | record_obj = {
66 | "bbox": box,
67 | "bbox_mode": BoxMode.XYXY_ABS,
68 | "category_id": all_class_name.index(class_name),
69 | }
70 | record["annotations"].append(record_obj)
71 |
72 | if len(record["annotations"]):
73 | #to convert float to int
74 | et.write(annotation_file)
75 | record["height"] = h
76 | record["width"] = w
77 | return record
78 |
79 | else:
80 | return None
81 |
82 | def files2dict(root,split):
83 |
84 | cache_dir = os.path.join(root, 'cache')
85 |
86 | pkl_filename = os.path.basename(root)+f'_{split}.pkl'
87 | pkl_path = os.path.join(cache_dir,pkl_filename)
88 |
89 | if os.path.exists(pkl_path):
90 | with open(pkl_path,'rb') as f:
91 | return pkl.load(f)
92 | else:
93 | try:
94 | os.makedirs(cache_dir)
95 | except OSError as e:
96 | if e.errno != errno.EEXIST:
97 | print(e)
98 | pass
99 |
100 | dataset_dicts = []
101 | image_sets_file = os.path.join( root,'VOC2007', "ImageSets", "Main", "%s.txt" % split)
102 |
103 | with open(image_sets_file) as f:
104 | count = 0
105 |
106 | for line in tqdm(f):
107 | record = get_annotation(root,line.rstrip(),count)
108 |
109 | if record is not None:
110 | dataset_dicts.append(record)
111 | count +=1
112 |
113 | with open(pkl_path, 'wb') as f:
114 | pkl.dump(dataset_dicts,f)
115 | return dataset_dicts
116 |
117 |
118 | def register_dataset(datasets_root):
119 | datasets_root = os.path.join(datasets_root,'diverseWeather')
120 | dataset_list = ['daytime_clear',
121 | 'daytime_foggy',
122 | 'night_sunny',
123 | 'night_rainy',
124 | 'dusk_rainy',
125 | ]
126 | settype = ['train','test']
127 |
128 | for name in dataset_list:
129 | for ind, d in enumerate(settype):
130 |
131 | DatasetCatalog.register(name+"_" + d, lambda datasets_root=datasets_root,name=name,d=d \
132 | : files2dict(os.path.join(datasets_root,name), d))
133 | MetadataCatalog.get(name+ "_" + d).set(thing_classes=all_class_name,evaluator_type='pascal_voc')
134 | MetadataCatalog.get(name+ "_" + d).set(dirname=datasets_root+f'/{name}/VOC2007')
135 | MetadataCatalog.get(name+ "_" + d).set(split=d)
136 | MetadataCatalog.get(name+ "_" + d).set(year=2007)
--------------------------------------------------------------------------------
/OD/CLIP/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/OD/CLIP/clip/mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import ipdb
7 |
8 | class SlotAttention(nn.Module):
9 | def __init__(self, dim=768, iters=3, eps=1e-8, hidden_dim=512, drop_rate=0.4, feature_size=512):
10 | super().__init__()
11 | self.iters = iters
12 | self.eps = eps
13 | self.scale = dim ** -0.5
14 | self.feature_size = feature_size
15 |
16 | self.to_q = nn.Linear(dim, dim)
17 | slot_share_qk = False
18 | if slot_share_qk:
19 | self.to_k = self.to_q
20 | else:
21 | self.to_k = nn.Linear(dim, dim)
22 |
23 | self.to_v = nn.Linear(feature_size, feature_size)
24 |
25 | hidden_dim = max(dim, hidden_dim, feature_size)
26 |
27 | self.gru = nn.GRUCell(feature_size, feature_size)
28 | self.mlp = nn.Sequential(
29 | nn.Linear(feature_size, hidden_dim),
30 | nn.ReLU(inplace=True),
31 | nn.Linear(hidden_dim, feature_size)
32 | )
33 |
34 | self.norm_slots = nn.LayerNorm(feature_size)
35 | self.norm_pre_ff = nn.LayerNorm(feature_size)
36 | self.norm_input = nn.LayerNorm(feature_size)
37 |
38 | self.slot_dropout = nn.Dropout(drop_rate)
39 | self.input_dropout = nn.Dropout(drop_rate)
40 |
41 | def forward(self,cand_feat, pano_feat):
42 | b, d = pano_feat.shape
43 | # original cand_feat as the initial slot
44 | slots = cand_feat.clone()
45 | slots = self.slot_dropout(slots)
46 |
47 | pano_feat = self.norm_input(pano_feat.clone())
48 | pano_feat = self.input_dropout(pano_feat)
49 |
50 | # (bs, num_ctx, hidden_size)
51 | k = self.to_k(pano_feat)
52 | v = self.to_v(pano_feat)
53 | attn_weights = []
54 |
55 | for t in range(self.iters):
56 | slots_prev = slots
57 |
58 | slots = self.norm_slots(slots.clone())
59 |
60 | # (bs, num_slots, hidden_size)
61 | q = self.to_q(slots.clone())
62 |
63 | # (bs, num_slots, num_ctx)
64 | dots = torch.einsum('id,jd->ij', q, k) * self.scale
65 |
66 | attn = dots.softmax(dim=1)
67 |
68 | attn_weights.append(attn) # for visualization
69 |
70 | # (bs, num_slots, feature_size)
71 | updates = torch.einsum('jd,ij->id', v, attn)
72 |
73 | gru_updates = self.gru(
74 | updates.reshape(-1, self.feature_size),
75 | slots_prev.clone().reshape(-1, self.feature_size)
76 | )
77 | gru_updates = gru_updates + self.mlp(self.norm_pre_ff(gru_updates))
78 |
79 | slots = gru_updates.clone()
80 |
81 | return slots # , np.stack([a.cpu().detach().numpy() for a in attn_weights], 0)
82 |
83 |
84 | class GumbelSoftmax(nn.Module):
85 | '''
86 | gumbel softmax gate.
87 | '''
88 | def __init__(self, eps=1):
89 | super(GumbelSoftmax, self).__init__()
90 | self.eps = eps
91 | self.sigmoid = nn.Sigmoid()
92 |
93 | def gumbel_sample(self, template_tensor, eps=1e-8):
94 | uniform_samples_tensor = template_tensor.clone().uniform_()
95 | gumble_samples_tensor = torch.log(uniform_samples_tensor+eps)-torch.log(
96 | 1-uniform_samples_tensor+eps)
97 | return gumble_samples_tensor
98 |
99 | def gumbel_softmax(self, logits):
100 | """ Draw a sample from the Gumbel-Softmax distribution"""
101 | gsamples = self.gumbel_sample(logits.data)
102 | logits = logits + Variable(gsamples)
103 | soft_samples = self.sigmoid(logits / self.eps)
104 | return soft_samples, logits
105 |
106 | def forward(self, logits):
107 | if not self.training:
108 | out_hard = (logits>=0).float()
109 | return out_hard
110 | out_soft, prob_soft = self.gumbel_softmax(logits)
111 | out_hard = ((out_soft >= 0.5).float() - out_soft).detach() + out_soft
112 | return out_hard
113 |
114 |
115 |
116 | class Mask_s(nn.Module):
117 | '''
118 | Attention Mask spatial.
119 | '''
120 | def __init__(self, h, w, planes, block_w, block_h, eps=0.66667,
121 | bias=-1, **kwargs):
122 | super(Mask_s, self).__init__()
123 | # Parameter
124 | self.width, self.height, self.channel = w, h, planes
125 | self.mask_h, self.mask_w = int(np.ceil(h / block_h)), int(np.ceil(w / block_w))
126 | self.eleNum_s = torch.Tensor([self.mask_h*self.mask_w])
127 | # spatial attention
128 | self.atten_s = nn.Conv2d(planes, 1, kernel_size=3, stride=1, bias=bias>=0, padding=1)
129 | if bias>=0:
130 | nn.init.constant_(self.atten_s.bias, bias)
131 | # Gate
132 | self.gate_s = GumbelSoftmax(eps=eps)
133 | # Norm
134 | self.norm = lambda x: torch.norm(x, p=1, dim=(1,2,3))
135 |
136 | def forward(self, x):
137 |
138 | batch, channel, height, width = x.size() # torch.Size([256, 64, 56, 56])
139 | # Pooling
140 | input_ds = F.adaptive_avg_pool2d(input=x, output_size=(self.mask_h, self.mask_w)) # torch.Size([256, 64, 7, 7])
141 | # spatial attention
142 | s_in = self.atten_s(input_ds) # [N, 1, h, w]
143 |
144 | # spatial gate
145 | mask_s = self.gate_s(s_in) # [N, 1, h, w]
146 | # norm
147 | norm = self.norm(mask_s)
148 | norm_t = self.eleNum_s.to(x.device)
149 | return mask_s, norm, norm_t
150 |
151 | def get_flops(self):
152 | flops = self.mask_h * self.mask_w * self.channel * 9
153 | return flops
154 |
155 |
156 | class Mask_c(nn.Module):
157 | '''
158 | Attention Mask.
159 | '''
160 | def __init__(self, inplanes, outplanes, fc_reduction=4, eps=0.66667, bias=-1, **kwargs):
161 | super(Mask_c, self).__init__()
162 | # Parameter
163 | self.bottleneck = 512 # inplanes // fc_reduction
164 | self.inplanes, self.outplanes = inplanes, outplanes
165 | self.eleNum_c = torch.Tensor([outplanes])
166 | # channel attention
167 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
168 | self.atten_c_fc1 = nn.Conv2d(inplanes, 512, kernel_size=1)
169 | self.slot_attention = SlotAttention(
170 | dim=512,
171 | iters=3,
172 | drop_rate=0,
173 | )
174 |
175 | self.atten_c_bn = nn.BatchNorm2d(self.bottleneck)
176 | self.atten_c_act = nn.ReLU(inplace=True)
177 | self.atten_c_conv = nn.Conv2d(self.bottleneck, outplanes, kernel_size=1, stride=1, bias=bias>=0)
178 |
179 | if bias>=0:
180 | nn.init.constant_(self.atten_c_conv.bias, bias)
181 | # Gate
182 | self.gate_c = GumbelSoftmax(eps=eps)
183 | # Norm
184 | self.norm = lambda x: torch.norm(x, p=1, dim=(1,2,3))
185 |
186 | def forward(self, x, txt_emb):
187 | batch, channel, _, _ = x.size()
188 | context = self.avg_pool(x) # [N, C, 1, 1]
189 | context = self.atten_c_fc1(context)
190 | # transform
191 | c_in = context+self.slot_attention(context.squeeze(-1).squeeze(-1),txt_emb.detach()).unsqueeze(-1).unsqueeze(-1)
192 | c_in = self.atten_c_bn(c_in)
193 | c_in = self.atten_c_act(c_in)
194 | c_in = self.atten_c_conv(c_in)
195 |
196 | # channel gate
197 | mask_c = self.gate_c(c_in) # [N, C_out, 1, 1]
198 | # norm
199 | norm = self.norm(mask_c)
200 | norm_t = self.eleNum_c.to(x.device)
201 | return mask_c, norm, norm_t
202 |
203 | def get_flops(self):
204 | flops = self.inplanes * self.bottleneck + self.bottleneck * self.outplanes
205 | return flops
206 |
--------------------------------------------------------------------------------
/OD/CLIP/model-card.md:
--------------------------------------------------------------------------------
1 | # Model Card: CLIP
2 |
3 | Inspired by [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf), we’re providing some accompanying information about the multimodal model.
4 |
5 | ## Model Details
6 |
7 | The CLIP model was developed by researchers at OpenAI to learn about what contributes to robustness in computer vision tasks. The model was also developed to test the ability of models to generalize to arbitrary image classification tasks in a zero-shot manner. It was not developed for general model deployment - to deploy models like CLIP, researchers will first need to carefully study their capabilities in relation to the specific context they’re being deployed within.
8 |
9 | ### Model Date
10 |
11 | January 2021
12 |
13 | ### Model Type
14 |
15 | The base model uses a ResNet50 with several modifications as an image encoder and uses a masked self-attention Transformer as a text encoder. These encoders are trained to maximize the similarity of (image, text) pairs via a contrastive loss. There is also a variant of the model where the ResNet image encoder is replaced with a Vision Transformer.
16 |
17 | ### Model Versions
18 |
19 | Initially, we’ve released one CLIP model based on the Vision Transformer architecture equivalent to ViT-B/32, along with the RN50 model, using the architecture equivalent to ResNet-50.
20 |
21 | As part of the staged release process, we have also released the RN101 model, as well as RN50x4, a RN50 scaled up 4x according to the [EfficientNet](https://arxiv.org/abs/1905.11946) scaling rule. In July 2021, we additionally released the RN50x16 and ViT-B/16 models, and in January 2022, the RN50x64 and ViT-L/14 models were released. Lastly, the ViT-L/14@336px model was released in April 2022.
22 |
23 | Please see the paper linked below for further details about their specification.
24 |
25 | ### Documents
26 |
27 | - [Blog Post](https://openai.com/blog/clip/)
28 | - [CLIP Paper](https://arxiv.org/abs/2103.00020)
29 |
30 |
31 |
32 | ## Model Use
33 |
34 | ### Intended Use
35 |
36 | The model is intended as a research output for research communities. We hope that this model will enable researchers to better understand and explore zero-shot, arbitrary image classification. We also hope it can be used for interdisciplinary studies of the potential impact of such models - the CLIP paper includes a discussion of potential downstream impacts to provide an example for this sort of analysis.
37 |
38 | #### Primary intended uses
39 |
40 | The primary intended users of these models are AI researchers.
41 |
42 | We primarily imagine the model will be used by researchers to better understand robustness, generalization, and other capabilities, biases, and constraints of computer vision models.
43 |
44 | ### Out-of-Scope Use Cases
45 |
46 | **Any** deployed use case of the model - whether commercial or not - is currently out of scope. Non-deployed use cases such as image search in a constrained environment, are also not recommended unless there is thorough in-domain testing of the model with a specific, fixed class taxonomy. This is because our safety assessment demonstrated a high need for task specific testing especially given the variability of CLIP’s performance with different class taxonomies. This makes untested and unconstrained deployment of the model in any use case currently potentially harmful.
47 |
48 | Certain use cases which would fall under the domain of surveillance and facial recognition are always out-of-scope regardless of performance of the model. This is because the use of artificial intelligence for tasks such as these can be premature currently given the lack of testing norms and checks to ensure its fair use.
49 |
50 | Since the model has not been purposefully trained in or evaluated on any languages other than English, its use should be limited to English language use cases.
51 |
52 |
53 |
54 | ## Data
55 |
56 | The model was trained on publicly available image-caption data. This was done through a combination of crawling a handful of websites and using commonly-used pre-existing image datasets such as [YFCC100M](http://projects.dfki.uni-kl.de/yfcc100m/). A large portion of the data comes from our crawling of the internet. This means that the data is more representative of people and societies most connected to the internet which tend to skew towards more developed nations, and younger, male users.
57 |
58 | ### Data Mission Statement
59 |
60 | Our goal with building this dataset was to test out robustness and generalizability in computer vision tasks. As a result, the focus was on gathering large quantities of data from different publicly-available internet data sources. The data was gathered in a mostly non-interventionist manner. However, we only crawled websites that had policies against excessively violent and adult images and allowed us to filter out such content. We do not intend for this dataset to be used as the basis for any commercial or deployed model and will not be releasing the dataset.
61 |
62 |
63 |
64 | ## Performance and Limitations
65 |
66 | ### Performance
67 |
68 | We have evaluated the performance of CLIP on a wide range of benchmarks across a variety of computer vision datasets such as OCR to texture recognition to fine-grained classification. The paper describes model performance on the following datasets:
69 |
70 | - Food101
71 | - CIFAR10
72 | - CIFAR100
73 | - Birdsnap
74 | - SUN397
75 | - Stanford Cars
76 | - FGVC Aircraft
77 | - VOC2007
78 | - DTD
79 | - Oxford-IIIT Pet dataset
80 | - Caltech101
81 | - Flowers102
82 | - MNIST
83 | - SVHN
84 | - IIIT5K
85 | - Hateful Memes
86 | - SST-2
87 | - UCF101
88 | - Kinetics700
89 | - Country211
90 | - CLEVR Counting
91 | - KITTI Distance
92 | - STL-10
93 | - RareAct
94 | - Flickr30
95 | - MSCOCO
96 | - ImageNet
97 | - ImageNet-A
98 | - ImageNet-R
99 | - ImageNet Sketch
100 | - ObjectNet (ImageNet Overlap)
101 | - Youtube-BB
102 | - ImageNet-Vid
103 |
104 | ## Limitations
105 |
106 | CLIP and our analysis of it have a number of limitations. CLIP currently struggles with respect to certain tasks such as fine grained classification and counting objects. CLIP also poses issues with regards to fairness and bias which we discuss in the paper and briefly in the next section. Additionally, our approach to testing CLIP also has an important limitation- in many cases we have used linear probes to evaluate the performance of CLIP and there is evidence suggesting that linear probes can underestimate model performance.
107 |
108 | ### Bias and Fairness
109 |
110 | We find that the performance of CLIP - and the specific biases it exhibits - can depend significantly on class design and the choices one makes for categories to include and exclude. We tested the risk of certain kinds of denigration with CLIP by classifying images of people from [Fairface](https://arxiv.org/abs/1908.04913) into crime-related and non-human animal categories. We found significant disparities with respect to race and gender. Additionally, we found that these disparities could shift based on how the classes were constructed. (Details captured in the Broader Impacts Section in the paper).
111 |
112 | We also tested the performance of CLIP on gender, race and age classification using the Fairface dataset (We default to using race categories as they are constructed in the Fairface dataset.) in order to assess quality of performance across different demographics. We found accuracy >96% across all races for gender classification with ‘Middle Eastern’ having the highest accuracy (98.4%) and ‘White’ having the lowest (96.5%). Additionally, CLIP averaged ~93% for racial classification and ~63% for age classification. Our use of evaluations to test for gender, race and age classification as well as denigration harms is simply to evaluate performance of the model across people and surface potential risks and not to demonstrate an endorsement/enthusiasm for such tasks.
113 |
114 |
115 |
116 | ## Feedback
117 |
118 | ### Where to send questions or comments about the model
119 |
120 | Please use [this Google Form](https://forms.gle/Uv7afRH5dvY34ZEs9)
121 |
--------------------------------------------------------------------------------
/OD/CLIP/README.md:
--------------------------------------------------------------------------------
1 | # CLIP
2 |
3 | [[Blog]](https://openai.com/blog/clip/) [[Paper]](https://arxiv.org/abs/2103.00020) [[Model Card]](model-card.md) [[Colab]](https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Interacting_with_CLIP.ipynb)
4 |
5 | CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision.
6 |
7 |
8 |
9 | ## Approach
10 |
11 | 
12 |
13 |
14 |
15 | ## Usage
16 |
17 | First, [install PyTorch 1.7.1](https://pytorch.org/get-started/locally/) (or later) and torchvision, as well as small additional dependencies, and then install this repo as a Python package. On a CUDA GPU machine, the following will do the trick:
18 |
19 | ```bash
20 | $ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
21 | $ pip install ftfy regex tqdm
22 | $ pip install git+https://github.com/openai/CLIP.git
23 | ```
24 |
25 | Replace `cudatoolkit=11.0` above with the appropriate CUDA version on your machine or `cpuonly` when installing on a machine without a GPU.
26 |
27 | ```python
28 | import torch
29 | import clip
30 | from PIL import Image
31 |
32 | device = "cuda" if torch.cuda.is_available() else "cpu"
33 | model, preprocess = clip.load("ViT-B/32", device=device)
34 |
35 | image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
36 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
37 |
38 | with torch.no_grad():
39 | image_features = model.encode_image(image)
40 | text_features = model.encode_text(text)
41 |
42 | logits_per_image, logits_per_text = model(image, text)
43 | probs = logits_per_image.softmax(dim=-1).cpu().numpy()
44 |
45 | print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
46 | ```
47 |
48 |
49 | ## API
50 |
51 | The CLIP module `clip` provides the following methods:
52 |
53 | #### `clip.available_models()`
54 |
55 | Returns the names of the available CLIP models.
56 |
57 | #### `clip.load(name, device=..., jit=False)`
58 |
59 | Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The `name` argument can also be a path to a local checkpoint.
60 |
61 | The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When `jit` is `False`, a non-JIT version of the model will be loaded.
62 |
63 | #### `clip.tokenize(text: Union[str, List[str]], context_length=77)`
64 |
65 | Returns a LongTensor containing tokenized sequences of given text input(s). This can be used as the input to the model
66 |
67 | ---
68 |
69 | The model returned by `clip.load()` supports the following methods:
70 |
71 | #### `model.encode_image(image: Tensor)`
72 |
73 | Given a batch of images, returns the image features encoded by the vision portion of the CLIP model.
74 |
75 | #### `model.encode_text(text: Tensor)`
76 |
77 | Given a batch of text tokens, returns the text features encoded by the language portion of the CLIP model.
78 |
79 | #### `model(image: Tensor, text: Tensor)`
80 |
81 | Given a batch of images and a batch of text tokens, returns two Tensors, containing the logit scores corresponding to each image and text input. The values are cosine similarities between the corresponding image and text features, times 100.
82 |
83 |
84 |
85 | ## More Examples
86 |
87 | ### Zero-Shot Prediction
88 |
89 | The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset.
90 |
91 | ```python
92 | import os
93 | import clip
94 | import torch
95 | from torchvision.datasets import CIFAR100
96 |
97 | # Load the model
98 | device = "cuda" if torch.cuda.is_available() else "cpu"
99 | model, preprocess = clip.load('ViT-B/32', device)
100 |
101 | # Download the dataset
102 | cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
103 |
104 | # Prepare the inputs
105 | image, class_id = cifar100[3637]
106 | image_input = preprocess(image).unsqueeze(0).to(device)
107 | text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
108 |
109 | # Calculate features
110 | with torch.no_grad():
111 | image_features = model.encode_image(image_input)
112 | text_features = model.encode_text(text_inputs)
113 |
114 | # Pick the top 5 most similar labels for the image
115 | image_features /= image_features.norm(dim=-1, keepdim=True)
116 | text_features /= text_features.norm(dim=-1, keepdim=True)
117 | similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
118 | values, indices = similarity[0].topk(5)
119 |
120 | # Print the result
121 | print("\nTop predictions:\n")
122 | for value, index in zip(values, indices):
123 | print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
124 | ```
125 |
126 | The output will look like the following (the exact numbers may be slightly different depending on the compute device):
127 |
128 | ```
129 | Top predictions:
130 |
131 | snake: 65.31%
132 | turtle: 12.29%
133 | sweet_pepper: 3.83%
134 | lizard: 1.88%
135 | crocodile: 1.75%
136 | ```
137 |
138 | Note that this example uses the `encode_image()` and `encode_text()` methods that return the encoded features of given inputs.
139 |
140 |
141 | ### Linear-probe evaluation
142 |
143 | The example below uses [scikit-learn](https://scikit-learn.org/) to perform logistic regression on image features.
144 |
145 | ```python
146 | import os
147 | import clip
148 | import torch
149 |
150 | import numpy as np
151 | from sklearn.linear_model import LogisticRegression
152 | from torch.utils.data import DataLoader
153 | from torchvision.datasets import CIFAR100
154 | from tqdm import tqdm
155 |
156 | # Load the model
157 | device = "cuda" if torch.cuda.is_available() else "cpu"
158 | model, preprocess = clip.load('ViT-B/32', device)
159 |
160 | # Load the dataset
161 | root = os.path.expanduser("~/.cache")
162 | train = CIFAR100(root, download=True, train=True, transform=preprocess)
163 | test = CIFAR100(root, download=True, train=False, transform=preprocess)
164 |
165 |
166 | def get_features(dataset):
167 | all_features = []
168 | all_labels = []
169 |
170 | with torch.no_grad():
171 | for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
172 | features = model.encode_image(images.to(device))
173 |
174 | all_features.append(features)
175 | all_labels.append(labels)
176 |
177 | return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()
178 |
179 | # Calculate the image features
180 | train_features, train_labels = get_features(train)
181 | test_features, test_labels = get_features(test)
182 |
183 | # Perform logistic regression
184 | classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
185 | classifier.fit(train_features, train_labels)
186 |
187 | # Evaluate using the logistic regression classifier
188 | predictions = classifier.predict(test_features)
189 | accuracy = np.mean((test_labels == predictions).astype(float)) * 100.
190 | print(f"Accuracy = {accuracy:.3f}")
191 | ```
192 |
193 | Note that the `C` value should be determined via a hyperparameter sweep using a validation split.
194 |
195 |
196 | ## See Also
197 |
198 | * [OpenCLIP](https://github.com/mlfoundations/open_clip): includes larger and independently trained CLIP models up to ViT-G/14
199 | * [Hugging Face implementation of CLIP](https://huggingface.co/docs/transformers/model_doc/clip): for easier integration with the HF ecosystem
200 |
--------------------------------------------------------------------------------
/OD/modeling/roi_head.py:
--------------------------------------------------------------------------------
1 | from pydoc import classname
2 | from typing import Dict, List, Optional, Tuple
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | import torchvision.transforms as T
8 |
9 |
10 | from detectron2.layers import ShapeSpec
11 | from detectron2.data import MetadataCatalog
12 |
13 | from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads
14 | from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
15 | from .box_predictor import ClipFastRCNNOutputLayers
16 |
17 | def select_foreground_proposals(
18 | proposals: List[Instances], bg_label: int
19 | ) -> Tuple[List[Instances], List[torch.Tensor]]:
20 | """
21 | Given a list of N Instances (for N images), each containing a `gt_classes` field,
22 | return a list of Instances that contain only instances with `gt_classes != -1 &&
23 | gt_classes != bg_label`.
24 | Args:
25 | proposals (list[Instances]): A list of N Instances, where N is the number of
26 | images in the batch.
27 | bg_label: label index of background class.
28 | Returns:
29 | list[Instances]: N Instances, each contains only the selected foreground instances.
30 | list[Tensor]: N boolean vector, correspond to the selection mask of
31 | each Instances object. True for selected instances.
32 | """
33 | assert isinstance(proposals, (list, tuple))
34 | assert isinstance(proposals[0], Instances)
35 | assert proposals[0].has("gt_classes")
36 | fg_proposals = []
37 | fg_selection_masks = []
38 | for proposals_per_image in proposals:
39 | gt_classes = proposals_per_image.gt_classes
40 | fg_selection_mask = (gt_classes != -1) & (gt_classes != bg_label)
41 | fg_idxs = fg_selection_mask.nonzero().squeeze(1)
42 | fg_proposals.append(proposals_per_image[fg_idxs])
43 | fg_selection_masks.append(fg_selection_mask)
44 | return fg_proposals, fg_selection_masks
45 |
46 |
47 | @ROI_HEADS_REGISTRY.register()
48 | class ClipRes5ROIHeads(Res5ROIHeads):
49 | def __init__(self, cfg, input_shape) -> None:
50 | super().__init__(cfg, input_shape)
51 | clsnames = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).get("thing_classes").copy()
52 |
53 | # import pdb;pdb.set_trace()
54 | ##change the labels to represent the objects correctly
55 | for name in cfg.MODEL.RENAME:
56 | ind = clsnames.index(name[0])
57 | clsnames[ind] = name[1]
58 |
59 | out_channels=cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * (2 ** 3) ### copied
60 | self.box_predictor = ClipFastRCNNOutputLayers(cfg, ShapeSpec(channels=out_channels, height=1, width=1), clsnames)
61 | self.clip_im_predictor = self.box_predictor.cls_score # should call it properly
62 | self.device = cfg.MODEL.DEVICE
63 | def forward(
64 | self,
65 | images: ImageList,
66 | features: Dict[str, torch.Tensor],
67 | proposals: List[Instances],
68 | targets: Optional[List[Instances]] = None,
69 | crops: Optional[List[Tuple]] = None,
70 | ):
71 | """
72 | See :meth:`ROIHeads.forward`.
73 | """
74 | del images
75 |
76 | if self.training:
77 | assert targets
78 | proposals = self.label_and_sample_proposals(proposals, targets)
79 | # import pdb;pdb.set_trace()
80 | loss_crop_im = None
81 | if crops is not None:
82 | crop_im = list()#[x[0] for x in crops] #bxcropx3x224x224
83 | crop_boxes = list()#[x[1].to(self.device) for x in crops] #bxcropsx4
84 | keep = torch.ones(len(crops)).bool()
85 |
86 | for ind,x in enumerate(crops):
87 | if len(x) == 0:
88 | keep[ind] = False
89 | continue
90 | crop_im.append(x[0])
91 | crop_boxes.append(x[1].to(self.device))
92 |
93 | c = self._shared_roi_transform(
94 | [features[f][keep] for f in self.in_features], crop_boxes) #(b*crops)x2048x7x7
95 | loss_crop_im, _ = self.clip_im_predictor.forward_crops(crop_im,crops_features.mean(dim=[2, 3]))
96 |
97 | del targets
98 |
99 | proposal_boxes = [x.proposal_boxes for x in proposals]
100 | box_features = self._shared_roi_transform(
101 | [features[f] for f in self.in_features], proposal_boxes
102 | )
103 | predictions = self.box_predictor(box_features.mean(dim=[2, 3]))
104 | # import pdb;pdb.set_trace()
105 | if self.training:
106 | del features
107 | losses = self.box_predictor.losses(predictions, proposals)
108 | if self.mask_on:
109 | proposals, fg_selection_masks = select_foreground_proposals(
110 | proposals, self.num_classes
111 | )
112 | # Since the ROI feature transform is shared between boxes and masks,
113 | # we don't need to recompute features. The mask loss is only defined
114 | # on foreground proposals, so we need to select out the foreground
115 | # features.
116 | mask_features = box_features[torch.cat(fg_selection_masks, dim=0)]
117 | del box_features
118 | losses.update(self.mask_head(mask_features, proposals))
119 |
120 | if loss_crop_im is not None:
121 | losses.update(loss_crop_im)
122 | return [], losses
123 | else:
124 | pred_instances, _ = self.box_predictor.inference(predictions, proposals)
125 | pred_instances = self.forward_with_given_boxes(features, pred_instances)
126 | return pred_instances, {}
127 |
128 |
129 | @ROI_HEADS_REGISTRY.register()
130 | class ClipRes5ROIHeadsAttn(ClipRes5ROIHeads):
131 | def __init__(self, cfg, input_shape) -> None:
132 | super().__init__(cfg, input_shape)
133 |
134 | def _shared_roi_transform(self, features, txt_emb, boxes):
135 | x = self.pooler(features, boxes)
136 | batch_num, _, _, _ = x.shape
137 | norm1 = torch.zeros(1, batch_num+1).to(x.device)
138 | norm2 = torch.zeros(1, batch_num+1).to(x.device)
139 | flops = torch.zeros(1, batch_num+2).to(x.device)
140 |
141 | return self.fwdres5(x,txt_emb,norm1,norm2,flops)
142 |
143 |
144 |
145 | def forward(
146 | self,
147 | images: ImageList,
148 | txt_emb: torch.Tensor,
149 | features: Dict[str, torch.Tensor],
150 | proposals: List[Instances],
151 | targets: Optional[List[Instances]] = None,
152 | crops: Optional[List[Tuple]] = None,
153 | backbone = None
154 | ):
155 | """
156 | See :meth:`ROIHeads.forward`.
157 | """
158 | del images
159 |
160 | self.fwdres5 = backbone.forward_res5
161 |
162 | if self.training:
163 | assert targets
164 | proposals = self.label_and_sample_proposals(proposals, targets)
165 | # import pdb;pdb.set_trace()
166 | loss_crop_im = None
167 | if crops is not None:
168 | crop_im = list()#[x[0] for x in crops] #bxcropx3x224x224
169 | crop_boxes = list()#[x[1].to(self.device) for x in crops] #bxcropsx4
170 | keep = torch.ones(len(crops)).bool()
171 |
172 | for ind,x in enumerate(crops):
173 | if len(x) == 0:
174 | keep[ind] = False
175 | continue
176 | crop_im.append(x[0])
177 | crop_boxes.append(x[1].to(self.device))
178 |
179 | crops_features = self._shared_roi_transform(
180 | [features[f][keep] for f in self.in_features], txt_emb, crop_boxes) #(b*crops)x2048x7x7
181 | crops_features = backbone.attention_global_pool(crops_features)
182 | loss_crop_im, _ = self.clip_im_predictor.forward_crops(crop_im,crops_features)
183 |
184 | del targets
185 |
186 | proposal_boxes = [x.proposal_boxes for x in proposals]
187 | box_features = self._shared_roi_transform(
188 | [features[f] for f in self.in_features], txt_emb, proposal_boxes
189 | )
190 |
191 | attn_feat = backbone.attention_global_pool(box_features)
192 | predictions = self.box_predictor([attn_feat,box_features.mean(dim=(2,3))])
193 |
194 | if self.training:
195 | del features
196 |
197 | losses = self.box_predictor.losses(predictions, proposals)
198 |
199 | if self.mask_on:
200 | proposals, fg_selection_masks = select_foreground_proposals(
201 | proposals, self.num_classes
202 | )
203 | # Since the ROI feature transform is shared between boxes and masks,
204 | # we don't need to recompute features. The mask loss is only defined
205 | # on foreground proposals, so we need to select out the foreground
206 | # features.
207 | mask_features = box_features[torch.cat(fg_selection_masks, dim=0)]
208 | del box_features
209 | losses.update(self.mask_head(mask_features, proposals))
210 |
211 | if loss_crop_im is not None:
212 | losses.update(loss_crop_im)
213 | return [], losses
214 | else:
215 | pred_instances, _ = self.box_predictor.inference(predictions, proposals)
216 | pred_instances = self.forward_with_given_boxes(features, pred_instances)
217 | return pred_instances, {}
218 |
219 |
220 |
--------------------------------------------------------------------------------
/OD/modeling/.ipynb_checkpoints/roi_head-checkpoint.py:
--------------------------------------------------------------------------------
1 | from pydoc import classname
2 | from typing import Dict, List, Optional, Tuple
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | import torchvision.transforms as T
8 |
9 |
10 | from detectron2.layers import ShapeSpec
11 | from detectron2.data import MetadataCatalog
12 |
13 | from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads
14 | from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou
15 | from .box_predictor import ClipFastRCNNOutputLayers
16 |
17 | def select_foreground_proposals(
18 | proposals: List[Instances], bg_label: int
19 | ) -> Tuple[List[Instances], List[torch.Tensor]]:
20 | """
21 | Given a list of N Instances (for N images), each containing a `gt_classes` field,
22 | return a list of Instances that contain only instances with `gt_classes != -1 &&
23 | gt_classes != bg_label`.
24 | Args:
25 | proposals (list[Instances]): A list of N Instances, where N is the number of
26 | images in the batch.
27 | bg_label: label index of background class.
28 | Returns:
29 | list[Instances]: N Instances, each contains only the selected foreground instances.
30 | list[Tensor]: N boolean vector, correspond to the selection mask of
31 | each Instances object. True for selected instances.
32 | """
33 | assert isinstance(proposals, (list, tuple))
34 | assert isinstance(proposals[0], Instances)
35 | assert proposals[0].has("gt_classes")
36 | fg_proposals = []
37 | fg_selection_masks = []
38 | for proposals_per_image in proposals:
39 | gt_classes = proposals_per_image.gt_classes
40 | fg_selection_mask = (gt_classes != -1) & (gt_classes != bg_label)
41 | fg_idxs = fg_selection_mask.nonzero().squeeze(1)
42 | fg_proposals.append(proposals_per_image[fg_idxs])
43 | fg_selection_masks.append(fg_selection_mask)
44 | return fg_proposals, fg_selection_masks
45 |
46 |
47 | @ROI_HEADS_REGISTRY.register()
48 | class ClipRes5ROIHeads(Res5ROIHeads):
49 | def __init__(self, cfg, input_shape) -> None:
50 | super().__init__(cfg, input_shape)
51 | clsnames = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).get("thing_classes").copy()
52 |
53 | # import pdb;pdb.set_trace()
54 | ##change the labels to represent the objects correctly
55 | for name in cfg.MODEL.RENAME:
56 | ind = clsnames.index(name[0])
57 | clsnames[ind] = name[1]
58 |
59 | out_channels=cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * (2 ** 3) ### copied
60 | self.box_predictor = ClipFastRCNNOutputLayers(cfg, ShapeSpec(channels=out_channels, height=1, width=1), clsnames)
61 | self.clip_im_predictor = self.box_predictor.cls_score # should call it properly
62 | self.device = cfg.MODEL.DEVICE
63 | def forward(
64 | self,
65 | images: ImageList,
66 | features: Dict[str, torch.Tensor],
67 | proposals: List[Instances],
68 | targets: Optional[List[Instances]] = None,
69 | crops: Optional[List[Tuple]] = None,
70 | ):
71 | """
72 | See :meth:`ROIHeads.forward`.
73 | """
74 | del images
75 |
76 | if self.training:
77 | assert targets
78 | proposals = self.label_and_sample_proposals(proposals, targets)
79 | # import pdb;pdb.set_trace()
80 | loss_crop_im = None
81 | if crops is not None:
82 | crop_im = list()#[x[0] for x in crops] #bxcropx3x224x224
83 | crop_boxes = list()#[x[1].to(self.device) for x in crops] #bxcropsx4
84 | keep = torch.ones(len(crops)).bool()
85 |
86 | for ind,x in enumerate(crops):
87 | if len(x) == 0:
88 | keep[ind] = False
89 | continue
90 | crop_im.append(x[0])
91 | crop_boxes.append(x[1].to(self.device))
92 |
93 | c = self._shared_roi_transform(
94 | [features[f][keep] for f in self.in_features], crop_boxes) #(b*crops)x2048x7x7
95 | loss_crop_im, _ = self.clip_im_predictor.forward_crops(crop_im,crops_features.mean(dim=[2, 3]))
96 |
97 | del targets
98 |
99 | proposal_boxes = [x.proposal_boxes for x in proposals]
100 | box_features = self._shared_roi_transform(
101 | [features[f] for f in self.in_features], proposal_boxes
102 | )
103 | predictions = self.box_predictor(box_features.mean(dim=[2, 3]))
104 | # import pdb;pdb.set_trace()
105 | if self.training:
106 | del features
107 | losses = self.box_predictor.losses(predictions, proposals)
108 | if self.mask_on:
109 | proposals, fg_selection_masks = select_foreground_proposals(
110 | proposals, self.num_classes
111 | )
112 | # Since the ROI feature transform is shared between boxes and masks,
113 | # we don't need to recompute features. The mask loss is only defined
114 | # on foreground proposals, so we need to select out the foreground
115 | # features.
116 | mask_features = box_features[torch.cat(fg_selection_masks, dim=0)]
117 | del box_features
118 | losses.update(self.mask_head(mask_features, proposals))
119 |
120 | if loss_crop_im is not None:
121 | losses.update(loss_crop_im)
122 | return [], losses
123 | else:
124 | pred_instances, _ = self.box_predictor.inference(predictions, proposals)
125 | pred_instances = self.forward_with_given_boxes(features, pred_instances)
126 | return pred_instances, {}
127 |
128 |
129 | @ROI_HEADS_REGISTRY.register()
130 | class ClipRes5ROIHeadsAttn(ClipRes5ROIHeads):
131 | def __init__(self, cfg, input_shape) -> None:
132 | super().__init__(cfg, input_shape)
133 |
134 | def _shared_roi_transform(self, features, txt_emb, boxes):
135 | x = self.pooler(features, boxes)
136 | batch_num, _, _, _ = x.shape
137 | norm1 = torch.zeros(1, batch_num+1).to(x.device)
138 | norm2 = torch.zeros(1, batch_num+1).to(x.device)
139 | flops = torch.zeros(1, batch_num+2).to(x.device)
140 |
141 | return self.fwdres5(x,txt_emb,norm1,norm2,flops)
142 |
143 |
144 |
145 | def forward(
146 | self,
147 | images: ImageList,
148 | txt_emb: torch.Tensor,
149 | features: Dict[str, torch.Tensor],
150 | proposals: List[Instances],
151 | targets: Optional[List[Instances]] = None,
152 | crops: Optional[List[Tuple]] = None,
153 | backbone = None
154 | ):
155 | """
156 | See :meth:`ROIHeads.forward`.
157 | """
158 | del images
159 |
160 | self.fwdres5 = backbone.forward_res5
161 |
162 | if self.training:
163 | assert targets
164 | proposals = self.label_and_sample_proposals(proposals, targets)
165 | # import pdb;pdb.set_trace()
166 | loss_crop_im = None
167 | if crops is not None:
168 | crop_im = list()#[x[0] for x in crops] #bxcropx3x224x224
169 | crop_boxes = list()#[x[1].to(self.device) for x in crops] #bxcropsx4
170 | keep = torch.ones(len(crops)).bool()
171 |
172 | for ind,x in enumerate(crops):
173 | if len(x) == 0:
174 | keep[ind] = False
175 | continue
176 | crop_im.append(x[0])
177 | crop_boxes.append(x[1].to(self.device))
178 |
179 | crops_features = self._shared_roi_transform(
180 | [features[f][keep] for f in self.in_features], txt_emb, crop_boxes) #(b*crops)x2048x7x7
181 | crops_features = backbone.attention_global_pool(crops_features)
182 | loss_crop_im, _ = self.clip_im_predictor.forward_crops(crop_im,crops_features)
183 |
184 | del targets
185 |
186 | proposal_boxes = [x.proposal_boxes for x in proposals]
187 | box_features = self._shared_roi_transform(
188 | [features[f] for f in self.in_features], txt_emb, proposal_boxes
189 | )
190 |
191 | attn_feat = backbone.attention_global_pool(box_features)
192 | predictions = self.box_predictor([attn_feat,box_features.mean(dim=(2,3))])
193 |
194 | if self.training:
195 | del features
196 |
197 | losses = self.box_predictor.losses(predictions, proposals)
198 |
199 | if self.mask_on:
200 | proposals, fg_selection_masks = select_foreground_proposals(
201 | proposals, self.num_classes
202 | )
203 | # Since the ROI feature transform is shared between boxes and masks,
204 | # we don't need to recompute features. The mask loss is only defined
205 | # on foreground proposals, so we need to select out the foreground
206 | # features.
207 | mask_features = box_features[torch.cat(fg_selection_masks, dim=0)]
208 | del box_features
209 | losses.update(self.mask_head(mask_features, proposals))
210 |
211 | if loss_crop_im is not None:
212 | losses.update(loss_crop_im)
213 | return [], losses
214 | else:
215 | pred_instances, _ = self.box_predictor.inference(predictions, proposals)
216 | pred_instances = self.forward_with_given_boxes(features, pred_instances)
217 | return pred_instances, {}
218 |
219 |
220 |
--------------------------------------------------------------------------------
/OD/CLIP/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from .model import build_model
14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 | BICUBIC = InterpolationMode.BICUBIC
19 | except ImportError:
20 | BICUBIC = Image.BICUBIC
21 |
22 |
23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 |
27 | __all__ = ["available_models", "load", "tokenize"]
28 | _tokenizer = _Tokenizer()
29 |
30 | _MODELS = {
31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40 | }
41 |
42 |
43 | def _download(url: str, root: str):
44 | os.makedirs(root, exist_ok=True)
45 | filename = os.path.basename(url)
46 |
47 | expected_sha256 = url.split("/")[-2]
48 | download_target = os.path.join(root, filename)
49 |
50 | if os.path.exists(download_target) and not os.path.isfile(download_target):
51 | raise RuntimeError(f"{download_target} exists and is not a regular file")
52 |
53 | if os.path.isfile(download_target):
54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55 | return download_target
56 | else:
57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58 |
59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61 | while True:
62 | buffer = source.read(8192)
63 | if not buffer:
64 | break
65 |
66 | output.write(buffer)
67 | loop.update(len(buffer))
68 |
69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71 |
72 | return download_target
73 |
74 |
75 | def _convert_image_to_rgb(image):
76 | return image.convert("RGB")
77 |
78 |
79 | def _transform(n_px):
80 | return Compose([
81 | Resize(n_px, interpolation=BICUBIC),
82 | CenterCrop(n_px),
83 | _convert_image_to_rgb,
84 | ToTensor(),
85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86 | ])
87 |
88 |
89 | def available_models() -> List[str]:
90 | """Returns the names of available CLIP models"""
91 | return list(_MODELS.keys())
92 |
93 |
94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
95 | """Load a CLIP model
96 |
97 | Parameters
98 | ----------
99 | name : str
100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
101 |
102 | device : Union[str, torch.device]
103 | The device to put the loaded model
104 |
105 | jit : bool
106 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
107 |
108 | download_root: str
109 | path to download the model files; by default, it uses "~/.cache/clip"
110 |
111 | Returns
112 | -------
113 | model : torch.nn.Module
114 | The CLIP model
115 |
116 | preprocess : Callable[[PIL.Image], torch.Tensor]
117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
118 | """
119 | if name in _MODELS:
120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
121 | elif os.path.isfile(name):
122 | model_path = name
123 | else:
124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
125 |
126 | with open(model_path, 'rb') as opened_file:
127 | try:
128 | # loading JIT archive
129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
130 | state_dict = None
131 | except RuntimeError:
132 | # loading saved state dict
133 | if jit:
134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
135 | jit = False
136 | state_dict = torch.load(opened_file, map_location="cpu")
137 |
138 | if not jit:
139 | model = build_model(state_dict or model.state_dict()).to(device)
140 | if str(device) == "cpu":
141 | model.float()
142 | return model, _transform(model.visual.input_resolution)
143 |
144 | # patch the device names
145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
147 |
148 | def _node_get(node: torch._C.Node, key: str):
149 | """Gets attributes of a node which is polymorphic over return type.
150 |
151 | From https://github.com/pytorch/pytorch/pull/82628
152 | """
153 | sel = node.kindOf(key)
154 | return getattr(node, sel)(key)
155 |
156 | def patch_device(module):
157 | try:
158 | graphs = [module.graph] if hasattr(module, "graph") else []
159 | except RuntimeError:
160 | graphs = []
161 |
162 | if hasattr(module, "forward1"):
163 | graphs.append(module.forward1.graph)
164 |
165 | for graph in graphs:
166 | for node in graph.findAllNodes("prim::Constant"):
167 | if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
168 | node.copyAttributes(device_node)
169 |
170 | model.apply(patch_device)
171 | patch_device(model.encode_image)
172 | patch_device(model.encode_text)
173 |
174 | # patch dtype to float32 on CPU
175 | if str(device) == "cpu":
176 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
177 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
178 | float_node = float_input.node()
179 |
180 | def patch_float(module):
181 | try:
182 | graphs = [module.graph] if hasattr(module, "graph") else []
183 | except RuntimeError:
184 | graphs = []
185 |
186 | if hasattr(module, "forward1"):
187 | graphs.append(module.forward1.graph)
188 |
189 | for graph in graphs:
190 | for node in graph.findAllNodes("aten::to"):
191 | inputs = list(node.inputs())
192 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
193 | if _node_get(inputs[i].node(), "value") == 5:
194 | inputs[i].node().copyAttributes(float_node)
195 |
196 | model.apply(patch_float)
197 | patch_float(model.encode_image)
198 | patch_float(model.encode_text)
199 |
200 | model.float()
201 |
202 | return model, _transform(model.input_resolution.item())
203 |
204 |
205 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
206 | """
207 | Returns the tokenized representation of given input string(s)
208 |
209 | Parameters
210 | ----------
211 | texts : Union[str, List[str]]
212 | An input string or a list of input strings to tokenize
213 |
214 | context_length : int
215 | The context length to use; all CLIP models use 77 as the context length
216 |
217 | truncate: bool
218 | Whether to truncate the text in case its encoding is longer than the context length
219 |
220 | Returns
221 | -------
222 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
223 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
224 | """
225 | if isinstance(texts, str):
226 | texts = [texts]
227 |
228 | sot_token = _tokenizer.encoder["<|startoftext|>"]
229 | eot_token = _tokenizer.encoder["<|endoftext|>"]
230 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
231 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
232 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
233 | else:
234 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
235 |
236 | for i, tokens in enumerate(all_tokens):
237 | if len(tokens) > context_length:
238 | if truncate:
239 | tokens = tokens[:context_length]
240 | tokens[-1] = eot_token
241 | else:
242 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
243 | result[i, :len(tokens)] = torch.tensor(tokens)
244 |
245 | return result
246 |
--------------------------------------------------------------------------------
/OD/modeling/meta_arch.py:
--------------------------------------------------------------------------------
1 | from ast import mod
2 | import math
3 | import numpy as np
4 | import cv2
5 | import os
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torchvision
10 | import torchvision.transforms as T
11 |
12 | from typing import Dict,List,Optional
13 |
14 | from detectron2.modeling import META_ARCH_REGISTRY, GeneralizedRCNN
15 | from detectron2.structures import ImageList, Instances, pairwise_iou
16 | from detectron2.utils.events import get_event_storage
17 | from detectron2.layers import batched_nms
18 | from detectron2.data.detection_utils import convert_image_to_rgb
19 | from detectron2.utils.visualizer import Visualizer
20 |
21 |
22 | def show_cam_on_image(img, mask):
23 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
24 | heatmap = np.float32(heatmap) / 255
25 | cam = heatmap + np.float32(img)
26 | cam = cam / np.max(cam)
27 | return cam
28 |
29 |
30 | def generate_visualization(feat, input_img_tensor, scale_factor=60, size=(224, 224)):
31 | """
32 |
33 | :param feat: 3D ,4D
34 | :param input_img_tensor:
35 | :param size:
36 | :return:
37 | """
38 | input_img_tensor = input_img_tensor[:, :size[0], :size[1]]
39 | feat = torch.nn.functional.interpolate(feat, scale_factor=scale_factor, mode='bilinear')
40 | feat = feat.reshape(size).cuda().data.cpu().numpy()
41 | feat = (feat - feat.min()) / (
42 | feat.max() - feat.min())
43 | image_feat = input_img_tensor.permute(1, 2, 0).data.cpu().numpy()
44 | image_feat = (image_feat - image_feat.min()) / (
45 | image_feat.max() - image_feat.min())
46 | vis = show_cam_on_image(image_feat, feat)
47 | vis = np.uint8(255 * vis)
48 | return vis
49 |
50 |
51 | @META_ARCH_REGISTRY.register()
52 | class ClipRCNNWithClipBackbone(GeneralizedRCNN):
53 |
54 | def __init__(self,cfg) -> None:
55 | super().__init__(cfg)
56 | self.cfg = cfg
57 | self.colors = self.generate_colors(7)
58 | self.backbone.set_backbone_model(self.roi_heads.box_predictor.cls_score.visual_enc)
59 |
60 | # txt
61 | domain_text = {'day': 'an image taken during the day'}
62 | with open('prunedprompts.txt','r') as f:
63 | for ind,l in enumerate(f):
64 | domain_text.update({str(ind):l.strip()})
65 | self.offsets = nn.Parameter(torch.zeros(len(domain_text)-1,1024,14,14)) #skip day
66 |
67 | import clip
68 | self.domain_tk = dict([(k,clip.tokenize(t)) for k,t in domain_text.items()])
69 | self.apply_aug = cfg.AUG_PROB
70 |
71 | day_text_embed_list = []
72 | for i,val in enumerate(self.domain_tk.items()):
73 | name , dtk = val
74 | if name == 'day':
75 | continue
76 | with torch.no_grad():
77 |
78 | day_text_embed = self.roi_heads.box_predictor.cls_score.model.encode_text(self.domain_tk['day'].cuda()) #day
79 | day_text_embed = day_text_embed/day_text_embed.norm(dim=-1,keepdim=True)
80 | day_text_embed_list.append(day_text_embed)
81 | self.day_text_embeds = torch.cat(day_text_embed_list,0).cuda()
82 |
83 | def preprocess_image(self, batched_inputs: List[Dict[str, torch.Tensor]]):
84 | """
85 | Normalize, pad and batch the input images.
86 | """
87 | clip_images = [x["image"].to(self.pixel_mean.device) for x in batched_inputs]
88 | mean=[0.48145466, 0.4578275, 0.40821073]
89 | std=[0.26862954, 0.26130258, 0.27577711]
90 |
91 |
92 | clip_images = [ T.functional.normalize(ci.flip(0)/255, mean,std) for ci in clip_images]
93 | clip_images = ImageList.from_tensors(
94 | [i for i in clip_images])
95 | return clip_images
96 |
97 |
98 | def forward(self, batched_inputs):
99 |
100 | if not self.training:
101 | return self.inference(batched_inputs)
102 |
103 | images = self.preprocess_image(batched_inputs)
104 | b = images.tensor.shape[0]#batchsize
105 |
106 | if "instances" in batched_inputs[0]:
107 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
108 |
109 | features = self.backbone(images.tensor)
110 |
111 | if self.proposal_generator is not None:
112 | if self.training:
113 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
114 | else:
115 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
116 | else:
117 | assert "proposals" in batched_inputs[0]
118 | proposals = [x["proposals"].to(self.device) for x in batched_inputs]
119 | proposal_losses = {}
120 |
121 | try:
122 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, None, self.backbone)
123 | except Exception as e:
124 | print(e)
125 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, None)
126 |
127 | if self.vis_period > 0:
128 | storage = get_event_storage()
129 | if storage.iter % self.vis_period == 0:
130 | self.visualize_training(batched_inputs, proposals)
131 | with torch.no_grad():
132 | ogimage = batched_inputs[0]['image']
133 | ogimage = convert_image_to_rgb(ogimage.permute(1, 2, 0), self.input_format)
134 | o_pred = Visualizer(ogimage, None).overlay_instances().get_image()
135 |
136 | vis_img = o_pred.transpose(2, 0, 1)
137 | storage.put_image('og-tfimage', vis_img)
138 |
139 | losses = {}
140 | losses.update(detector_losses)
141 | losses.update(proposal_losses)
142 | return losses
143 |
144 | def generate_colors(self,N):
145 | import colorsys
146 | '''
147 | Generate random colors.
148 | To get visually distinct colors, generate them in HSV space then
149 | convert to RGB.
150 | '''
151 | brightness = 0.7
152 | hsv = [(i / N, 1, brightness) for i in range(N)]
153 | colors = list(map(lambda c: tuple(round(i * 255) for i in colorsys.hsv_to_rgb(*c)), hsv))
154 | perm = np.arange(7)
155 | colors = [colors[idx] for idx in perm]
156 | return colors
157 |
158 |
159 | def inference(
160 | self,
161 | batched_inputs: List[Dict[str, torch.Tensor]],
162 | detected_instances: Optional[List[Instances]] = None,
163 | do_postprocess: bool = True,
164 | ):
165 | """
166 | Run inference on the given inputs.
167 | Args:
168 | batched_inputs (list[dict]): same as in :meth:`forward`
169 | detected_instances (None or list[Instances]): if not None, it
170 | contains an `Instances` object per image. The `Instances`
171 | object contains "pred_boxes" and "pred_classes" which are
172 | known boxes in the image.
173 | The inference will then skip the detection of bounding boxes,
174 | and only predict other per-ROI outputs.
175 | do_postprocess (bool): whether to apply post-processing on the outputs.
176 | Returns:
177 | When do_postprocess=True, same as in :meth:`forward`.
178 | Otherwise, a list[Instances] containing raw network outputs.
179 | """
180 | assert not self.training
181 |
182 | images = self.preprocess_image(batched_inputs)
183 | features = self.backbone(images.tensor)
184 |
185 |
186 | ###############save feat vis###################
187 | feat_vis = False
188 | if feat_vis:
189 | scale_size = 16
190 | out_dir_explain = os.path.join('./output', 'featmap_vis')
191 | explain_rpn_feat_i, _ = torch.max(features['res4'][0], 0)
192 | explain_rpn_feat_i = explain_rpn_feat_i.unsqueeze(0).unsqueeze(0)
193 |
194 | size = (explain_rpn_feat_i.shape[-2] * scale_size , explain_rpn_feat_i.shape[-1] *scale_size)
195 | visual = generate_visualization(explain_rpn_feat_i, images.tensor[0], scale_factor=scale_size, size=size)
196 |
197 | name_sp_list = batched_inputs[0]['file_name'].split('/')[-1].rsplit('.', 1)
198 | save_file_name = name_sp_list[0] + '.' + name_sp_list[1]
199 | explain_out_file = os.path.join(out_dir_explain, save_file_name)
200 | cv2.imwrite(explain_out_file, visual)
201 |
202 | ###############save feat vis ###################
203 |
204 | if detected_instances is None:
205 | if self.proposal_generator is not None:
206 | logits,proposals, _ = self.proposal_generator(images, features, None)
207 | else:
208 | assert "proposals" in batched_inputs[0]
209 | proposals = [x["proposals"].to(self.device) for x in batched_inputs]
210 |
211 |
212 | try:
213 | results, _ = self.roi_heads(images,self.day_text_embeds, features, proposals, None, None, self.backbone)
214 | except:
215 | results, _ = self.roi_heads(images,self.day_text_embeds, features, proposals, None, None)
216 | else:
217 | detected_instances = [x.to(self.device) for x in detected_instances]
218 | results = self.roi_heads.forward_with_given_boxes(features, detected_instances)
219 |
220 |
221 | if do_postprocess:
222 | assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess."
223 |
224 | allresults = GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes)
225 |
226 |
227 | return allresults
228 | else:
229 | return results
230 |
231 |
232 | @META_ARCH_REGISTRY.register()
233 | class ClipRCNNWithClipBackboneTrainable(ClipRCNNWithClipBackbone):
234 | def __init__(self,cfg) -> None:
235 | super().__init__(cfg)
236 |
237 | def forward(self, batched_inputs):
238 |
239 | if not self.training:
240 | return self.inference(batched_inputs)
241 |
242 | images = self.preprocess_image(batched_inputs)
243 | b = images.tensor.shape[0]#batchsize
244 |
245 | if "instances" in batched_inputs[0]:
246 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
247 |
248 | features = self.backbone(images.tensor)
249 |
250 | if self.proposal_generator is not None:
251 | if self.training:
252 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
253 | else:
254 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
255 | else:
256 | assert "proposals" in batched_inputs[0]
257 | proposals = [x["proposals"].to(self.device) for x in batched_inputs]
258 | proposal_losses = {}
259 |
260 | _, detector_losses = self.roi_heads(images, self.day_text_embeds, features, proposals, gt_instances, None, self.backbone)
261 |
262 | losses = {}
263 | losses.update(detector_losses)
264 | losses.update(proposal_losses)
265 | return losses
266 |
267 |
268 |
--------------------------------------------------------------------------------
/OD/modeling/.ipynb_checkpoints/meta_arch-checkpoint.py:
--------------------------------------------------------------------------------
1 | from ast import mod
2 | import math
3 | import numpy as np
4 | import cv2
5 | import os
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torchvision
10 | import torchvision.transforms as T
11 |
12 | from typing import Dict,List,Optional
13 |
14 | from detectron2.modeling import META_ARCH_REGISTRY, GeneralizedRCNN
15 | from detectron2.structures import ImageList, Instances, pairwise_iou
16 | from detectron2.utils.events import get_event_storage
17 | from detectron2.layers import batched_nms
18 | from detectron2.data.detection_utils import convert_image_to_rgb
19 | from detectron2.utils.visualizer import Visualizer
20 | # from .regularization import *
21 |
22 | def show_cam_on_image(img, mask):
23 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
24 | heatmap = np.float32(heatmap) / 255
25 | cam = heatmap + np.float32(img)
26 | cam = cam / np.max(cam)
27 | return cam
28 |
29 |
30 | def generate_visualization(feat, input_img_tensor, scale_factor=60, size=(224, 224)):
31 | """
32 |
33 | :param feat: 3D ,4D
34 | :param input_img_tensor:
35 | :param size:
36 | :return:
37 | """
38 | input_img_tensor = input_img_tensor[:, :size[0], :size[1]]
39 | feat = torch.nn.functional.interpolate(feat, scale_factor=scale_factor, mode='bilinear')
40 | feat = feat.reshape(size).cuda().data.cpu().numpy()
41 | feat = (feat - feat.min()) / (
42 | feat.max() - feat.min())
43 | image_feat = input_img_tensor.permute(1, 2, 0).data.cpu().numpy()
44 | image_feat = (image_feat - image_feat.min()) / (
45 | image_feat.max() - image_feat.min())
46 | vis = show_cam_on_image(image_feat, feat)
47 | vis = np.uint8(255 * vis)
48 | return vis
49 |
50 |
51 | @META_ARCH_REGISTRY.register()
52 | class ClipRCNNWithClipBackbone(GeneralizedRCNN):
53 |
54 | def __init__(self,cfg) -> None:
55 | super().__init__(cfg)
56 | self.cfg = cfg
57 | self.colors = self.generate_colors(7)
58 | self.backbone.set_backbone_model(self.roi_heads.box_predictor.cls_score.visual_enc)
59 |
60 | # txt
61 | domain_text = {'day': 'an image taken during the day'}
62 | with open('prunedprompts2.txt','r') as f:
63 | for ind,l in enumerate(f):
64 | domain_text.update({str(ind):l.strip()})
65 | self.offsets = nn.Parameter(torch.zeros(len(domain_text)-1,1024,14,14)) #skip day
66 |
67 | import clip
68 | self.domain_tk = dict([(k,clip.tokenize(t)) for k,t in domain_text.items()])
69 | self.apply_aug = cfg.AUG_PROB
70 |
71 | day_text_embed_list = []
72 | for i,val in enumerate(self.domain_tk.items()):
73 | name , dtk = val
74 | if name == 'day':
75 | continue
76 | with torch.no_grad():
77 |
78 | day_text_embed = self.roi_heads.box_predictor.cls_score.model.encode_text(self.domain_tk['day'].cuda()) #day
79 | day_text_embed = day_text_embed/day_text_embed.norm(dim=-1,keepdim=True)
80 | day_text_embed_list.append(day_text_embed)
81 | self.day_text_embeds = torch.cat(day_text_embed_list,0).cuda()
82 |
83 | def preprocess_image(self, batched_inputs: List[Dict[str, torch.Tensor]]):
84 | """
85 | Normalize, pad and batch the input images.
86 | """
87 | clip_images = [x["image"].to(self.pixel_mean.device) for x in batched_inputs]
88 | mean=[0.48145466, 0.4578275, 0.40821073]
89 | std=[0.26862954, 0.26130258, 0.27577711]
90 |
91 |
92 | clip_images = [ T.functional.normalize(ci.flip(0)/255, mean,std) for ci in clip_images]
93 | clip_images = ImageList.from_tensors(
94 | [i for i in clip_images])
95 | return clip_images
96 |
97 |
98 | def forward(self, batched_inputs):
99 |
100 | if not self.training:
101 | return self.inference(batched_inputs)
102 |
103 | images = self.preprocess_image(batched_inputs)
104 | b = images.tensor.shape[0]#batchsize
105 |
106 | if "instances" in batched_inputs[0]:
107 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
108 |
109 | features = self.backbone(images.tensor)
110 |
111 | if self.proposal_generator is not None:
112 | if self.training:
113 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
114 | else:
115 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
116 | else:
117 | assert "proposals" in batched_inputs[0]
118 | proposals = [x["proposals"].to(self.device) for x in batched_inputs]
119 | proposal_losses = {}
120 |
121 | try:
122 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, None, self.backbone)
123 | except Exception as e:
124 | print(e)
125 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, None)
126 |
127 | if self.vis_period > 0:
128 | storage = get_event_storage()
129 | if storage.iter % self.vis_period == 0:
130 | self.visualize_training(batched_inputs, proposals)
131 | with torch.no_grad():
132 | ogimage = batched_inputs[0]['image']
133 | ogimage = convert_image_to_rgb(ogimage.permute(1, 2, 0), self.input_format)
134 | o_pred = Visualizer(ogimage, None).overlay_instances().get_image()
135 |
136 | vis_img = o_pred.transpose(2, 0, 1)
137 | storage.put_image('og-tfimage', vis_img)
138 |
139 | losses = {}
140 | losses.update(detector_losses)
141 | losses.update(proposal_losses)
142 | return losses
143 |
144 | def generate_colors(self,N):
145 | import colorsys
146 | '''
147 | Generate random colors.
148 | To get visually distinct colors, generate them in HSV space then
149 | convert to RGB.
150 | '''
151 | brightness = 0.7
152 | hsv = [(i / N, 1, brightness) for i in range(N)]
153 | colors = list(map(lambda c: tuple(round(i * 255) for i in colorsys.hsv_to_rgb(*c)), hsv))
154 | perm = np.arange(7)
155 | colors = [colors[idx] for idx in perm]
156 | return colors
157 |
158 |
159 | def inference(
160 | self,
161 | batched_inputs: List[Dict[str, torch.Tensor]],
162 | detected_instances: Optional[List[Instances]] = None,
163 | do_postprocess: bool = True,
164 | ):
165 | """
166 | Run inference on the given inputs.
167 | Args:
168 | batched_inputs (list[dict]): same as in :meth:`forward`
169 | detected_instances (None or list[Instances]): if not None, it
170 | contains an `Instances` object per image. The `Instances`
171 | object contains "pred_boxes" and "pred_classes" which are
172 | known boxes in the image.
173 | The inference will then skip the detection of bounding boxes,
174 | and only predict other per-ROI outputs.
175 | do_postprocess (bool): whether to apply post-processing on the outputs.
176 | Returns:
177 | When do_postprocess=True, same as in :meth:`forward`.
178 | Otherwise, a list[Instances] containing raw network outputs.
179 | """
180 | assert not self.training
181 |
182 | images = self.preprocess_image(batched_inputs)
183 | features = self.backbone(images.tensor)
184 |
185 |
186 | ###############save feat vis###################
187 | feat_vis = False
188 | if feat_vis:
189 | scale_size = 16
190 | out_dir_explain = os.path.join('./output', 'featmap_vis')
191 | explain_rpn_feat_i, _ = torch.max(features['res4'][0], 0)
192 | explain_rpn_feat_i = explain_rpn_feat_i.unsqueeze(0).unsqueeze(0)
193 |
194 | size = (explain_rpn_feat_i.shape[-2] * scale_size , explain_rpn_feat_i.shape[-1] *scale_size)
195 | visual = generate_visualization(explain_rpn_feat_i, images.tensor[0], scale_factor=scale_size, size=size)
196 |
197 | name_sp_list = batched_inputs[0]['file_name'].split('/')[-1].rsplit('.', 1)
198 | save_file_name = name_sp_list[0] + '.' + name_sp_list[1]
199 | explain_out_file = os.path.join(out_dir_explain, save_file_name)
200 | cv2.imwrite(explain_out_file, visual)
201 |
202 | ###############save feat vis ###################
203 |
204 | if detected_instances is None:
205 | if self.proposal_generator is not None:
206 | logits,proposals, _ = self.proposal_generator(images, features, None)
207 | else:
208 | assert "proposals" in batched_inputs[0]
209 | proposals = [x["proposals"].to(self.device) for x in batched_inputs]
210 |
211 |
212 | try:
213 | results, _ = self.roi_heads(images,self.day_text_embeds, features, proposals, None, None, self.backbone)
214 | except:
215 | results, _ = self.roi_heads(images,self.day_text_embeds, features, proposals, None, None)
216 | else:
217 | detected_instances = [x.to(self.device) for x in detected_instances]
218 | results = self.roi_heads.forward_with_given_boxes(features, detected_instances)
219 |
220 |
221 | if do_postprocess:
222 | assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess."
223 |
224 | allresults = GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes)
225 |
226 |
227 | return allresults
228 | else:
229 | return results
230 |
231 |
232 | @META_ARCH_REGISTRY.register()
233 | class ClipRCNNWithClipBackboneTrainable(ClipRCNNWithClipBackbone):
234 | def __init__(self,cfg) -> None:
235 | super().__init__(cfg)
236 |
237 | def forward(self, batched_inputs):
238 |
239 | if not self.training:
240 | return self.inference(batched_inputs)
241 |
242 | images = self.preprocess_image(batched_inputs)
243 | b = images.tensor.shape[0]#batchsize
244 |
245 | if "instances" in batched_inputs[0]:
246 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
247 |
248 | features = self.backbone(images.tensor)
249 |
250 | if self.proposal_generator is not None:
251 | if self.training:
252 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
253 | else:
254 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances)
255 | else:
256 | assert "proposals" in batched_inputs[0]
257 | proposals = [x["proposals"].to(self.device) for x in batched_inputs]
258 | proposal_losses = {}
259 |
260 | _, detector_losses = self.roi_heads(images, self.day_text_embeds, features, proposals, gt_instances, None, self.backbone)
261 |
262 | losses = {}
263 | losses.update(detector_losses)
264 | losses.update(proposal_losses)
265 | return losses
266 |
267 |
268 |
--------------------------------------------------------------------------------
/OD/train.py:
--------------------------------------------------------------------------------
1 | from cgi import parse_multipart
2 | import os
3 | import logging
4 | import time
5 | from collections import OrderedDict, Counter
6 | import copy
7 |
8 | import numpy as np
9 |
10 | import torch
11 | from torch import autograd
12 | import torch.utils.data as torchdata
13 |
14 | from detectron2 import model_zoo
15 | from detectron2.config import get_cfg
16 | from detectron2.layers.batch_norm import FrozenBatchNorm2d
17 | from detectron2.engine import DefaultPredictor, DefaultTrainer, default_setup
18 | from detectron2.engine import default_argument_parser, hooks, HookBase
19 | from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping, build_lr_scheduler
20 | from detectron2.checkpoint import DetectionCheckpointer
21 | from detectron2.data import build_detection_train_loader, build_detection_test_loader, get_detection_dataset_dicts
22 | from detectron2.data.common import DatasetFromList, MapDataset
23 | from detectron2.data.samplers import InferenceSampler
24 | from detectron2.utils.events import get_event_storage
25 |
26 | from detectron2.utils import comm
27 | from detectron2.evaluation import COCOEvaluator, verify_results, inference_on_dataset, print_csv_format
28 |
29 | from detectron2.solver import LRMultiplier
30 | from detectron2.modeling import build_model
31 | from detectron2.structures import ImageList, Instances, pairwise_iou, Boxes
32 |
33 | from fvcore.common.param_scheduler import ParamScheduler
34 | from fvcore.common.checkpoint import Checkpointer
35 |
36 | from data.datasets import builtin
37 |
38 | from detectron2.evaluation import PascalVOCDetectionEvaluator, COCOEvaluator, inference_on_dataset
39 |
40 | from detectron2.data import build_detection_train_loader, MetadataCatalog
41 | import torch.utils.data as data
42 | from detectron2.data.dataset_mapper import DatasetMapper
43 | import detectron2.data.detection_utils as utils
44 | import detectron2.data.transforms as detT
45 |
46 | import torchvision.transforms as T
47 | import torchvision.transforms.functional as tF
48 |
49 | from modeling import add_stn_config
50 | from modeling import CustomPascalVOCDetectionEvaluator
51 | import ipdb
52 |
53 | logger = logging.getLogger("detectron2")
54 |
55 | def setup(args):
56 | cfg = get_cfg()
57 | add_stn_config(cfg)
58 | #hack to add base yaml
59 | cfg.merge_from_file(args.config_file)
60 | cfg.merge_from_file(model_zoo.get_config_file(cfg.BASE_YAML))
61 | cfg.merge_from_file(args.config_file)
62 | cfg.merge_from_list(args.opts)
63 | default_setup(cfg, args)
64 | return cfg
65 |
66 | class CustomDatasetMapper(DatasetMapper):
67 | def __init__(self,cfg,is_train) -> None:
68 | super().__init__(cfg,is_train)
69 | self.with_crops = cfg.INPUT.CLIP_WITH_IMG
70 | self.with_random_clip_crops = cfg.INPUT.CLIP_RANDOM_CROPS
71 | self.with_jitter = cfg.INPUT.IMAGE_JITTER
72 | self.cropfn = T.RandomCrop#T.RandomCrop([224,224])
73 | self.aug = T.ColorJitter(brightness=.5, hue=.3)
74 | self.crop_size = cfg.INPUT.RANDOM_CROP_SIZE
75 |
76 | def __call__(self,dataset_dict):
77 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
78 | # USER: Write your own image loading if it's not from a file
79 | image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
80 | utils.check_image_size(dataset_dict, image)
81 |
82 | # USER: Remove if you don't do semantic/panoptic segmentation.
83 | if "sem_seg_file_name" in dataset_dict:
84 | sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
85 | else:
86 | sem_seg_gt = None
87 |
88 | aug_input = detT.AugInput(image, sem_seg=sem_seg_gt)
89 | transforms = self.augmentations(aug_input)
90 | image, sem_seg_gt = aug_input.image, aug_input.sem_seg
91 |
92 | image_shape = image.shape[:2] # h, w
93 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
94 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
95 | # Therefore it's important to use torch.Tensor.
96 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
97 | if sem_seg_gt is not None:
98 | dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
99 |
100 | # USER: Remove if you don't use pre-computed proposals.
101 | # Most users would not need this feature.
102 | if self.proposal_topk is not None:
103 | utils.transform_proposals(
104 | dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
105 | )
106 |
107 | if not self.is_train:
108 | # USER: Modify this if you want to keep them for some reason.
109 | dataset_dict.pop("annotations", None)
110 | dataset_dict.pop("sem_seg_file_name", None)
111 | return dataset_dict
112 |
113 | if "annotations" in dataset_dict:
114 | self._transform_annotations(dataset_dict, transforms, image_shape)
115 |
116 | if self.with_jitter:
117 | dataset_dict["jitter_image"] = self.aug(dataset_dict["image"])
118 |
119 | if self.with_crops:
120 | bbox = dataset_dict['instances'].gt_boxes.tensor
121 | csx = (bbox[:,0] + bbox[:,2])*0.5
122 | csy = (bbox[:,1] + bbox[:,3])*0.5
123 | maxwh = torch.maximum(bbox[:,2]-bbox[:,0],bbox[:,3]-bbox[:,1])
124 | crops = list()
125 | gt_boxes = list()
126 | mean=[0.48145466, 0.4578275, 0.40821073]
127 | std=[0.26862954, 0.26130258, 0.27577711]
128 | for cx,cy,maxdim,label,box in zip(csx,csy,maxwh,dataset_dict['instances'].gt_classes, bbox):
129 |
130 | if int(maxdim) < 10:
131 | continue
132 | x0 = torch.maximum(cx-maxdim*0.5,torch.tensor(0))
133 | y0 = torch.maximum(cy-maxdim*0.5,torch.tensor(0))
134 | try:
135 | imcrop = T.functional.resized_crop(dataset_dict['image'],top=int(y0),left=int(x0),height=int(maxdim),width=int(maxdim),size=224)
136 | imcrop = imcrop.flip(0)/255 # bgr --> rgb for clip
137 | imcrop = T.functional.normalize(imcrop,mean,std)
138 | # print(x0,y0,x0+maxdim,y0+maxdim,dataset_dict['image'].shape)
139 | # print(imcrop.min(),imcrop.max() )
140 | gt_boxes.append(box.reshape(1,-1))
141 | except Exception as e:
142 | print(e)
143 | print('crops:',x0,y0,maxdim)
144 | exit()
145 | # crops.append((imcrop,label))
146 | crops.append(imcrop.unsqueeze(0))
147 |
148 | if len(crops) == 0:
149 | dataset_dict['crops'] = []
150 | else:
151 | dataset_dict['crops'] = [torch.cat(crops,0),Boxes(torch.cat(gt_boxes,0))]
152 |
153 | if self.with_random_clip_crops:
154 | crops = []
155 | rbboxs = []
156 |
157 | for i in range(15):
158 | p = self.cropfn.get_params(dataset_dict['image'],[self.crop_size,self.crop_size])
159 | c = tF.crop(dataset_dict['image'],*p)
160 | if self.crop_size != 224:
161 | c = tF.resize(img=c,size=224)
162 | crops.append(c)
163 | rbboxs.append(p)
164 |
165 | crops = torch.stack(crops)
166 | dataset_dict['randomcrops'] = crops
167 |
168 | #apply same crop bbox to the jittered image
169 | if self.with_jitter:
170 | jitter_crops = []
171 | for p in rbboxs:
172 | jc = tF.crop(dataset_dict['jitter_image'],*p)
173 | if self.crop_size != 224:
174 | jc = tF.resize(img=jc,size=224)
175 | jitter_crops.append(jc)
176 |
177 | jcrops = torch.stack(jitter_crops)
178 | dataset_dict['jitter_randomcrops'] = jcrops
179 | return dataset_dict
180 |
181 | class CombineLoaders(data.IterableDataset):
182 | def __init__(self,loaders):
183 | self.loaders = loaders
184 |
185 | def __iter__(self,):
186 | dd = iter(self.loaders[1])
187 | for d1 in self.loaders[0]:
188 | try:
189 | d2 = next(dd)
190 | except:
191 | dd=iter(self.loaders[1])
192 | d2 = next(dd)
193 |
194 | list_out_dict=[]
195 | for v1,v2 in zip(d1,d2):
196 | out_dict = {}
197 | for k in v1.keys():
198 | out_dict[k] = (v1[k],v2[k])
199 | list_out_dict.append(out_dict)
200 |
201 | yield list_out_dict
202 |
203 |
204 | class Trainer(DefaultTrainer):
205 |
206 | def __init__(self,cfg) -> None:
207 | super().__init__(cfg)
208 | self.teach_model = None
209 | self.off_opt_interval = np.arange(0,cfg.SOLVER.MAX_ITER,cfg.OFFSET_OPT_INTERVAL[0]).tolist()
210 | self.off_opt_iters = cfg.OFFSET_OPT_ITERS
211 |
212 | @classmethod
213 | def build_model(cls, cfg):
214 | """
215 | Returns:
216 | torch.nn.Module:
217 | It now calls :func:`detectron2.modeling.build_model`.
218 | Overwrite it if you'd like a different model.
219 | """
220 | model = build_model(cfg)
221 |
222 |
223 | logger = logging.getLogger(__name__)
224 | logger.info("Model:\n{}".format(model))
225 |
226 | return model
227 |
228 | @classmethod
229 | def build_train_loader(cls,cfg):
230 | original = cfg.DATASETS.TRAIN
231 | print(original)
232 | cfg.DATASETS.TRAIN=(original[0],)
233 | data_loader1 = build_detection_train_loader(cfg, mapper=CustomDatasetMapper(cfg, True))
234 | return data_loader1
235 |
236 | @classmethod
237 | def build_evaluator(cls, cfg, dataset_name, output_folder=None):
238 | if output_folder is None:
239 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
240 | if MetadataCatalog.get(dataset_name).evaluator_type == 'pascal_voc':
241 | return CustomPascalVOCDetectionEvaluator(dataset_name)
242 | else:
243 | return COCOEvaluator(dataset_name, output_dir=output_folder)
244 |
245 | @classmethod
246 | def build_optimizer(cls,cfg,model):
247 |
248 | trainable = {'others':[],'offset':[]}
249 |
250 | for name,val in model.named_parameters():
251 | head = name.split('.')[0]
252 | #previously was setting all params to be true
253 | if val.requires_grad == True:
254 | print(name)
255 | if 'offset' in name:
256 | trainable['offset'].append(val)
257 | else:
258 | trainable['others'].append(val)
259 |
260 | optimizer1 = torch.optim.SGD(
261 | trainable['others'],
262 | cfg.SOLVER.BASE_LR,
263 | momentum=cfg.SOLVER.MOMENTUM,
264 | nesterov=cfg.SOLVER.NESTEROV,
265 | weight_decay=cfg.SOLVER.WEIGHT_DECAY,
266 | )
267 |
268 | optimizer2 = torch.optim.Adam(
269 | trainable['offset'],
270 | 0.01,
271 | )
272 | return (maybe_add_gradient_clipping(cfg, optimizer1),maybe_add_gradient_clipping(cfg, optimizer2))
273 |
274 |
275 | def run_step(self):
276 | """
277 | Implement the standard training logic described above.
278 | """
279 |
280 |
281 | assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
282 | start = time.perf_counter()
283 | """
284 | If you want to do something with the data, you can wrap the dataloader.
285 | """
286 | data = next(self._trainer._data_loader_iter)
287 | data_time = time.perf_counter() - start
288 |
289 | """
290 | If you want to do something with the losses, you can wrap the model.
291 | """
292 | data_s = data
293 |
294 | opt_phase = False
295 | loss_dict_s = self.model(data_s)
296 | loss_dict = {}
297 |
298 | loss = 0
299 | for k,v in loss_dict_s.items():
300 | loss += v
301 |
302 |
303 | """
304 | If you need to accumulate gradients or do something similar, you can
305 | wrap the optimizer with your custom `zero_grad()` method.
306 | """
307 | self.optimizer[0].zero_grad()
308 | self.optimizer[1].zero_grad()
309 |
310 | loss.backward()
311 |
312 | if not opt_phase:
313 | self.optimizer[0].step()
314 | else:
315 | self.optimizer[1].step()
316 |
317 | self.optimizer[0].zero_grad()
318 | self.optimizer[1].zero_grad()
319 |
320 | for k,v in loss_dict_s.items():
321 | loss_dict.update({k:v})
322 |
323 | # print(loss_di ct)
324 | self._trainer._write_metrics(loss_dict, data_time)
325 | """
326 | If you need gradient clipping/scaling or other processing, you can
327 | wrap the optimizer with your custom `step()` method. But it is
328 | suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4
329 | """
330 |
331 | def build_hooks(self):
332 | """
333 | Build a list of default hooks, including timing, evaluation,
334 | checkpointing, lr scheduling, precise BN, writing events.
335 | Returns:
336 | list[HookBase]:
337 | """
338 | cfg = self.cfg.clone()
339 | cfg.defrost()
340 | cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
341 |
342 | ret = [
343 | hooks.IterationTimer(),
344 | LRScheduler(),
345 | hooks.PreciseBN(
346 | # Run at the same freq as (but before) evaluation.
347 | cfg.TEST.EVAL_PERIOD,
348 | self.model,
349 | # Build a new data loader to not affect training
350 | self.build_train_loader(cfg),
351 | cfg.TEST.PRECISE_BN.NUM_ITER,
352 | )
353 | if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
354 | else None,
355 | ]
356 |
357 | # Do PreciseBN before checkpointer, because it updates the model and need to
358 | # be saved by checkpointer.
359 | # This is not always the best: if checkpointing has a different frequency,
360 | # some checkpoints may have more precise statistics than others.
361 | if comm.is_main_process():
362 | ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
363 |
364 | def test_and_save_results():
365 | self._last_eval_results = self.test(self.cfg, self.model)
366 | return self._last_eval_results
367 |
368 | def do_test_st(flag):
369 | if flag == 'st':
370 | model = self.model
371 | else:
372 | print("Error in the flag")
373 |
374 | results = OrderedDict()
375 | for dataset_name in self.cfg.DATASETS.TEST:
376 | data_loader = build_detection_test_loader(self.cfg, dataset_name)
377 | evaluator = CustomPascalVOCDetectionEvaluator(dataset_name)
378 | results_i = inference_on_dataset(model, data_loader, evaluator)
379 | results[dataset_name] = results_i
380 | if comm.is_main_process():
381 | logger.info("Evaluation results for {} in csv format:".format(dataset_name))
382 | print_csv_format(results_i)
383 | storage = get_event_storage()
384 | storage.put_scalar(f'{dataset_name}_AP50', results_i['bbox']['AP50'],smoothing_hint=False)
385 | if len(results) == 1:
386 | results = list(results.values())[0]
387 | return results
388 |
389 |
390 | # Do evaluation after checkpointer, because then if it fails,
391 | # we can use the saved checkpoint to debug.
392 | ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
393 | ret.append(hooks.EvalHook(cfg.TEST.EVAL_SAVE_PERIOD, lambda flag='st': do_test_st(flag)))
394 |
395 | if comm.is_main_process():
396 | # Here the default print/log frequency of each writer is used.
397 | # run writers in the end, so that evaluation metrics are written
398 | ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
399 | return ret
400 |
401 | @classmethod
402 | def build_lr_scheduler(cls, cfg, optimizer):
403 | """
404 | It now calls :func:`detectron2.solver.build_lr_scheduler`.
405 | Overwrite it if you'd like a different scheduler.
406 | """
407 |
408 | return build_lr_scheduler(cfg, optimizer[0])
409 |
410 | def state_dict(self):
411 | ret = super().state_dict()
412 | ret["optimizer1"] = self.optimizer[0].state_dict()
413 | ret["optimizer2"] = self.optimizer[1].state_dict()
414 | return ret
415 |
416 | def load_state_dict(self, state_dict):
417 | super().load_state_dict(state_dict)
418 | self.optimizer[0].load_state_dict(state_dict["optimizer1"])
419 | self.optimizer[1].load_state_dict(state_dict["optimizer2"])
420 |
421 |
422 |
423 | class LRScheduler(HookBase):
424 | """
425 | A hook which executes a torch builtin LR scheduler and summarizes the LR.
426 | It is executed after every iteration.
427 | """
428 |
429 | def __init__(self, optimizer=None, scheduler=None):
430 | """
431 | Args:
432 | optimizer (torch.optim.Optimizer):
433 | scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler):
434 | if a :class:`ParamScheduler` object, it defines the multiplier over the base LR
435 | in the optimizer.
436 | If any argument is not given, will try to obtain it from the trainer.
437 | """
438 | self._optimizer = optimizer
439 | self._scheduler = scheduler
440 |
441 | def before_train(self):
442 | self._optimizer = self._optimizer or self.trainer.optimizer
443 | if isinstance(self.scheduler, ParamScheduler):
444 | self._scheduler = LRMultiplier(
445 | self._optimizer,
446 | self.scheduler,
447 | self.trainer.max_iter,
448 | last_iter=self.trainer.iter - 1,
449 | )
450 | self._best_param_group_id1 = LRScheduler.get_best_param_group_id(self._optimizer[0])
451 | self._best_param_group_id2 = LRScheduler.get_best_param_group_id(self._optimizer[1])
452 |
453 |
454 | @staticmethod
455 | def get_best_param_group_id(optimizer):
456 | # NOTE: some heuristics on what LR to summarize
457 | # summarize the param group with most parameters
458 | largest_group = max(len(g["params"]) for g in optimizer.param_groups)
459 |
460 | if largest_group == 1:
461 | # If all groups have one parameter,
462 | # then find the most common initial LR, and use it for summary
463 | lr_count = Counter([g["lr"] for g in optimizer.param_groups])
464 | lr = lr_count.most_common()[0][0]
465 | for i, g in enumerate(optimizer.param_groups):
466 | if g["lr"] == lr:
467 | return i
468 | else:
469 | for i, g in enumerate(optimizer.param_groups):
470 | if len(g["params"]) == largest_group:
471 | return i
472 |
473 | def after_step(self):
474 | lr1 = self._optimizer[0].param_groups[self._best_param_group_id1]["lr"]
475 | self.trainer.storage.put_scalar("lr1", lr1, smoothing_hint=False)
476 |
477 | lr2 = self._optimizer[1].param_groups[self._best_param_group_id2]["lr"]
478 | self.trainer.storage.put_scalar("lr2", lr2, smoothing_hint=False)
479 |
480 | self.scheduler.step()
481 |
482 | @property
483 | def scheduler(self):
484 | return self._scheduler or self.trainer.scheduler
485 |
486 | def state_dict(self):
487 | if isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler):
488 | return self.scheduler.state_dict()
489 | return {}
490 |
491 | def load_state_dict(self, state_dict):
492 | if isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler):
493 | logger = logging.getLogger(__name__)
494 | logger.info("Loading scheduler from state_dict ...")
495 | self.scheduler.load_state_dict(state_dict)
496 |
497 | def custom_build_detection_test_loader(cfg,dataset_name,mapper=None):
498 |
499 | if isinstance(dataset_name, str):
500 | dataset_name = [dataset_name]
501 |
502 | dataset = get_detection_dataset_dicts(
503 | dataset_name,
504 | filter_empty=False,
505 | proposal_files=[
506 | cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name
507 | ]
508 | if cfg.MODEL.LOAD_PROPOSALS
509 | else None,
510 | )
511 | if mapper is None:
512 | mapper = DatasetMapper(cfg, False)
513 |
514 | if isinstance(dataset, list):
515 | dataset = DatasetFromList(dataset, copy=False)
516 | if mapper is not None:
517 | dataset = MapDataset(dataset, mapper)
518 |
519 | sampler = None
520 | if isinstance(dataset, torchdata.IterableDataset):
521 | assert sampler is None, "sampler must be None if dataset is IterableDataset"
522 | else:
523 | if sampler is None:
524 | sampler = InferenceSampler(len(dataset))
525 | collate_fn = None
526 |
527 | def trivial_batch_collator(batch):
528 | """
529 | A batch collator that does nothing.
530 | """
531 | return batch
532 |
533 | return torchdata.DataLoader(
534 | dataset,
535 | batch_size=1,
536 | sampler=sampler,
537 | drop_last=False,
538 | num_workers=cfg.DATALOADER.NUM_WORKERS,
539 | collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
540 | )
541 |
542 |
543 | def do_test(cfg, model, model_type=''):
544 | results = OrderedDict()
545 | for dataset_name in cfg.DATASETS.TEST:
546 | data_loader = build_detection_test_loader(cfg, dataset_name)
547 | evaluator = CustomPascalVOCDetectionEvaluator(dataset_name)
548 | results_i = inference_on_dataset(model, data_loader, evaluator)
549 | results[dataset_name] = results_i
550 | if comm.is_main_process():
551 | logger.info("Evaluation results for {} in csv format:".format(dataset_name))
552 | print_csv_format(results_i)
553 |
554 | if len(results) == 1:
555 | results = list(results.values())[0]
556 | return results
557 |
558 | def main(args):
559 | cfg = setup(args)
560 | if args.eval_only:
561 | model = Trainer.build_model(cfg)
562 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
563 | cfg.MODEL.WEIGHTS, resume=args.resume
564 | )
565 | return do_test(cfg,model)
566 | trainer = Trainer(cfg)
567 | trainer.resume_or_load(resume=args.resume)
568 | for dataset_name in cfg.DATASETS.TEST:
569 | if 'daytime_clear_test' in dataset_name :
570 | trainer.register_hooks([
571 | hooks.BestCheckpointer(cfg.TEST.EVAL_SAVE_PERIOD,trainer.checkpointer,f'{dataset_name}_AP50',file_prefix='daytime_clear_model_best'),
572 | ])
573 |
574 | trainer.train()
575 |
576 |
577 | if __name__ == "__main__":
578 | args = default_argument_parser().parse_args()
579 | cfg = setup(args)
580 | print("Command Line Args:", args)
581 | main(args)
582 |
--------------------------------------------------------------------------------
/OD/CLIP/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 | import ipdb
9 | from .mask import Mask_s, Mask_c
10 | import math
11 |
12 |
13 | def conv2d_out_dim(dim, kernel_size, padding=0, stride=1, dilation=1, ceil_mode=False):
14 | if ceil_mode:
15 | return int(math.ceil((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1))
16 | else:
17 | return int(math.floor((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1))
18 |
19 |
20 |
21 | class Bottleneck(nn.Module):
22 | expansion = 4
23 |
24 | def __init__(self, inplanes, planes, h=0, w=0, eta=8, stride=1, base_width=64, dilation=1):
25 |
26 | super().__init__()
27 |
28 |
29 | self.height_1, self.width_1 = h, w
30 | if self.height_1>0:
31 | self.dynamic = True
32 | else:
33 | self.dynamic = False
34 |
35 | if self.dynamic:
36 | width = int(planes * (base_width / 64.)) # * groups
37 | # spatial gating module
38 |
39 | self.height_2 = conv2d_out_dim(h, 3, dilation, 2, dilation)
40 | self.width_2 = conv2d_out_dim(w, 3, dilation, 2, dilation)
41 | self.mask_s = Mask_s(self.height_2, self.width_2, inplanes, eta, eta)
42 | self.upsample_1 = nn.Upsample(size=(self.height_2, self.width_2), mode='nearest')
43 | self.upsample_2 = nn.Upsample(size=(self.height_2, self.width_2), mode='nearest')
44 |
45 | self.mask_c1 = Mask_c(inplanes, width)
46 | self.mask_c2 = Mask_c(width, width)
47 |
48 |
49 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
50 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
51 | self.bn1 = nn.BatchNorm2d(planes)
52 | self.relu1 = nn.ReLU(inplace=True)
53 |
54 | # conv 2
55 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
56 | self.bn2 = nn.BatchNorm2d(planes)
57 | self.relu2 = nn.ReLU(inplace=True)
58 |
59 |
60 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
61 |
62 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
63 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
64 | self.relu3 = nn.ReLU(inplace=True)
65 |
66 | self.downsample = None
67 | self.stride = stride
68 |
69 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
70 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
71 | self.downsample = nn.Sequential(OrderedDict([
72 | ("-1", nn.AvgPool2d(stride)),
73 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
74 | ("1", nn.BatchNorm2d(planes * self.expansion))
75 | ]))
76 |
77 | if self.dynamic:
78 | self.inplanes, self.width, self.planes = inplanes, width, planes * self.expansion
79 |
80 | flops_conv1_full = torch.Tensor([self.height_1 * self.width_1 * width * inplanes])
81 | flops_conv2_full = torch.Tensor([9 * self.height_2 * self.width_2 * width * width])
82 | flops_conv3_full = torch.Tensor([self.height_2 * self.width_2 * width * planes*self.expansion])
83 | self.flops_downsample = torch.Tensor([self.height_2*self.width_2*planes*self.expansion*inplanes]
84 | ) if self.downsample is not None else torch.Tensor([0])
85 | self.flops_full = flops_conv1_full+flops_conv2_full+flops_conv3_full+self.flops_downsample
86 | # mask flops
87 | flops_mask_s = self.mask_s.get_flops()
88 | flops_mask_c1 = self.mask_c1.get_flops()
89 | flops_mask_c2 = self.mask_c2.get_flops()
90 | self.flops_mask = torch.Tensor([flops_mask_s + flops_mask_c1 + flops_mask_c2])
91 |
92 | def forward(self, input: torch.Tensor):
93 | if self.dynamic:
94 | x, txt_emb, norm_1, norm_2, flops = input
95 | # spatial mask
96 | mask_s_m, norm_s, norm_s_t = self.mask_s(x) # [N, 1, h, w]
97 | mask_c1, norm_c1, norm_c1_t = self.mask_c1(x, txt_emb)
98 | mask_s1 = self.upsample_1(mask_s_m) # [N, 1, H1, W1]
99 | mask_s = self.upsample_2(mask_s_m) # [N, 1, H2, W2]
100 |
101 | else:
102 | x = input
103 | identity = x
104 |
105 | out = self.relu1(self.bn1(self.conv1(x)))
106 |
107 | if self.dynamic:
108 | out * mask_c1 * mask_s1 if not self.training else out * mask_c1 * mask_s1
109 | mask_c2, norm_c2, norm_c2_t = self.mask_c2(out, txt_emb)
110 |
111 | out = self.relu2(self.bn2(self.conv2(out)))
112 |
113 | if self.dynamic:
114 | out = out * mask_c2* mask_s if not self.training else out * mask_c2 * mask_s
115 |
116 | out = self.avgpool(out)
117 | out = self.bn3(self.conv3(out))
118 |
119 | if self.downsample is not None:
120 | identity = self.downsample(x)
121 |
122 | out += identity
123 | out = self.relu3(out)
124 |
125 | if self.dynamic:
126 | # norm
127 | norm_1 = torch.cat((norm_1, torch.cat((norm_s, norm_s_t)).unsqueeze(0)))
128 | norm_2 = torch.cat((norm_2, torch.cat((norm_c1, norm_c1_t)).unsqueeze(0)))
129 | norm_2 = torch.cat((norm_2, torch.cat((norm_c2, norm_c2_t)).unsqueeze(0)))
130 | flops_blk = self.get_flops(mask_s, mask_s1, mask_c1, mask_c2)
131 | flops = torch.cat((flops, flops_blk.unsqueeze(0)))
132 |
133 | return (out, txt_emb, norm_1, norm_2, flops)
134 | return out
135 |
136 | def get_flops(self, mask_s, mask_s1, mask_c1, mask_c2):
137 | s_sum = mask_s.sum((1,2,3))
138 | c1_sum, c2_sum = mask_c1.sum((1,2,3)), mask_c2.sum((1,2,3))
139 | # conv
140 | s_sum_1 = mask_s1.sum((1,2,3))
141 | flops_conv1 = s_sum_1 * c1_sum * self.inplanes
142 | flops_conv2 = 9 * s_sum * c2_sum * c1_sum
143 | flops_conv3 = s_sum * self.planes * c2_sum
144 | # total
145 | flops = flops_conv1+flops_conv2+flops_conv3+self.flops_downsample.to(flops_conv1.device)
146 | return torch.cat((flops, self.flops_mask.to(flops.device), self.flops_full.to(flops.device)))
147 |
148 |
149 | class AttentionPool2d(nn.Module):
150 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
151 | super().__init__()
152 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
153 | self.k_proj = nn.Linear(embed_dim, embed_dim)
154 | self.q_proj = nn.Linear(embed_dim, embed_dim)
155 | self.v_proj = nn.Linear(embed_dim, embed_dim)
156 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
157 | self.num_heads = num_heads
158 |
159 | def forward(self, x):
160 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
161 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
162 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
163 | x, _ = F.multi_head_attention_forward(
164 | query=x[:1], key=x, value=x,
165 | embed_dim_to_check=x.shape[-1],
166 | num_heads=self.num_heads,
167 | q_proj_weight=self.q_proj.weight,
168 | k_proj_weight=self.k_proj.weight,
169 | v_proj_weight=self.v_proj.weight,
170 | in_proj_weight=None,
171 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
172 | bias_k=None,
173 | bias_v=None,
174 | add_zero_attn=False,
175 | dropout_p=0,
176 | out_proj_weight=self.c_proj.weight,
177 | out_proj_bias=self.c_proj.bias,
178 | use_separate_proj_weight=True,
179 | training=self.training,
180 | need_weights=False
181 | )
182 | return x.squeeze(0)
183 |
184 |
185 | class ModifiedResNet(nn.Module):
186 | """
187 | A ResNet class that is similar to torchvision's but contains the following changes:
188 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
189 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
190 | - The final pooling layer is a QKV attention instead of an average pool
191 | """
192 |
193 | def __init__(self, layers, output_dim, heads, h=75, w=133, input_resolution=224, width=64):
194 | super().__init__()
195 |
196 | self.height, self.width = h, w
197 |
198 | self.output_dim = output_dim
199 | self.input_resolution = input_resolution
200 |
201 | # the 3-layer stem
202 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
203 | self.bn1 = nn.BatchNorm2d(width // 2)
204 | self.relu1 = nn.ReLU(inplace=True)
205 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
206 | self.bn2 = nn.BatchNorm2d(width // 2)
207 | self.relu2 = nn.ReLU(inplace=True)
208 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
209 | self.bn3 = nn.BatchNorm2d(width)
210 | self.relu3 = nn.ReLU(inplace=True)
211 | self.avgpool = nn.AvgPool2d(2)
212 |
213 |
214 |
215 | # residual layers
216 | self._inplanes = width # this is a *mutable* variable used during construction
217 | self.layer1 = self._make_layer(width, layers[0])
218 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
219 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
220 | self.layer4, h, w = self._make_layer_dynamic(width * 8, layers[3], 14*2, 14*2, stride=2)
221 |
222 | embed_dim = width * 32 # the ResNet feature dimension
223 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
224 |
225 | def _make_layer(self, planes, blocks, stride=1):
226 |
227 | layers = [Bottleneck(self._inplanes, planes,0,0,8, stride)]
228 | self._inplanes = planes * Bottleneck.expansion
229 | for _ in range(1, blocks):
230 | layers.append(Bottleneck(self._inplanes, planes))
231 | return nn.Sequential(*layers)
232 |
233 | def _make_layer_dynamic(self, planes, blocks, h, w, stride=1):
234 |
235 | layers = [Bottleneck(self._inplanes, planes, h, w, 8, stride)]
236 |
237 | h = conv2d_out_dim(h, kernel_size=1, stride=stride, padding=0)
238 | w = conv2d_out_dim(w, kernel_size=1, stride=stride, padding=0)
239 | self._inplanes = planes * Bottleneck.expansion
240 | for _ in range(1, blocks):
241 | layers.append(Bottleneck(self._inplanes, planes, h, w))
242 |
243 | return nn.Sequential(*layers), h, w
244 |
245 |
246 |
247 | def forward(self, x):
248 | def stem(x):
249 | x = self.relu1(self.bn1(self.conv1(x)))
250 | x = self.relu2(self.bn2(self.conv2(x)))
251 | x = self.relu3(self.bn3(self.conv3(x)))
252 | x = self.avgpool(x)
253 | return x
254 |
255 | x = x.type(self.conv1.weight.dtype)
256 | x = stem(x)
257 | x = self.layer1(x)
258 | x = self.layer2(x)
259 | x = self.layer3(x)
260 | x = self.layer4(x)
261 | x = self.attnpool(x)
262 |
263 | return x
264 |
265 |
266 | class LayerNorm(nn.LayerNorm):
267 | """Subclass torch's LayerNorm to handle fp16."""
268 |
269 | def forward(self, x: torch.Tensor):
270 | orig_type = x.dtype
271 | ret = super().forward(x.type(torch.float32))
272 | return ret.type(orig_type)
273 |
274 |
275 | class QuickGELU(nn.Module):
276 | def forward(self, x: torch.Tensor):
277 | return x * torch.sigmoid(1.702 * x)
278 |
279 |
280 | class ResidualAttentionBlock(nn.Module):
281 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
282 | super().__init__()
283 |
284 | self.attn = nn.MultiheadAttention(d_model, n_head)
285 | self.ln_1 = LayerNorm(d_model)
286 | self.mlp = nn.Sequential(OrderedDict([
287 | ("c_fc", nn.Linear(d_model, d_model * 4)),
288 | ("gelu", QuickGELU()),
289 | ("c_proj", nn.Linear(d_model * 4, d_model))
290 | ]))
291 | self.ln_2 = LayerNorm(d_model)
292 | self.attn_mask = attn_mask
293 |
294 | def attention(self, x: torch.Tensor):
295 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
296 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
297 |
298 | def forward(self, x: torch.Tensor):
299 | x = x + self.attention(self.ln_1(x))
300 | x = x + self.mlp(self.ln_2(x))
301 | return x
302 |
303 |
304 | class Transformer(nn.Module):
305 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
306 | super().__init__()
307 | self.width = width
308 | self.layers = layers
309 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
310 |
311 | def forward(self, x: torch.Tensor):
312 | return self.resblocks(x)
313 |
314 |
315 | class VisionTransformer(nn.Module):
316 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
317 | super().__init__()
318 | self.input_resolution = input_resolution
319 | self.output_dim = output_dim
320 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
321 |
322 | scale = width ** -0.5
323 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
324 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
325 | self.ln_pre = LayerNorm(width)
326 |
327 | self.transformer = Transformer(width, layers, heads)
328 |
329 | self.ln_post = LayerNorm(width)
330 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
331 |
332 | def forward(self, x: torch.Tensor):
333 | x = self.conv1(x) # shape = [*, width, grid, grid]
334 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
335 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
336 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
337 | x = x + self.positional_embedding.to(x.dtype)
338 | x = self.ln_pre(x)
339 |
340 | x = x.permute(1, 0, 2) # NLD -> LND
341 | x = self.transformer(x)
342 | x = x.permute(1, 0, 2) # LND -> NLD
343 |
344 | x = self.ln_post(x[:, 0, :])
345 |
346 | if self.proj is not None:
347 | x = x @ self.proj
348 |
349 | return x
350 |
351 |
352 | class CLIP(nn.Module):
353 | def __init__(self,
354 | embed_dim: int,
355 | # vision
356 | image_resolution: int,
357 | vision_layers: Union[Tuple[int, int, int, int], int],
358 | vision_width: int,
359 | vision_patch_size: int,
360 | # text
361 | context_length: int,
362 | vocab_size: int,
363 | transformer_width: int,
364 | transformer_heads: int,
365 | transformer_layers: int
366 | ):
367 | super().__init__()
368 |
369 | self.context_length = context_length
370 |
371 | if isinstance(vision_layers, (tuple, list)):
372 | vision_heads = vision_width * 32 // 64
373 | self.visual = ModifiedResNet(
374 | layers=vision_layers,
375 | output_dim=embed_dim,
376 | heads=vision_heads,
377 | input_resolution=image_resolution,
378 | width=vision_width
379 | )
380 | else:
381 | vision_heads = vision_width // 64
382 | self.visual = VisionTransformer(
383 | input_resolution=image_resolution,
384 | patch_size=vision_patch_size,
385 | width=vision_width,
386 | layers=vision_layers,
387 | heads=vision_heads,
388 | output_dim=embed_dim
389 | )
390 |
391 | self.transformer = Transformer(
392 | width=transformer_width,
393 | layers=transformer_layers,
394 | heads=transformer_heads,
395 | attn_mask=self.build_attention_mask()
396 | )
397 |
398 | self.vocab_size = vocab_size
399 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
400 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
401 | self.ln_final = LayerNorm(transformer_width)
402 |
403 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
404 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
405 |
406 | self.initialize_parameters()
407 |
408 | def initialize_parameters(self):
409 | nn.init.normal_(self.token_embedding.weight, std=0.02)
410 | nn.init.normal_(self.positional_embedding, std=0.01)
411 |
412 | if isinstance(self.visual, ModifiedResNet):
413 | if self.visual.attnpool is not None:
414 | std = self.visual.attnpool.c_proj.in_features ** -0.5
415 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
416 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
417 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
418 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
419 |
420 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
421 | for name, param in resnet_block.named_parameters():
422 | if name.endswith("bn3.weight"):
423 | nn.init.zeros_(param)
424 |
425 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
426 | attn_std = self.transformer.width ** -0.5
427 | fc_std = (2 * self.transformer.width) ** -0.5
428 | for block in self.transformer.resblocks:
429 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
430 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
431 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
432 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
433 |
434 | if self.text_projection is not None:
435 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
436 |
437 | def build_attention_mask(self):
438 | # lazily create causal attention mask, with full attention between the vision tokens
439 | # pytorch uses additive attention mask; fill with -inf
440 | mask = torch.empty(self.context_length, self.context_length)
441 | mask.fill_(float("-inf"))
442 | mask.triu_(1) # zero out the lower diagonal
443 | return mask
444 |
445 | @property
446 | def dtype(self):
447 | return self.visual.conv1.weight.dtype
448 |
449 | def encode_image(self, image):
450 | return self.visual(image.type(self.dtype))
451 |
452 | def encode_text(self, text):
453 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
454 |
455 | x = x + self.positional_embedding.type(self.dtype)
456 | x = x.permute(1, 0, 2) # NLD -> LND
457 | x = self.transformer(x)
458 | x = x.permute(1, 0, 2) # LND -> NLD
459 | x = self.ln_final(x).type(self.dtype)
460 |
461 | # take features from the eot embedding (eot_token is the highest number in each sequence)
462 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
463 |
464 | return x
465 |
466 | def forward(self, image, text):
467 | image_features = self.encode_image(image)
468 | text_features = self.encode_text(text)
469 |
470 | # normalized features
471 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
472 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
473 |
474 | # cosine similarity as logits
475 | logit_scale = self.logit_scale.exp()
476 | logits_per_image = logit_scale * image_features @ text_features.t()
477 | logits_per_text = logits_per_image.t()
478 |
479 | # shape = [global_batch_size, global_batch_size]
480 | return logits_per_image, logits_per_text
481 |
482 |
483 | def convert_weights(model: nn.Module):
484 | """Convert applicable model parameters to fp16"""
485 |
486 | def _convert_weights_to_fp16(l):
487 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
488 | l.weight.data = l.weight.data.half()
489 | if l.bias is not None:
490 | l.bias.data = l.bias.data.half()
491 |
492 | if isinstance(l, nn.MultiheadAttention):
493 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
494 | tensor = getattr(l, attr)
495 | if tensor is not None:
496 | tensor.data = tensor.data.half()
497 |
498 | for name in ["text_projection", "proj"]:
499 | if hasattr(l, name):
500 | attr = getattr(l, name)
501 | if attr is not None:
502 | attr.data = attr.data.half()
503 |
504 | model.apply(_convert_weights_to_fp16)
505 |
506 |
507 | def build_model(state_dict: dict):
508 | vit = "visual.proj" in state_dict
509 |
510 | if vit:
511 | vision_width = state_dict["visual.conv1.weight"].shape[0]
512 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
513 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
514 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
515 | image_resolution = vision_patch_size * grid_size
516 | else:
517 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
518 | vision_layers = tuple(counts)
519 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
520 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
521 | vision_patch_size = None
522 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
523 | image_resolution = output_width * 32
524 |
525 | embed_dim = state_dict["text_projection"].shape[1]
526 | context_length = state_dict["positional_embedding"].shape[0]
527 | vocab_size = state_dict["token_embedding.weight"].shape[0]
528 | transformer_width = state_dict["ln_final.weight"].shape[0]
529 | transformer_heads = transformer_width // 64
530 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
531 |
532 | model = CLIP(
533 | embed_dim,
534 | image_resolution, vision_layers, vision_width, vision_patch_size,
535 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
536 | )
537 |
538 | for key in ["input_resolution", "context_length", "vocab_size"]:
539 | if key in state_dict:
540 | del state_dict[key]
541 |
542 | convert_weights(model)
543 | model.load_state_dict(state_dict, strict=False)
544 | return model.eval()
545 |
--------------------------------------------------------------------------------