├── modeling ├── __pycache__ │ ├── sbmnist_rcnn.cpython-36.pyc.140601517918224 │ ├── clip.cpython-38.pyc │ ├── dis.cpython-36.pyc │ ├── dis.cpython-38.pyc │ ├── rpn.cpython-38.pyc │ ├── config.cpython-36.pyc │ ├── config.cpython-38.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── attention.cpython-38.pyc │ ├── backbone.cpython-38.pyc │ ├── meta_arch.cpython-38.pyc │ ├── roi_head.cpython-38.pyc │ ├── stn_arch.cpython-36.pyc │ ├── stn_arch.cpython-38.pyc │ ├── tps_arch.cpython-36.pyc │ ├── tps_arch.cpython-38.pyc │ ├── sbmnist_rpn.cpython-36.pyc │ ├── sbmnist_rpn.cpython-38.pyc │ ├── box_predictor.cpython-38.pyc │ ├── sbmnist_rcnn.cpython-36.pyc │ ├── sbmnist_rcnn.cpython-38.pyc │ └── custom_pascal_evaluation.cpython-38.pyc ├── __init__.py ├── config.py ├── box_predictor.py ├── clip.py ├── backbone.py ├── custom_pascal_evaluation.py ├── rpn.py ├── roi_head.py └── meta_arch.py ├── data └── datasets │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── builtin.cpython-38.pyc │ ├── diverse_weather.cpython-38.pyc │ ├── pascal_voc_adaptation.cpython-38.pyc │ └── comic_water_adaptation.cpython-38.pyc │ ├── __init__.py │ ├── builtin.py │ ├── pascal_voc_adaptation.py │ ├── comic_water_adaptation.py │ └── diverse_weather.py ├── requirements.txt ├── prunedprompts2.txt ├── configs ├── diverse_weather.yaml └── comic_watercolor.yaml ├── README.md ├── train.py └── train_voc.py /modeling/__pycache__/sbmnist_rcnn.cpython-36.pyc.140601517918224: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modeling/__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/dis.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/dis.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/dis.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/dis.cpython-38.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/rpn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/rpn.cpython-38.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/backbone.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/backbone.cpython-38.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/meta_arch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/meta_arch.cpython-38.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/roi_head.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/roi_head.cpython-38.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/stn_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/stn_arch.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/stn_arch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/stn_arch.cpython-38.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/tps_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/tps_arch.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/tps_arch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/tps_arch.cpython-38.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/sbmnist_rpn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/sbmnist_rpn.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/sbmnist_rpn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/sbmnist_rpn.cpython-38.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/data/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/builtin.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/data/datasets/__pycache__/builtin.cpython-38.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/box_predictor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/box_predictor.cpython-38.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/sbmnist_rcnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/sbmnist_rcnn.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/sbmnist_rcnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/sbmnist_rcnn.cpython-38.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/diverse_weather.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/data/datasets/__pycache__/diverse_weather.cpython-38.pyc -------------------------------------------------------------------------------- /modeling/__pycache__/custom_pascal_evaluation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/modeling/__pycache__/custom_pascal_evaluation.cpython-38.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.21.5 2 | torch==1.10.1 3 | torchvision==0.11.2 4 | detectron2==0.6 5 | kornia==0.6.3 6 | git+https://github.com/openai/CLIP.git 7 | pymage_size 8 | -------------------------------------------------------------------------------- /data/datasets/__pycache__/pascal_voc_adaptation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/data/datasets/__pycache__/pascal_voc_adaptation.cpython-38.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/comic_water_adaptation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vidit09/domaingen/HEAD/data/datasets/__pycache__/comic_water_adaptation.cpython-38.pyc -------------------------------------------------------------------------------- /data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from . import builtin # ensure the builtin datasets are registered 4 | 5 | __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")] 6 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .rpn import SBRPN 2 | from .backbone import ClipRN101 3 | from .meta_arch import ClipRCNNWithClipBackbone 4 | from .roi_head import ClipRes5ROIHeads 5 | from .config import add_stn_config 6 | from .custom_pascal_evaluation import CustomPascalVOCDetectionEvaluator -------------------------------------------------------------------------------- /prunedprompts2.txt: -------------------------------------------------------------------------------- 1 | an image taken on a snow night 2 | an image taken on a fog night 3 | an image taken on a cloudy night 4 | an image taken on a rain night 5 | an image taken on a stormy night 6 | an image taken on a snow day 7 | an image taken on a fog day 8 | an image taken on a cloudy day 9 | an image taken on a rain day 10 | an image taken on a stormy day 11 | an image taken on a snow evening 12 | an image taken on a fog evening 13 | an image taken on a cloudy evening 14 | an image taken on a rain evening 15 | an image taken on a stormy evening 16 | -------------------------------------------------------------------------------- /data/datasets/builtin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | from .diverse_weather import register_dataset as register_diverse_weather 4 | from .pascal_voc_adaptation import register_all_pascal_voc as register_pascal_voc 5 | from .comic_water_adaptation import register_dataset as register_comic_water 6 | import os 7 | 8 | _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets")) 9 | DEFAULT_DATASETS_ROOT = "data/" 10 | 11 | 12 | register_diverse_weather(_root) 13 | register_pascal_voc(_root) 14 | register_comic_water(_root) 15 | -------------------------------------------------------------------------------- /modeling/config.py: -------------------------------------------------------------------------------- 1 | from detectron2.config import CfgNode as CN 2 | 3 | def add_stn_config(cfg): 4 | cfg.OFFSET_DOMAIN = '' 5 | cfg.OFFSET_FROZENBN = False 6 | cfg.OFFSET_DOMAIN_TEXT = '' 7 | cfg.OFFSET_NAME = 0 8 | cfg.OFFSET_OPT_INTERVAL = [10] 9 | cfg.OFFSET_OPT_ITERS = 0 10 | cfg.AUG_PROB = 0.5 11 | cfg.DOMAIN_NAME = "" 12 | cfg.TEST.EVAL_SAVE_PERIOD = 5000 13 | cfg.INPUT.CLIP_WITH_IMG = False 14 | cfg.INPUT.CLIP_RANDOM_CROPS = False 15 | cfg.INPUT.IMAGE_JITTER = False 16 | cfg.INPUT.RANDOM_CROP_SIZE = 224 17 | cfg.MODEL.GLOBAL_GND = False 18 | cfg.BASE_YAML = "COCO-Detection/faster_rcnn_R_50_C4_3x.yaml" 19 | cfg.MODEL.RENAME = list() 20 | cfg.MODEL.CLIP_IMAGE_ENCODER_NAME = 'ViT-B/32' 21 | cfg.MODEL.BACKBONE.UNFREEZE = ['layer3','layer4','attnpool'] 22 | cfg.MODEL.USE_PROJ = True 23 | -------------------------------------------------------------------------------- /modeling/box_predictor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | import torch 3 | 4 | from detectron2.layers import cat, cross_entropy 5 | from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers 6 | from .clip import ClipPredictor 7 | 8 | class ClipFastRCNNOutputLayers(FastRCNNOutputLayers): 9 | 10 | def __init__(self,cfg, input_shape, clsnames) -> None: 11 | super().__init__(cfg, input_shape) 12 | self.cls_score = ClipPredictor(cfg.MODEL.CLIP_IMAGE_ENCODER_NAME, input_shape.channels, cfg.MODEL.DEVICE,clsnames) 13 | # self.proj = torch.nn.Linear(512,2048) 14 | def forward(self,x,gfeat=None): 15 | # if x.dim() > 2: 16 | # x = torch.flatten(x, start_dim=1) 17 | 18 | ## for features from clip model 19 | if isinstance(x,list): 20 | scores = self.cls_score(x[0],gfeat) 21 | proposal_deltas = self.bbox_pred(x[1])#self.bbox_pred(self.proj(x[0]/x[0].norm(dim=-1,keepdim=True))) 22 | else: 23 | scores = self.cls_score(x,gfeat) 24 | proposal_deltas = self.bbox_pred(x) 25 | 26 | return scores, proposal_deltas 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /configs/diverse_weather.yaml: -------------------------------------------------------------------------------- 1 | BASE_YAML: "COCO-Detection/faster_rcnn_R_101_C4_3x.yaml" 2 | DATASETS: 3 | TRAIN: ("daytime_clear_train",) 4 | TEST: ('daytime_clear_test',) 5 | DATALOADER: 6 | NUM_WORKERS: 16 7 | INPUT: 8 | MIN_SIZE_TRAIN: (600,) 9 | MIN_SIZE_TEST: 600 10 | CLIP_RANDOM_CROPS: True 11 | RANDOM_CROP_SIZE: 400 12 | 13 | SOLVER: 14 | BASE_LR: 0.001 15 | MAX_ITER: 200000 16 | STEPS: [40000,] 17 | WARMUP_ITERS: 0 18 | IMS_PER_BATCH: 4 19 | CHECKPOINT_PERIOD: 1000000 20 | MODEL: 21 | BACKBONE: 22 | NAME: ClipRN101 23 | WEIGHTS: "" 24 | CLIP_IMAGE_ENCODER_NAME: 'RN101' 25 | META_ARCHITECTURE: 'ClipRCNNWithClipBackboneWithOffsetGenTrainable' 26 | 27 | PROPOSAL_GENERATOR: 28 | NAME: 'SBRPN' 29 | ROI_HEADS: 30 | NAME: 'ClipRes5ROIHeadsAttn' 31 | # BATCH_SIZE_PER_IMAGE: 128 # faster, and good enough for this toy dataset (default: 512) 32 | NUM_CLASSES: 7 33 | TEST: 34 | EVAL_SAVE_PERIOD: 5000 35 | OUTPUT_DIR: "all_outs/diverse_weather" 36 | VIS_PERIOD: 5000 37 | OFFSET_OPT_INTERVAL: [20000000] 38 | OFFSET_OPT_ITERS: 1000 39 | -------------------------------------------------------------------------------- /configs/comic_watercolor.yaml: -------------------------------------------------------------------------------- 1 | BASE_YAML: "COCO-Detection/faster_rcnn_R_101_C4_3x.yaml" 2 | DATASETS: 3 | TRAIN: ('voc_adapt_2007_trainval',"voc_adapt_2012_train",) 4 | TEST: ('voc_adapt_2012_val','voc_adapt_2007_test','comic_test','watercolor_test') 5 | DATALOADER: 6 | NUM_WORKERS: 16 7 | INPUT: 8 | MIN_SIZE_TRAIN: (600,) 9 | MIN_SIZE_TEST: 600 10 | CLIP_RANDOM_CROPS: True 11 | RANDOM_CROP_SIZE: 400 12 | 13 | SOLVER: 14 | BASE_LR: 0.0001 15 | MAX_ITER: 100000 16 | STEPS: [] #[10000,] 17 | WARMUP_ITERS: 0 18 | IMS_PER_BATCH: 4 19 | CHECKPOINT_PERIOD: 1000000 20 | MODEL: 21 | BACKBONE: 22 | NAME: ClipRN101 23 | WEIGHTS: "" 24 | CLIP_IMAGE_ENCODER_NAME: 'RN101' 25 | META_ARCHITECTURE: 'ClipRCNNWithClipBackboneWithOffsetGenTrainableVOC' 26 | 27 | PROPOSAL_GENERATOR: 28 | NAME: 'SBRPN' 29 | ROI_HEADS: 30 | NAME: 'ClipRes5ROIHeadsAttn' 31 | # BATCH_SIZE_PER_IMAGE: 128 # faster, and good enough for this toy dataset (default: 512) 32 | NUM_CLASSES: 6 33 | TEST: 34 | EVAL_SAVE_PERIOD: 5000 35 | OUTPUT_DIR: "all_outs/comic_watercolor" 36 | VIS_PERIOD: 5000 37 | OFFSET_OPT_INTERVAL: [2000000] 38 | OFFSET_OPT_ITERS: 1000 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLIP the Gap: A Single Domain Generalization Approach for Object Detection 2 | 3 | [ [Paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Vidit_CLIP_the_Gap_A_Single_Domain_Generalization_Approach_for_Object_CVPR_2023_paper.pdf) ] 4 | 5 | ### Installation 6 | Our code is based on [Detectron2](https://github.com/facebookresearch/detectron2) and requires python >= 3.6 7 | 8 | Install the required packages 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ### Datasets 14 | Set the environment variable DETECTRON2_DATASETS to the parent folder of the datasets 15 | 16 | ``` 17 | path-to-parent-dir/ 18 | /diverseWeather 19 | /daytime_clear 20 | /daytime_foggy 21 | ... 22 | /comic 23 | /watercolor 24 | /VOC2007 25 | /VOC2012 26 | 27 | ``` 28 | Download [Diverse Weather](https://github.com/AmingWu/Single-DGOD) and [Cross-Domain](https://naoto0804.github.io/cross_domain_detection/) Datasets and place in the structure as shown. 29 | 30 | ### Training 31 | We train our models on a single A100 GPU. 32 | ``` 33 | python train.py --config-file configs/diverse_weather.yaml 34 | 35 | or 36 | 37 | python train_voc.py --config-file configs/comic_watercolor.yaml 38 | ``` 39 | 40 | ### Weights 41 | [Download](https://drive.google.com/file/d/1qMJfMZkE7cG6wwphQtA4uAxfh0NBVItu/view?usp=drive_link) the trained weights. 42 | 43 | ### Citation 44 | ```bibtex 45 | @InProceedings{Vidit_2023_CVPR, 46 | author = {Vidit, Vidit and Engilberge, Martin and Salzmann, Mathieu}, 47 | title = {CLIP the Gap: A Single Domain Generalization Approach for Object Detection}, 48 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 49 | month = {June}, 50 | year = {2023}, 51 | pages = {3219-3229} 52 | } 53 | 54 | ``` 55 | -------------------------------------------------------------------------------- /modeling/clip.py: -------------------------------------------------------------------------------- 1 | import clip 2 | 3 | import torch 4 | import torch.nn as nn 5 | import time 6 | import numpy as np 7 | import copy 8 | 9 | class ClipPredictor(nn.Module): 10 | def __init__(self, clip_enocder_name,inshape, device, clsnames): 11 | super().__init__() 12 | self.model, self.preprocess = clip.load(clip_enocder_name, device) 13 | self.model.float() 14 | #freeze everything 15 | for name, val in self.model.named_parameters(): 16 | val.requires_grad = False 17 | # this is only used for inference 18 | self.frozen_clip_model = copy.deepcopy(self.model) 19 | 20 | self.visual_enc = self.model.visual 21 | prompt = 'a photo of a {}' 22 | print(clsnames) 23 | with torch.no_grad(): 24 | text_inputs = torch.cat([clip.tokenize(prompt.format(cls)) for cls in clsnames]).to(device) 25 | self.text_features = self.model.encode_text(text_inputs).float() 26 | self.text_features /= self.text_features.norm(dim=-1, keepdim=True) 27 | 28 | 29 | self.projection = nn.Linear(inshape,512) 30 | self.projection_global = nn.Linear(inshape,512) 31 | 32 | 33 | 34 | def forward(self, feat, gfeat=None): 35 | 36 | if feat.shape[-1] > 512: 37 | feat = self.projection(feat) 38 | 39 | feat = feat/feat.norm(dim=-1,keepdim=True) 40 | if gfeat is not None: 41 | 42 | feat = feat-gfeat 43 | feat = feat/feat.norm(dim=-1,keepdim=True) 44 | scores = (100.0 * torch.matmul(feat,self.text_features.detach().T)) 45 | 46 | # print(scores.min(),scores.max()) 47 | # add for bkg class a score 0 48 | scores = torch.cat([scores,torch.zeros(scores.shape[0],1,device=scores.device)],1) 49 | return scores 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /modeling/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torchvision.transforms as T 6 | 7 | from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec 8 | 9 | @BACKBONE_REGISTRY.register() 10 | class ClipRN101(Backbone): 11 | def __init__(self, cfg, clip_visual): 12 | super().__init__() 13 | self.enc = None 14 | self.unfreeze = cfg.MODEL.BACKBONE.UNFREEZE 15 | self.proj = nn.Linear(512,512) 16 | self.global_proj = nn.Linear(512,512) 17 | self.use_proj = cfg.MODEL.USE_PROJ 18 | 19 | 20 | def set_backbone_model(self,model): 21 | self.enc = model 22 | for name,val in self.enc.named_parameters(): 23 | head = name.split('.')[0] 24 | if head not in self.unfreeze: 25 | val.requires_grad = False 26 | else: 27 | val.requires_grad = True 28 | 29 | self.backbone_unchanged = nn.Sequential(*self.enc.layer3[:19]) 30 | 31 | def forward(self, image): 32 | x = image 33 | x = self.enc.relu1(self.enc.bn1(self.enc.conv1(x))) 34 | x = self.enc.relu2(self.enc.bn2(self.enc.conv2(x))) 35 | x = self.enc.relu3(self.enc.bn3(self.enc.conv3(x))) 36 | x = self.enc.avgpool(x) 37 | 38 | x = self.enc.layer1(x) 39 | x = self.enc.layer2(x) 40 | x = self.enc.layer3(x) 41 | return {"res4": x} 42 | 43 | 44 | def forward_l12(self, image): 45 | x = image 46 | x = self.enc.relu1(self.enc.bn1(self.enc.conv1(x))) 47 | x = self.enc.relu2(self.enc.bn2(self.enc.conv2(x))) 48 | x = self.enc.relu3(self.enc.bn3(self.enc.conv3(x))) 49 | x = self.enc.avgpool(x) 50 | 51 | x = self.enc.layer1(x) 52 | x = self.enc.layer2(x) 53 | 54 | return x 55 | 56 | def forward_l3(self, x): 57 | 58 | x = self.enc.layer3(x) 59 | return {"res4": x} 60 | 61 | def output_shape(self): 62 | return {"res4": ShapeSpec(channels=1024, stride=16)} 63 | 64 | def forward_res5(self,x): 65 | #detectron used last resnet layer for roi heads 66 | return self.enc.layer4(x) 67 | 68 | def attention_global_pool(self,input): 69 | x = input 70 | x = self.enc.attnpool(x) 71 | return x 72 | 73 | -------------------------------------------------------------------------------- /modeling/custom_pascal_evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import tempfile 4 | from collections import OrderedDict, defaultdict 5 | 6 | from detectron2.utils import comm 7 | from detectron2.evaluation import PascalVOCDetectionEvaluator 8 | from detectron2.evaluation.pascal_voc_evaluation import voc_eval 9 | 10 | class CustomPascalVOCDetectionEvaluator(PascalVOCDetectionEvaluator): 11 | def __init__(self,dataset_name): 12 | super().__init__(dataset_name) 13 | 14 | def evaluate(self): 15 | """ 16 | Returns: 17 | dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75". 18 | """ 19 | all_predictions = comm.gather(self._predictions, dst=0) 20 | if not comm.is_main_process(): 21 | return 22 | predictions = defaultdict(list) 23 | for predictions_per_rank in all_predictions: 24 | for clsid, lines in predictions_per_rank.items(): 25 | predictions[clsid].extend(lines) 26 | del all_predictions 27 | 28 | self._logger.info( 29 | "Evaluating {} using {} metric. " 30 | "Note that results do not use the official Matlab API.".format( 31 | self._dataset_name, 2007 if self._is_2007 else 2012 32 | ) 33 | ) 34 | 35 | 36 | with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname: 37 | res_file_template = os.path.join(dirname, "{}.txt") 38 | 39 | aps = defaultdict(list) # iou -> ap per class 40 | for cls_id, cls_name in enumerate(self._class_names): 41 | lines = predictions.get(cls_id, [""]) 42 | 43 | with open(res_file_template.format(cls_name), "w") as f: 44 | f.write("\n".join(lines)) 45 | 46 | for thresh in range(50, 100, 5): 47 | rec, prec, ap = voc_eval( 48 | res_file_template, 49 | self._anno_file_template, 50 | self._image_set_path, 51 | cls_name, 52 | ovthresh=thresh / 100.0, 53 | use_07_metric=self._is_2007, 54 | ) 55 | aps[thresh].append(ap * 100) 56 | 57 | ret = OrderedDict() 58 | mAP = {iou: np.mean(x) for iou, x in aps.items()} 59 | 60 | clsaps = ','.join(['{:.2f}'.format(a) for a in aps[50]]) 61 | self._logger.info("classwise ap {}".format(clsaps)) 62 | 63 | ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]} 64 | return ret 65 | -------------------------------------------------------------------------------- /data/datasets/pascal_voc_adaptation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | import numpy as np 5 | import os 6 | import xml.etree.ElementTree as ET 7 | from typing import List, Tuple, Union 8 | 9 | from detectron2.data import DatasetCatalog, MetadataCatalog 10 | from detectron2.structures import BoxMode 11 | from detectron2.utils.file_io import PathManager 12 | 13 | __all__ = ["load_voc_instances", "register_all_pascal_voc"] 14 | 15 | 16 | # fmt: off 17 | CLASS_NAMES = ( 18 | "bicycle", "bird", "car", "cat", "dog", "person", 19 | ) 20 | # fmt: on 21 | 22 | 23 | def load_voc_instances(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): 24 | """ 25 | Load Pascal VOC detection annotations to Detectron2 format. 26 | Args: 27 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 28 | split (str): one of "train", "test", "val", "trainval" 29 | class_names: list or tuple of class names 30 | """ 31 | with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f: 32 | fileids = np.loadtxt(f, dtype=np.str) 33 | 34 | # Needs to read many small annotation files. Makes sense at local 35 | annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/")) 36 | dicts = [] 37 | for fileid in fileids: 38 | anno_file = os.path.join(annotation_dirname, fileid + ".xml") 39 | jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg") 40 | 41 | with PathManager.open(anno_file) as f: 42 | tree = ET.parse(f) 43 | 44 | r = { 45 | "file_name": jpeg_file, 46 | "image_id": fileid, 47 | "height": int(tree.findall("./size/height")[0].text), 48 | "width": int(tree.findall("./size/width")[0].text), 49 | } 50 | instances = [] 51 | 52 | for obj in tree.findall("object"): 53 | cls = obj.find("name").text 54 | if cls not in CLASS_NAMES: 55 | continue 56 | # We include "difficult" samples in training. 57 | # Based on limited experiments, they don't hurt accuracy. 58 | # difficult = int(obj.find("difficult").text) 59 | # if difficult == 1: 60 | # continue 61 | bbox = obj.find("bndbox") 62 | bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]] 63 | # Original annotations are integers in the range [1, W or H] 64 | # Assuming they mean 1-based pixel indices (inclusive), 65 | # a box with annotation (xmin=1, xmax=W) covers the whole image. 66 | # In coordinate space this is represented by (xmin=0, xmax=W) 67 | bbox[0] -= 1.0 68 | bbox[1] -= 1.0 69 | instances.append( 70 | {"category_id": class_names.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS} 71 | ) 72 | r["annotations"] = instances 73 | dicts.append(r) 74 | return dicts 75 | 76 | 77 | def register_pascal_voc(name, dirname, split, year, class_names=CLASS_NAMES): 78 | DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names)) 79 | MetadataCatalog.get(name).set( 80 | thing_classes=list(class_names), dirname=dirname, year=year, split=split 81 | ) 82 | 83 | def register_all_pascal_voc(root): 84 | SPLITS = [ 85 | ("voc_adapt_2007_trainval", "VOC2007", "trainval"), 86 | ("voc_adapt_2007_train", "VOC2007", "train"), 87 | ("voc_adapt_2007_val", "VOC2007", "val"), 88 | ("voc_adapt_2007_test", "VOC2007", "test"), 89 | ("voc_adapt_2012_trainval", "VOC2012", "trainval"), 90 | ("voc_adapt_2012_train", "VOC2012", "train"), 91 | ("voc_adapt_2012_val", "VOC2012", "val"), 92 | ] 93 | for name, dirname, split in SPLITS: 94 | year = 2007 if "2007" in name else 2012 95 | register_pascal_voc(name, os.path.join(root, dirname), split, year) 96 | MetadataCatalog.get(name).evaluator_type = "pascal_voc" -------------------------------------------------------------------------------- /modeling/rpn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from detectron2.modeling import PROPOSAL_GENERATOR_REGISTRY, RPN_HEAD_REGISTRY 6 | from detectron2.modeling.proposal_generator.rpn import RPN, StandardRPNHead 7 | from detectron2.structures import ImageList 8 | from typing import List 9 | import time 10 | 11 | 12 | @PROPOSAL_GENERATOR_REGISTRY.register() 13 | class SBRPN(RPN): 14 | 15 | def forward( 16 | self, 17 | images, 18 | features, 19 | gt_instances= None, 20 | ): 21 | """ 22 | Args: 23 | images (ImageList): input images of length `N` 24 | features (dict[str, Tensor]): input data as a mapping from feature 25 | map name to tensor. Axis 0 represents the number of images `N` in 26 | the input data; axes 1-3 are channels, height, and width, which may 27 | vary between feature maps (e.g., if a feature pyramid is used). 28 | gt_instances (list[Instances], optional): a length `N` list of `Instances`s. 29 | Each `Instances` stores ground-truth instances for the corresponding image. 30 | Returns: 31 | proposals: list[Instances]: contains fields "proposal_boxes", "objectness_logits" 32 | loss: dict[Tensor] or None 33 | """ 34 | features = [features[f] for f in self.in_features] 35 | anchors = self.anchor_generator(features) 36 | val = self.rpn_head(features) 37 | 38 | if len(val) == 2: 39 | pred_objectness_logits, pred_anchor_deltas = val 40 | rep_clip = None 41 | else: 42 | pred_objectness_logits, pred_anchor_deltas, rep_clip = val 43 | 44 | # Transpose the Hi*Wi*A dimension to the middle: 45 | pred_objectness_logits= [ 46 | # (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N, Hi*Wi*A) 47 | score.permute(0, 2, 3, 1).flatten(1) 48 | for score in pred_objectness_logits 49 | ] 50 | 51 | 52 | pred_anchor_deltas = [ 53 | # (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B) -> (N, Hi*Wi*A, B) 54 | x.view(x.shape[0], -1, self.anchor_generator.box_dim, x.shape[-2], x.shape[-1]) 55 | .permute(0, 3, 4, 1, 2) 56 | .flatten(1, -2) 57 | for x in pred_anchor_deltas 58 | ] 59 | 60 | if self.training: 61 | #assert gt_instances is not None, "RPN requires gt_instances in training!" 62 | if gt_instances is not None: 63 | gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances) 64 | losses = self.losses( 65 | anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes 66 | ) 67 | if rep_clip is not None : 68 | if not isinstance(rep_clip, dict): 69 | ll = torch.stack(gt_labels) 70 | ll = ll.reshape(ll.shape[0],-1,15) 71 | valid_mask = ll.ge(0).sum(dim=-1).gt(0) # remove ignored anchors 72 | ll=ll.eq(1).sum(dim=-1) # if an object is present at this location 73 | ll = ll.gt(0).float() 74 | 75 | clip_loss = torch.nn.functional.binary_cross_entropy_with_logits(rep_clip[valid_mask],ll[valid_mask],reduction='sum') 76 | losses.update({'loss_rpn_cls_clip':clip_loss/(self.batch_size_per_image*ll.shape[0])}) 77 | else: 78 | losses.update(rep_clip) 79 | else: 80 | losses = {} 81 | else: 82 | losses = {} 83 | proposals = self.predict_proposals( 84 | anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes 85 | ) 86 | 87 | # if self.training: 88 | 89 | # if gt_instances is None: 90 | out = [ 91 | # (N, Hi*Wi*A) -> (N, Hi, Wi, A) 92 | score.reshape(features[ind].shape[0],features[ind].shape[-2],features[ind].shape[-1],-1) 93 | for ind, score in enumerate(pred_objectness_logits) 94 | ] 95 | # else: 96 | # b,_,h,w = features[0].shape 97 | # out = [1.*(torch.stack(gt_labels)==1).reshape(b,h,w,-1)] 98 | return out, proposals, losses 99 | # else: 100 | # return proposals, losses 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /data/datasets/comic_water_adaptation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | 4 | from tqdm import tqdm 5 | import pickle as pkl 6 | import xml.etree.ElementTree as ET 7 | 8 | import cv2 9 | import numpy as np 10 | from pymage_size import get_image_size 11 | 12 | from detectron2.data import DatasetCatalog, MetadataCatalog 13 | from detectron2.structures import BoxMode 14 | 15 | all_class_name =["bicycle", "bird", "car", "cat", "dog", "person"] 16 | 17 | def get_annotation(root, image_id, ind): 18 | annotation_file = os.path.join(root, "Annotations", "%s.xml" % image_id) 19 | 20 | et = ET.parse(annotation_file) 21 | 22 | 23 | objects = et.findall("object") 24 | 25 | record = {} 26 | record["file_name"] = os.path.join(root, "JPEGImages", "%s.jpg" % image_id) 27 | img_format = get_image_size(record["file_name"]) 28 | w, h = img_format.get_dimensions() 29 | 30 | record["image_id"] = image_id#ind for pascal evaluation actual image name is needed 31 | record["annotations"] = [] 32 | 33 | for obj in objects: 34 | class_name = obj.find('name').text.lower().strip() 35 | if class_name not in all_class_name: 36 | print(class_name) 37 | continue 38 | if obj.find('pose') is None: 39 | obj.append(ET.Element('pose')) 40 | obj.find('pose').text = '0' 41 | 42 | if obj.find('truncated') is None: 43 | obj.append(ET.Element('truncated')) 44 | obj.find('truncated').text = '0' 45 | 46 | if obj.find('difficult') is None: 47 | obj.append(ET.Element('difficult')) 48 | obj.find('difficult').text = '0' 49 | 50 | bbox = obj.find('bndbox') 51 | # VOC dataset format follows Matlab, in which indexes start from 0 52 | x1 = max(0,float(bbox.find('xmin').text) - 1) # fixing when -1 in anno 53 | y1 = max(0,float(bbox.find('ymin').text) - 1) # fixing when -1 in anno 54 | x2 = float(bbox.find('xmax').text) - 1 55 | y2 = float(bbox.find('ymax').text) - 1 56 | box = [x1, y1, x2, y2] 57 | 58 | #pascal voc evaluator requires int 59 | bbox.find('xmin').text = str(int(x1)) 60 | bbox.find('ymin').text = str(int(y1)) 61 | bbox.find('xmax').text = str(int(x2)) 62 | bbox.find('ymax').text = str(int(y2)) 63 | 64 | 65 | record_obj = { 66 | "bbox": box, 67 | "bbox_mode": BoxMode.XYXY_ABS, 68 | "category_id": all_class_name.index(class_name), 69 | } 70 | record["annotations"].append(record_obj) 71 | 72 | if len(record["annotations"]): 73 | #to convert float to int 74 | et.write(annotation_file) 75 | record["height"] = h 76 | record["width"] = w 77 | return record 78 | 79 | else: 80 | return None 81 | 82 | def files2dict(root,split): 83 | 84 | cache_dir = os.path.join(root, 'cache') 85 | 86 | pkl_filename = os.path.basename(root)+f'_{split}.pkl' 87 | pkl_path = os.path.join(cache_dir,pkl_filename) 88 | 89 | if os.path.exists(pkl_path): 90 | with open(pkl_path,'rb') as f: 91 | return pkl.load(f) 92 | else: 93 | try: 94 | os.makedirs(cache_dir) 95 | except OSError as e: 96 | if e.errno != errno.EEXIST: 97 | print(e) 98 | pass 99 | 100 | dataset_dicts = [] 101 | image_sets_file = os.path.join( root, "ImageSets", "Main", "%s.txt" % split) 102 | 103 | with open(image_sets_file) as f: 104 | count = 0 105 | 106 | for line in tqdm(f): 107 | record = get_annotation(root,line.rstrip(),count) 108 | 109 | if record is not None: 110 | dataset_dicts.append(record) 111 | count +=1 112 | 113 | with open(pkl_path, 'wb') as f: 114 | pkl.dump(dataset_dicts,f) 115 | return dataset_dicts 116 | 117 | 118 | def register_dataset(datasets_root): 119 | dataset_list = ['comic', 120 | 'watercolor' 121 | ] 122 | settype = ['train','test'] 123 | 124 | for name in dataset_list: 125 | for ind, d in enumerate(settype): 126 | 127 | DatasetCatalog.register(name+"_" + d, lambda datasets_root=datasets_root,name=name,d=d \ 128 | : files2dict(os.path.join(datasets_root,name), d)) 129 | MetadataCatalog.get(name+ "_" + d).set(thing_classes=all_class_name,evaluator_type='pascal_voc') 130 | MetadataCatalog.get(name+ "_" + d).set(dirname=datasets_root+f'/{name}') 131 | MetadataCatalog.get(name+ "_" + d).set(split=d) 132 | MetadataCatalog.get(name+ "_" + d).set(year=2007) -------------------------------------------------------------------------------- /data/datasets/diverse_weather.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | 4 | from tqdm import tqdm 5 | import pickle as pkl 6 | import xml.etree.ElementTree as ET 7 | 8 | import cv2 9 | import numpy as np 10 | from pymage_size import get_image_size 11 | 12 | from detectron2.data import DatasetCatalog, MetadataCatalog 13 | from detectron2.structures import BoxMode 14 | 15 | all_class_name = ['bus' ,'bike', 'car', 'motor', 'person', 'rider' ,'truck'] 16 | 17 | def get_annotation(root, image_id, ind): 18 | annotation_file = os.path.join(root,'VOC2007', "Annotations", "%s.xml" % image_id) 19 | 20 | et = ET.parse(annotation_file) 21 | 22 | 23 | objects = et.findall("object") 24 | 25 | record = {} 26 | record["file_name"] = os.path.join(root, 'VOC2007', "JPEGImages", "%s.jpg" % image_id) 27 | img_format = get_image_size(record["file_name"]) 28 | w, h = img_format.get_dimensions() 29 | 30 | record["image_id"] = image_id#ind for pascal evaluation actual image name is needed 31 | record["annotations"] = [] 32 | 33 | for obj in objects: 34 | class_name = obj.find('name').text.lower().strip() 35 | if class_name not in all_class_name: 36 | print(class_name) 37 | continue 38 | if obj.find('pose') is None: 39 | obj.append(ET.Element('pose')) 40 | obj.find('pose').text = '0' 41 | 42 | if obj.find('truncated') is None: 43 | obj.append(ET.Element('truncated')) 44 | obj.find('truncated').text = '0' 45 | 46 | if obj.find('difficult') is None: 47 | obj.append(ET.Element('difficult')) 48 | obj.find('difficult').text = '0' 49 | 50 | bbox = obj.find('bndbox') 51 | # VOC dataset format follows Matlab, in which indexes start from 0 52 | x1 = max(0,float(bbox.find('xmin').text) - 1) # fixing when -1 in anno 53 | y1 = max(0,float(bbox.find('ymin').text) - 1) # fixing when -1 in anno 54 | x2 = float(bbox.find('xmax').text) - 1 55 | y2 = float(bbox.find('ymax').text) - 1 56 | box = [x1, y1, x2, y2] 57 | 58 | #pascal voc evaluator requires int 59 | bbox.find('xmin').text = str(int(x1)) 60 | bbox.find('ymin').text = str(int(y1)) 61 | bbox.find('xmax').text = str(int(x2)) 62 | bbox.find('ymax').text = str(int(y2)) 63 | 64 | 65 | record_obj = { 66 | "bbox": box, 67 | "bbox_mode": BoxMode.XYXY_ABS, 68 | "category_id": all_class_name.index(class_name), 69 | } 70 | record["annotations"].append(record_obj) 71 | 72 | if len(record["annotations"]): 73 | #to convert float to int 74 | et.write(annotation_file) 75 | record["height"] = h 76 | record["width"] = w 77 | return record 78 | 79 | else: 80 | return None 81 | 82 | def files2dict(root,split): 83 | 84 | cache_dir = os.path.join(root, 'cache') 85 | 86 | pkl_filename = os.path.basename(root)+f'_{split}.pkl' 87 | pkl_path = os.path.join(cache_dir,pkl_filename) 88 | 89 | if os.path.exists(pkl_path): 90 | with open(pkl_path,'rb') as f: 91 | return pkl.load(f) 92 | else: 93 | try: 94 | os.makedirs(cache_dir) 95 | except OSError as e: 96 | if e.errno != errno.EEXIST: 97 | print(e) 98 | pass 99 | 100 | dataset_dicts = [] 101 | image_sets_file = os.path.join( root,'VOC2007', "ImageSets", "Main", "%s.txt" % split) 102 | 103 | with open(image_sets_file) as f: 104 | count = 0 105 | 106 | for line in tqdm(f): 107 | record = get_annotation(root,line.rstrip(),count) 108 | 109 | if record is not None: 110 | dataset_dicts.append(record) 111 | count +=1 112 | 113 | with open(pkl_path, 'wb') as f: 114 | pkl.dump(dataset_dicts,f) 115 | return dataset_dicts 116 | 117 | 118 | def register_dataset(datasets_root): 119 | datasets_root = os.path.join(datasets_root,'diverseWeather') 120 | dataset_list = ['daytime_clear', 121 | 'daytime_foggy', 122 | 'night_sunny', 123 | 'night_rainy', 124 | 'dusk_rainy', 125 | ] 126 | settype = ['train','test'] 127 | 128 | for name in dataset_list: 129 | for ind, d in enumerate(settype): 130 | 131 | DatasetCatalog.register(name+"_" + d, lambda datasets_root=datasets_root,name=name,d=d \ 132 | : files2dict(os.path.join(datasets_root,name), d)) 133 | MetadataCatalog.get(name+ "_" + d).set(thing_classes=all_class_name,evaluator_type='pascal_voc') 134 | MetadataCatalog.get(name+ "_" + d).set(dirname=datasets_root+f'/{name}/VOC2007') 135 | MetadataCatalog.get(name+ "_" + d).set(split=d) 136 | MetadataCatalog.get(name+ "_" + d).set(year=2007) -------------------------------------------------------------------------------- /modeling/roi_head.py: -------------------------------------------------------------------------------- 1 | from pydoc import classname 2 | from typing import Dict, List, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | import torchvision.transforms as T 8 | 9 | 10 | from detectron2.layers import ShapeSpec 11 | from detectron2.data import MetadataCatalog 12 | 13 | from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads 14 | from detectron2.structures import Boxes, ImageList, Instances, pairwise_iou 15 | from .box_predictor import ClipFastRCNNOutputLayers 16 | 17 | def select_foreground_proposals( 18 | proposals: List[Instances], bg_label: int 19 | ) -> Tuple[List[Instances], List[torch.Tensor]]: 20 | """ 21 | Given a list of N Instances (for N images), each containing a `gt_classes` field, 22 | return a list of Instances that contain only instances with `gt_classes != -1 && 23 | gt_classes != bg_label`. 24 | Args: 25 | proposals (list[Instances]): A list of N Instances, where N is the number of 26 | images in the batch. 27 | bg_label: label index of background class. 28 | Returns: 29 | list[Instances]: N Instances, each contains only the selected foreground instances. 30 | list[Tensor]: N boolean vector, correspond to the selection mask of 31 | each Instances object. True for selected instances. 32 | """ 33 | assert isinstance(proposals, (list, tuple)) 34 | assert isinstance(proposals[0], Instances) 35 | assert proposals[0].has("gt_classes") 36 | fg_proposals = [] 37 | fg_selection_masks = [] 38 | for proposals_per_image in proposals: 39 | gt_classes = proposals_per_image.gt_classes 40 | fg_selection_mask = (gt_classes != -1) & (gt_classes != bg_label) 41 | fg_idxs = fg_selection_mask.nonzero().squeeze(1) 42 | fg_proposals.append(proposals_per_image[fg_idxs]) 43 | fg_selection_masks.append(fg_selection_mask) 44 | return fg_proposals, fg_selection_masks 45 | 46 | 47 | @ROI_HEADS_REGISTRY.register() 48 | class ClipRes5ROIHeads(Res5ROIHeads): 49 | def __init__(self, cfg, input_shape) -> None: 50 | super().__init__(cfg, input_shape) 51 | clsnames = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).get("thing_classes").copy() 52 | 53 | # import pdb;pdb.set_trace() 54 | ##change the labels to represent the objects correctly 55 | for name in cfg.MODEL.RENAME: 56 | ind = clsnames.index(name[0]) 57 | clsnames[ind] = name[1] 58 | 59 | out_channels=cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * (2 ** 3) ### copied 60 | self.box_predictor = ClipFastRCNNOutputLayers(cfg, ShapeSpec(channels=out_channels, height=1, width=1), clsnames) 61 | self.clip_im_predictor = self.box_predictor.cls_score # should call it properly 62 | self.device = cfg.MODEL.DEVICE 63 | def forward( 64 | self, 65 | images: ImageList, 66 | features: Dict[str, torch.Tensor], 67 | proposals: List[Instances], 68 | targets: Optional[List[Instances]] = None, 69 | crops: Optional[List[Tuple]] = None, 70 | ): 71 | """ 72 | See :meth:`ROIHeads.forward`. 73 | """ 74 | del images 75 | 76 | if self.training: 77 | assert targets 78 | proposals = self.label_and_sample_proposals(proposals, targets) 79 | # import pdb;pdb.set_trace() 80 | loss_crop_im = None 81 | if crops is not None: 82 | crop_im = list()#[x[0] for x in crops] #bxcropx3x224x224 83 | crop_boxes = list()#[x[1].to(self.device) for x in crops] #bxcropsx4 84 | keep = torch.ones(len(crops)).bool() 85 | 86 | for ind,x in enumerate(crops): 87 | if len(x) == 0: 88 | keep[ind] = False 89 | continue 90 | crop_im.append(x[0]) 91 | crop_boxes.append(x[1].to(self.device)) 92 | 93 | c = self._shared_roi_transform( 94 | [features[f][keep] for f in self.in_features], crop_boxes) #(b*crops)x2048x7x7 95 | loss_crop_im, _ = self.clip_im_predictor.forward_crops(crop_im,crops_features.mean(dim=[2, 3])) 96 | 97 | del targets 98 | 99 | proposal_boxes = [x.proposal_boxes for x in proposals] 100 | box_features = self._shared_roi_transform( 101 | [features[f] for f in self.in_features], proposal_boxes 102 | ) 103 | predictions = self.box_predictor(box_features.mean(dim=[2, 3])) 104 | # import pdb;pdb.set_trace() 105 | if self.training: 106 | del features 107 | losses = self.box_predictor.losses(predictions, proposals) 108 | if self.mask_on: 109 | proposals, fg_selection_masks = select_foreground_proposals( 110 | proposals, self.num_classes 111 | ) 112 | # Since the ROI feature transform is shared between boxes and masks, 113 | # we don't need to recompute features. The mask loss is only defined 114 | # on foreground proposals, so we need to select out the foreground 115 | # features. 116 | mask_features = box_features[torch.cat(fg_selection_masks, dim=0)] 117 | del box_features 118 | losses.update(self.mask_head(mask_features, proposals)) 119 | 120 | if loss_crop_im is not None: 121 | losses.update(loss_crop_im) 122 | return [], losses 123 | else: 124 | pred_instances, _ = self.box_predictor.inference(predictions, proposals) 125 | pred_instances = self.forward_with_given_boxes(features, pred_instances) 126 | return pred_instances, {} 127 | 128 | 129 | @ROI_HEADS_REGISTRY.register() 130 | class ClipRes5ROIHeadsAttn(ClipRes5ROIHeads): 131 | def __init__(self, cfg, input_shape) -> None: 132 | super().__init__(cfg, input_shape) 133 | # self.res5 = None 134 | 135 | def _shared_roi_transform(self, features, boxes): 136 | x = self.pooler(features, boxes) 137 | return self.fwdres5(x) 138 | 139 | def forward( 140 | self, 141 | images: ImageList, 142 | features: Dict[str, torch.Tensor], 143 | proposals: List[Instances], 144 | targets: Optional[List[Instances]] = None, 145 | crops: Optional[List[Tuple]] = None, 146 | backbone = None 147 | ): 148 | """ 149 | See :meth:`ROIHeads.forward`. 150 | """ 151 | del images 152 | 153 | self.fwdres5 = backbone.forward_res5 154 | 155 | if self.training: 156 | assert targets 157 | proposals = self.label_and_sample_proposals(proposals, targets) 158 | # import pdb;pdb.set_trace() 159 | loss_crop_im = None 160 | if crops is not None: 161 | crop_im = list()#[x[0] for x in crops] #bxcropx3x224x224 162 | crop_boxes = list()#[x[1].to(self.device) for x in crops] #bxcropsx4 163 | keep = torch.ones(len(crops)).bool() 164 | 165 | for ind,x in enumerate(crops): 166 | if len(x) == 0: 167 | keep[ind] = False 168 | continue 169 | crop_im.append(x[0]) 170 | crop_boxes.append(x[1].to(self.device)) 171 | 172 | crops_features = self._shared_roi_transform( 173 | [features[f][keep] for f in self.in_features], crop_boxes) #(b*crops)x2048x7x7 174 | crops_features = backbone.attention_global_pool(crops_features) 175 | loss_crop_im, _ = self.clip_im_predictor.forward_crops(crop_im,crops_features) 176 | 177 | del targets 178 | 179 | proposal_boxes = [x.proposal_boxes for x in proposals] 180 | box_features = self._shared_roi_transform( 181 | [features[f] for f in self.in_features], proposal_boxes 182 | ) 183 | 184 | attn_feat = backbone.attention_global_pool(box_features) 185 | predictions = self.box_predictor([attn_feat,box_features.mean(dim=(2,3))]) 186 | # import pdb;pdb.set_trace() 187 | if self.training: 188 | del features 189 | 190 | losses = self.box_predictor.losses(predictions, proposals) 191 | # if torch.isnan(losses['loss_cls']): 192 | # import pdb;pdb.set_trace() 193 | 194 | if self.mask_on: 195 | proposals, fg_selection_masks = select_foreground_proposals( 196 | proposals, self.num_classes 197 | ) 198 | # Since the ROI feature transform is shared between boxes and masks, 199 | # we don't need to recompute features. The mask loss is only defined 200 | # on foreground proposals, so we need to select out the foreground 201 | # features. 202 | mask_features = box_features[torch.cat(fg_selection_masks, dim=0)] 203 | del box_features 204 | losses.update(self.mask_head(mask_features, proposals)) 205 | 206 | if loss_crop_im is not None: 207 | losses.update(loss_crop_im) 208 | return [], losses 209 | else: 210 | pred_instances, _ = self.box_predictor.inference(predictions, proposals) 211 | pred_instances = self.forward_with_given_boxes(features, pred_instances) 212 | return pred_instances, {} 213 | 214 | 215 | -------------------------------------------------------------------------------- /modeling/meta_arch.py: -------------------------------------------------------------------------------- 1 | from ast import mod 2 | import math 3 | import numpy as np 4 | import cv2 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | import torchvision.transforms as T 11 | 12 | from typing import Dict,List,Optional 13 | 14 | from detectron2.modeling import META_ARCH_REGISTRY, GeneralizedRCNN 15 | from detectron2.structures import ImageList, Instances, pairwise_iou 16 | from detectron2.utils.events import get_event_storage 17 | from detectron2.layers import batched_nms 18 | from detectron2.data.detection_utils import convert_image_to_rgb 19 | from detectron2.utils.visualizer import Visualizer 20 | 21 | 22 | @META_ARCH_REGISTRY.register() 23 | class ClipRCNNWithClipBackbone(GeneralizedRCNN): 24 | 25 | def __init__(self,cfg) -> None: 26 | super().__init__(cfg) 27 | self.cfg = cfg 28 | self.colors = self.generate_colors(7) 29 | self.backbone.set_backbone_model(self.roi_heads.box_predictor.cls_score.visual_enc) 30 | 31 | def preprocess_image(self, batched_inputs: List[Dict[str, torch.Tensor]]): 32 | """ 33 | Normalize, pad and batch the input images. 34 | """ 35 | clip_images = [x["image"].to(self.pixel_mean.device) for x in batched_inputs] 36 | mean=[0.48145466, 0.4578275, 0.40821073] 37 | std=[0.26862954, 0.26130258, 0.27577711] 38 | 39 | 40 | clip_images = [ T.functional.normalize(ci.flip(0)/255, mean,std) for ci in clip_images] 41 | clip_images = ImageList.from_tensors( 42 | [i for i in clip_images]) 43 | return clip_images 44 | 45 | 46 | def forward(self, batched_inputs): 47 | 48 | if not self.training: 49 | return self.inference(batched_inputs) 50 | 51 | images = self.preprocess_image(batched_inputs) 52 | b = images.tensor.shape[0]#batchsize 53 | 54 | if "instances" in batched_inputs[0]: 55 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 56 | 57 | features = self.backbone(images.tensor) 58 | 59 | if self.proposal_generator is not None: 60 | if self.training: 61 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 62 | else: 63 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 64 | else: 65 | assert "proposals" in batched_inputs[0] 66 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 67 | proposal_losses = {} 68 | 69 | try: 70 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, None, self.backbone) 71 | except Exception as e: 72 | print(e) 73 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, None) 74 | 75 | if self.vis_period > 0: 76 | storage = get_event_storage() 77 | if storage.iter % self.vis_period == 0: 78 | self.visualize_training(batched_inputs, proposals) 79 | with torch.no_grad(): 80 | ogimage = batched_inputs[0]['image'] 81 | ogimage = convert_image_to_rgb(ogimage.permute(1, 2, 0), self.input_format) 82 | o_pred = Visualizer(ogimage, None).overlay_instances().get_image() 83 | 84 | vis_img = o_pred.transpose(2, 0, 1) 85 | storage.put_image('og-tfimage', vis_img) 86 | 87 | losses = {} 88 | losses.update(detector_losses) 89 | losses.update(proposal_losses) 90 | return losses 91 | 92 | def generate_colors(self,N): 93 | import colorsys 94 | ''' 95 | Generate random colors. 96 | To get visually distinct colors, generate them in HSV space then 97 | convert to RGB. 98 | ''' 99 | brightness = 0.7 100 | hsv = [(i / N, 1, brightness) for i in range(N)] 101 | colors = list(map(lambda c: tuple(round(i * 255) for i in colorsys.hsv_to_rgb(*c)), hsv)) 102 | perm = np.arange(7) 103 | colors = [colors[idx] for idx in perm] 104 | return colors 105 | 106 | 107 | def inference( 108 | self, 109 | batched_inputs: List[Dict[str, torch.Tensor]], 110 | detected_instances: Optional[List[Instances]] = None, 111 | do_postprocess: bool = True, 112 | ): 113 | """ 114 | Run inference on the given inputs. 115 | Args: 116 | batched_inputs (list[dict]): same as in :meth:`forward` 117 | detected_instances (None or list[Instances]): if not None, it 118 | contains an `Instances` object per image. The `Instances` 119 | object contains "pred_boxes" and "pred_classes" which are 120 | known boxes in the image. 121 | The inference will then skip the detection of bounding boxes, 122 | and only predict other per-ROI outputs. 123 | do_postprocess (bool): whether to apply post-processing on the outputs. 124 | Returns: 125 | When do_postprocess=True, same as in :meth:`forward`. 126 | Otherwise, a list[Instances] containing raw network outputs. 127 | """ 128 | assert not self.training 129 | 130 | images = self.preprocess_image(batched_inputs) 131 | features = self.backbone(images.tensor) 132 | 133 | if detected_instances is None: 134 | if self.proposal_generator is not None: 135 | logits,proposals, _ = self.proposal_generator(images, features, None) 136 | else: 137 | assert "proposals" in batched_inputs[0] 138 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 139 | 140 | # boxes = batched_inputs[0]['instances'].gt_boxes.to(images.tensor.device) 141 | # logits = 10*torch.ones(len(boxes)).to(images.tensor.device) 142 | # dictp = {'proposal_boxes':boxes,'objectness_logits':logits} 143 | # new_p = Instances(batched_inputs[0]['instances'].image_size,**dictp) 144 | # proposals = [new_p] 145 | 146 | try: 147 | results, _ = self.roi_heads(images, features, proposals, None, None, self.backbone) 148 | except: 149 | results, _ = self.roi_heads(images, features, proposals, None, None) 150 | else: 151 | detected_instances = [x.to(self.device) for x in detected_instances] 152 | results = self.roi_heads.forward_with_given_boxes(features, detected_instances) 153 | 154 | 155 | if do_postprocess: 156 | assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess." 157 | 158 | allresults = GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes) 159 | 160 | 161 | return allresults 162 | else: 163 | return results 164 | 165 | 166 | @META_ARCH_REGISTRY.register() 167 | class ClipRCNNWithClipBackboneWithOffsetGenTrainable(ClipRCNNWithClipBackbone): 168 | 169 | def __init__(self,cfg) -> None: 170 | super().__init__(cfg) 171 | 172 | domain_text = {'day': 'an image taken during the day'} 173 | with open('prunedprompts2.txt','r') as f: 174 | for ind,l in enumerate(f): 175 | domain_text.update({str(ind):l.strip()}) 176 | # self.offsets = nn.Parameter(offsets) 177 | self.offsets = nn.Parameter(torch.zeros(len(domain_text)-1,1024,14,14)) #skip day 178 | 179 | import clip 180 | self.domain_tk = dict([(k,clip.tokenize(t)) for k,t in domain_text.items()]) 181 | self.apply_aug = cfg.AUG_PROB 182 | 183 | def forward(self, batched_inputs): 184 | 185 | if not self.training: 186 | return self.inference(batched_inputs) 187 | 188 | images = self.preprocess_image(batched_inputs) 189 | b = images.tensor.shape[0]#batchsize 190 | 191 | if "instances" in batched_inputs[0]: 192 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 193 | 194 | features = self.backbone(images.tensor) 195 | 196 | if np.random.rand(1) >self.apply_aug: 197 | oids = np.random.choice(np.arange(len(self.offsets)),b) 198 | change = torch.cat([self.offsets[oid:oid+1].cuda().mean(dim=(2,3),keepdims=True) for oid in oids ],0) 199 | features['res4']=features['res4']+ change 200 | 201 | if self.proposal_generator is not None: 202 | if self.training: 203 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 204 | else: 205 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 206 | else: 207 | assert "proposals" in batched_inputs[0] 208 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 209 | proposal_losses = {} 210 | 211 | try: 212 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, None, self.backbone) 213 | except Exception as e: 214 | print(e) 215 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, None) 216 | 217 | if self.vis_period > 0: 218 | storage = get_event_storage() 219 | if storage.iter % self.vis_period == 0: 220 | self.visualize_training(batched_inputs, proposals) 221 | with torch.no_grad(): 222 | ogimage = batched_inputs[0]['image'] 223 | ogimage = convert_image_to_rgb(ogimage.permute(1, 2, 0), self.input_format) 224 | o_pred = Visualizer(ogimage, None).overlay_instances().get_image() 225 | 226 | vis_img = o_pred.transpose(2, 0, 1) 227 | storage.put_image('og-tfimage', vis_img) 228 | 229 | losses = {} 230 | losses.update(detector_losses) 231 | losses.update(proposal_losses) 232 | return losses 233 | 234 | def opt_offsets(self, batched_inputs): 235 | 236 | crops_clip = None 237 | if 'randomcrops' in batched_inputs[0]: 238 | rcrops = [x['randomcrops'] for x in batched_inputs] 239 | rcrops = torch.cat(rcrops,0) 240 | crops_clip = rcrops.flip(1)/255 241 | mean=[0.48145466, 0.4578275, 0.40821073] 242 | std=[0.26862954, 0.26130258, 0.27577711] 243 | crops_clip = T.functional.normalize(crops_clip,mean,std) 244 | crops_clip = crops_clip.cuda() 245 | 246 | with torch.no_grad(): 247 | features = self.backbone(crops_clip) 248 | 249 | losses = {} 250 | total_dist = 0 251 | total_reg = 0 252 | total_chgn = 0 253 | for i,val in enumerate(self.domain_tk.items()): 254 | name , dtk = val 255 | if name == 'day': 256 | continue 257 | with torch.no_grad(): 258 | 259 | # print(self.backbone.forward_res5(features['res4'])) 260 | wo_aug_im_embed = self.backbone.attention_global_pool(self.backbone.forward_res5(features['res4'])) 261 | wo_aug_im_embed = wo_aug_im_embed/wo_aug_im_embed.norm(dim=-1,keepdim=True) 262 | 263 | day_text_embed = self.roi_heads.box_predictor.cls_score.model.encode_text(self.domain_tk['day'].cuda()) #day 264 | day_text_embed = day_text_embed/day_text_embed.norm(dim=-1,keepdim=True) 265 | new_text_embed = self.roi_heads.box_predictor.cls_score.model.encode_text(dtk.cuda() ) #new_d 266 | new_text_embed = new_text_embed/new_text_embed.norm(dim=-1,keepdim=True) 267 | text_off = (new_text_embed - day_text_embed) 268 | text_off = text_off/text_off.norm(dim=-1,keepdim=True) 269 | 270 | wo_aug_im_tsl = wo_aug_im_embed + text_off 271 | wo_aug_im_tsl = wo_aug_im_tsl/wo_aug_im_tsl.norm(dim=-1,keepdim=True) 272 | wo_aug_im_tsl = wo_aug_im_tsl.unsqueeze(1).permute(0,2,1) 273 | 274 | 275 | aug_feat = features['res4'].detach()+self.offsets[i-1:i] 276 | 277 | 278 | x = self.backbone.forward_res5(aug_feat) 279 | im_embed = self.backbone.attention_global_pool(x) 280 | 281 | im_embed = im_embed/im_embed.norm(dim=-1,keepdim=True) 282 | 283 | cos_dist = 1 - im_embed.unsqueeze(1).bmm(wo_aug_im_tsl) 284 | 285 | dist_loss = cos_dist.mean() 286 | 287 | l1loss = torch.nn.functional.l1_loss(im_embed,wo_aug_im_embed) 288 | 289 | 290 | total_dist += dist_loss 291 | total_reg += l1loss 292 | 293 | losses.update({ f'cos_dist_loss_{name}': total_dist/len(self.domain_tk),f'reg_loss_{name}': total_reg/len(self.domain_tk)}) 294 | 295 | return losses 296 | 297 | 298 | @META_ARCH_REGISTRY.register() 299 | class ClipRCNNWithClipBackboneWithOffsetGenTrainableVOC(ClipRCNNWithClipBackbone): 300 | 301 | def __init__(self,cfg) -> None: 302 | super().__init__(cfg) 303 | 304 | domain_text = {'real': 'a realistic image'} 305 | 306 | domain_text.update({str(0):'an image in the comics style'}) 307 | domain_text.update({str(1):'an image in the painting style'}) 308 | domain_text.update({str(2):'an image in the cartoon style'}) 309 | domain_text.update({str(3):'an image in the digital-art style'}) 310 | domain_text.update({str(4):'an image in the sketch style'}) 311 | domain_text.update({str(5):'an image in the watercolor painting style'}) 312 | domain_text.update({str(6):'an image in the oil painting style'}) 313 | # self.offsets = nn.Parameter(offsets) 314 | self.offsets = nn.Parameter(torch.zeros(len(domain_text)-1,1024,14,14)) #skip day 315 | 316 | import clip 317 | self.domain_tk = dict([(k,clip.tokenize(t)) for k,t in domain_text.items()]) 318 | self.apply_aug = cfg.AUG_PROB 319 | 320 | def forward(self, batched_inputs): 321 | 322 | if not self.training: 323 | return self.inference(batched_inputs) 324 | 325 | images = self.preprocess_image(batched_inputs) 326 | b = images.tensor.shape[0]#batchsize 327 | 328 | if "instances" in batched_inputs[0]: 329 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 330 | 331 | features = self.backbone(images.tensor) 332 | 333 | if np.random.rand(1) >self.apply_aug: 334 | 335 | oids = np.random.choice(np.arange(len(self.offsets)),b) 336 | change = torch.cat([self.offsets[oid:oid+1].cuda().mean(dim=(2,3),keepdims=True) for oid in oids ],0) 337 | features['res4']=features['res4']+ change 338 | 339 | if self.proposal_generator is not None: 340 | if self.training: 341 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 342 | else: 343 | logits, proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 344 | else: 345 | assert "proposals" in batched_inputs[0] 346 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 347 | proposal_losses = {} 348 | 349 | try: 350 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, None, self.backbone) 351 | except Exception as e: 352 | print(e) 353 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, None) 354 | 355 | if self.vis_period > 0: 356 | storage = get_event_storage() 357 | if storage.iter % self.vis_period == 0: 358 | self.visualize_training(batched_inputs, proposals) 359 | with torch.no_grad(): 360 | ogimage = batched_inputs[0]['image'] 361 | ogimage = convert_image_to_rgb(ogimage.permute(1, 2, 0), self.input_format) 362 | o_pred = Visualizer(ogimage, None).overlay_instances().get_image() 363 | 364 | vis_img = o_pred.transpose(2, 0, 1) 365 | storage.put_image('og-tfimage', vis_img) 366 | 367 | losses = {} 368 | losses.update(detector_losses) 369 | losses.update(proposal_losses) 370 | return losses 371 | 372 | def opt_offsets(self, batched_inputs): 373 | 374 | crops_clip = None 375 | if 'randomcrops' in batched_inputs[0]: 376 | rcrops = [x['randomcrops'] for x in batched_inputs] 377 | rcrops = torch.cat(rcrops,0) 378 | crops_clip = rcrops.flip(1)/255 379 | mean=[0.48145466, 0.4578275, 0.40821073] 380 | std=[0.26862954, 0.26130258, 0.27577711] 381 | crops_clip = T.functional.normalize(crops_clip,mean,std) 382 | crops_clip = crops_clip.cuda() 383 | 384 | with torch.no_grad(): 385 | features = self.backbone(crops_clip) 386 | 387 | losses = {} 388 | total_dist = 0 389 | total_reg = 0 390 | total_chgn = 0 391 | for i,val in enumerate(self.domain_tk.items()): 392 | name , dtk = val 393 | if name == 'real': 394 | continue 395 | with torch.no_grad(): 396 | 397 | # print(self.backbone.forward_res5(features['res4'])) 398 | wo_aug_im_embed = self.backbone.attention_global_pool(self.backbone.forward_res5(features['res4'])) 399 | wo_aug_im_embed = wo_aug_im_embed/wo_aug_im_embed.norm(dim=-1,keepdim=True) 400 | 401 | day_text_embed = self.roi_heads.box_predictor.cls_score.model.encode_text(self.domain_tk['real'].cuda()) #day 402 | day_text_embed = day_text_embed/day_text_embed.norm(dim=-1,keepdim=True) 403 | new_text_embed = self.roi_heads.box_predictor.cls_score.model.encode_text(dtk.cuda() ) #new_d 404 | new_text_embed = new_text_embed/new_text_embed.norm(dim=-1,keepdim=True) 405 | text_off = (new_text_embed - day_text_embed) 406 | text_off = text_off/text_off.norm(dim=-1,keepdim=True) 407 | 408 | wo_aug_im_tsl = wo_aug_im_embed + text_off 409 | wo_aug_im_tsl = wo_aug_im_tsl/wo_aug_im_tsl.norm(dim=-1,keepdim=True) 410 | wo_aug_im_tsl = wo_aug_im_tsl.unsqueeze(1).permute(0,2,1) 411 | 412 | 413 | aug_feat = features['res4'].detach()+self.offsets[i-1:i] 414 | 415 | 416 | x = self.backbone.forward_res5(aug_feat) 417 | im_embed = self.backbone.attention_global_pool(x) 418 | 419 | im_embed = im_embed/im_embed.norm(dim=-1,keepdim=True) 420 | 421 | cos_dist = 1 - im_embed.unsqueeze(1).bmm(wo_aug_im_tsl) 422 | 423 | dist_loss = cos_dist.mean() 424 | 425 | l1loss = torch.nn.functional.l1_loss(im_embed,wo_aug_im_embed) 426 | 427 | 428 | total_dist += dist_loss 429 | total_reg += l1loss 430 | 431 | losses.update({ f'cos_dist_loss_{name}': total_dist/len(self.domain_tk),f'reg_loss_{name}': total_reg/len(self.domain_tk)}) 432 | import pdb;pdb.set_trace() 433 | return losses 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from cgi import parse_multipart 2 | import os 3 | import logging 4 | import time 5 | from collections import OrderedDict, Counter 6 | import copy 7 | 8 | import numpy as np 9 | 10 | import torch 11 | from torch import autograd 12 | import torch.utils.data as torchdata 13 | 14 | from detectron2 import model_zoo 15 | from detectron2.config import get_cfg 16 | from detectron2.layers.batch_norm import FrozenBatchNorm2d 17 | from detectron2.engine import DefaultPredictor, DefaultTrainer, default_setup 18 | from detectron2.engine import default_argument_parser, hooks, HookBase 19 | from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping, build_lr_scheduler 20 | from detectron2.checkpoint import DetectionCheckpointer 21 | from detectron2.data import build_detection_train_loader, build_detection_test_loader, get_detection_dataset_dicts 22 | from detectron2.data.common import DatasetFromList, MapDataset 23 | from detectron2.data.samplers import InferenceSampler 24 | from detectron2.utils.events import get_event_storage 25 | 26 | from detectron2.utils import comm 27 | from detectron2.evaluation import COCOEvaluator, verify_results, inference_on_dataset, print_csv_format 28 | 29 | from detectron2.solver import LRMultiplier 30 | from detectron2.modeling import build_model 31 | from detectron2.structures import ImageList, Instances, pairwise_iou, Boxes 32 | 33 | from fvcore.common.param_scheduler import ParamScheduler 34 | from fvcore.common.checkpoint import Checkpointer 35 | 36 | from data.datasets import builtin 37 | 38 | from detectron2.evaluation import PascalVOCDetectionEvaluator, COCOEvaluator, inference_on_dataset 39 | 40 | from detectron2.data import build_detection_train_loader, MetadataCatalog 41 | import torch.utils.data as data 42 | from detectron2.data.dataset_mapper import DatasetMapper 43 | import detectron2.data.detection_utils as utils 44 | import detectron2.data.transforms as detT 45 | 46 | import torchvision.transforms as T 47 | import torchvision.transforms.functional as tF 48 | 49 | from modeling import add_stn_config 50 | from modeling import CustomPascalVOCDetectionEvaluator 51 | 52 | logger = logging.getLogger("detectron2") 53 | 54 | def setup(args): 55 | cfg = get_cfg() 56 | add_stn_config(cfg) 57 | #hack to add base yaml 58 | cfg.merge_from_file(args.config_file) 59 | cfg.merge_from_file(model_zoo.get_config_file(cfg.BASE_YAML)) 60 | cfg.merge_from_file(args.config_file) 61 | cfg.merge_from_list(args.opts) 62 | #cfg.freeze() 63 | default_setup(cfg, args) 64 | return cfg 65 | 66 | class CustomDatasetMapper(DatasetMapper): 67 | def __init__(self,cfg,is_train) -> None: 68 | super().__init__(cfg,is_train) 69 | self.with_crops = cfg.INPUT.CLIP_WITH_IMG 70 | self.with_random_clip_crops = cfg.INPUT.CLIP_RANDOM_CROPS 71 | self.with_jitter = cfg.INPUT.IMAGE_JITTER 72 | self.cropfn = T.RandomCrop#T.RandomCrop([224,224]) 73 | self.aug = T.ColorJitter(brightness=.5, hue=.3) 74 | self.crop_size = cfg.INPUT.RANDOM_CROP_SIZE 75 | 76 | def __call__(self,dataset_dict): 77 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 78 | # USER: Write your own image loading if it's not from a file 79 | image = utils.read_image(dataset_dict["file_name"], format=self.image_format) 80 | utils.check_image_size(dataset_dict, image) 81 | 82 | # USER: Remove if you don't do semantic/panoptic segmentation. 83 | if "sem_seg_file_name" in dataset_dict: 84 | sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) 85 | else: 86 | sem_seg_gt = None 87 | 88 | aug_input = detT.AugInput(image, sem_seg=sem_seg_gt) 89 | transforms = self.augmentations(aug_input) 90 | image, sem_seg_gt = aug_input.image, aug_input.sem_seg 91 | 92 | image_shape = image.shape[:2] # h, w 93 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 94 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 95 | # Therefore it's important to use torch.Tensor. 96 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 97 | if sem_seg_gt is not None: 98 | dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) 99 | 100 | # USER: Remove if you don't use pre-computed proposals. 101 | # Most users would not need this feature. 102 | if self.proposal_topk is not None: 103 | utils.transform_proposals( 104 | dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk 105 | ) 106 | 107 | if not self.is_train: 108 | # USER: Modify this if you want to keep them for some reason. 109 | dataset_dict.pop("annotations", None) 110 | dataset_dict.pop("sem_seg_file_name", None) 111 | return dataset_dict 112 | 113 | if "annotations" in dataset_dict: 114 | self._transform_annotations(dataset_dict, transforms, image_shape) 115 | 116 | if self.with_jitter: 117 | dataset_dict["jitter_image"] = self.aug(dataset_dict["image"]) 118 | 119 | if self.with_crops: 120 | bbox = dataset_dict['instances'].gt_boxes.tensor 121 | csx = (bbox[:,0] + bbox[:,2])*0.5 122 | csy = (bbox[:,1] + bbox[:,3])*0.5 123 | maxwh = torch.maximum(bbox[:,2]-bbox[:,0],bbox[:,3]-bbox[:,1]) 124 | crops = list() 125 | gt_boxes = list() 126 | mean=[0.48145466, 0.4578275, 0.40821073] 127 | std=[0.26862954, 0.26130258, 0.27577711] 128 | for cx,cy,maxdim,label,box in zip(csx,csy,maxwh,dataset_dict['instances'].gt_classes, bbox): 129 | 130 | if int(maxdim) < 10: 131 | continue 132 | x0 = torch.maximum(cx-maxdim*0.5,torch.tensor(0)) 133 | y0 = torch.maximum(cy-maxdim*0.5,torch.tensor(0)) 134 | try: 135 | imcrop = T.functional.resized_crop(dataset_dict['image'],top=int(y0),left=int(x0),height=int(maxdim),width=int(maxdim),size=224) 136 | imcrop = imcrop.flip(0)/255 # bgr --> rgb for clip 137 | imcrop = T.functional.normalize(imcrop,mean,std) 138 | # print(x0,y0,x0+maxdim,y0+maxdim,dataset_dict['image'].shape) 139 | # print(imcrop.min(),imcrop.max() ) 140 | gt_boxes.append(box.reshape(1,-1)) 141 | except Exception as e: 142 | print(e) 143 | print('crops:',x0,y0,maxdim) 144 | exit() 145 | # crops.append((imcrop,label)) 146 | crops.append(imcrop.unsqueeze(0)) 147 | 148 | if len(crops) == 0: 149 | dataset_dict['crops'] = [] 150 | else: 151 | dataset_dict['crops'] = [torch.cat(crops,0),Boxes(torch.cat(gt_boxes,0))] 152 | 153 | if self.with_random_clip_crops: 154 | crops = [] 155 | rbboxs = [] 156 | 157 | for i in range(15): 158 | p = self.cropfn.get_params(dataset_dict['image'],[self.crop_size,self.crop_size]) 159 | c = tF.crop(dataset_dict['image'],*p) 160 | if self.crop_size != 224: 161 | c = tF.resize(img=c,size=224) 162 | crops.append(c) 163 | rbboxs.append(p) 164 | 165 | crops = torch.stack(crops) 166 | dataset_dict['randomcrops'] = crops 167 | 168 | #apply same crop bbox to the jittered image 169 | if self.with_jitter: 170 | jitter_crops = [] 171 | for p in rbboxs: 172 | jc = tF.crop(dataset_dict['jitter_image'],*p) 173 | if self.crop_size != 224: 174 | jc = tF.resize(img=jc,size=224) 175 | jitter_crops.append(jc) 176 | 177 | jcrops = torch.stack(jitter_crops) 178 | dataset_dict['jitter_randomcrops'] = jcrops 179 | 180 | 181 | 182 | return dataset_dict 183 | 184 | class CombineLoaders(data.IterableDataset): 185 | def __init__(self,loaders): 186 | self.loaders = loaders 187 | 188 | def __iter__(self,): 189 | dd = iter(self.loaders[1]) 190 | for d1 in self.loaders[0]: 191 | try: 192 | d2 = next(dd) 193 | except: 194 | dd=iter(self.loaders[1]) 195 | d2 = next(dd) 196 | 197 | list_out_dict=[] 198 | for v1,v2 in zip(d1,d2): 199 | out_dict = {} 200 | for k in v1.keys(): 201 | out_dict[k] = (v1[k],v2[k]) 202 | list_out_dict.append(out_dict) 203 | 204 | yield list_out_dict 205 | 206 | 207 | class Trainer(DefaultTrainer): 208 | 209 | def __init__(self,cfg) -> None: 210 | super().__init__(cfg) 211 | self.teach_model = None 212 | self.off_opt_interval = np.arange(0,cfg.SOLVER.MAX_ITER,cfg.OFFSET_OPT_INTERVAL[0]).tolist() 213 | self.off_opt_iters = cfg.OFFSET_OPT_ITERS 214 | 215 | @classmethod 216 | def build_model(cls, cfg): 217 | """ 218 | Returns: 219 | torch.nn.Module: 220 | It now calls :func:`detectron2.modeling.build_model`. 221 | Overwrite it if you'd like a different model. 222 | """ 223 | model = build_model(cfg) 224 | 225 | 226 | logger = logging.getLogger(__name__) 227 | logger.info("Model:\n{}".format(model)) 228 | 229 | return model 230 | 231 | @classmethod 232 | def build_train_loader(cls,cfg): 233 | original = cfg.DATASETS.TRAIN 234 | print(original) 235 | cfg.DATASETS.TRAIN=(original[0],) 236 | data_loader1 = build_detection_train_loader(cfg, mapper=CustomDatasetMapper(cfg, True)) 237 | return data_loader1 238 | 239 | @classmethod 240 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 241 | if output_folder is None: 242 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 243 | if MetadataCatalog.get(dataset_name).evaluator_type == 'pascal_voc': 244 | return CustomPascalVOCDetectionEvaluator(dataset_name) 245 | else: 246 | return COCOEvaluator(dataset_name, output_dir=output_folder) 247 | 248 | @classmethod 249 | def build_optimizer(cls,cfg,model): 250 | 251 | trainable = {'others':[],'offset':[]} 252 | 253 | for name,val in model.named_parameters(): 254 | head = name.split('.')[0] 255 | #previously was setting all params to be true 256 | if val.requires_grad == True: 257 | print(name) 258 | if 'offset' in name: 259 | trainable['offset'].append(val) 260 | else: 261 | trainable['others'].append(val) 262 | 263 | optimizer1 = torch.optim.SGD( 264 | trainable['others'], 265 | cfg.SOLVER.BASE_LR, 266 | momentum=cfg.SOLVER.MOMENTUM, 267 | nesterov=cfg.SOLVER.NESTEROV, 268 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 269 | ) 270 | 271 | optimizer2 = torch.optim.Adam( 272 | trainable['offset'], 273 | 0.01, 274 | ) 275 | return (maybe_add_gradient_clipping(cfg, optimizer1),maybe_add_gradient_clipping(cfg, optimizer2)) 276 | 277 | 278 | def run_step(self): 279 | """ 280 | Implement the standard training logic described above. 281 | """ 282 | 283 | 284 | assert self.model.training, "[SimpleTrainer] model was changed to eval mode!" 285 | start = time.perf_counter() 286 | """ 287 | If you want to do something with the data, you can wrap the dataloader. 288 | """ 289 | data = next(self._trainer._data_loader_iter) 290 | data_time = time.perf_counter() - start 291 | 292 | """ 293 | If you want to do something with the losses, you can wrap the model. 294 | """ 295 | data_s = data 296 | 297 | opt_phase = False 298 | if len(self.off_opt_interval) and self.iter >= self.off_opt_interval[0] and self.iter < self.off_opt_interval[0]+self.off_opt_iters: 299 | 300 | if self.iter == self.off_opt_interval[0]: 301 | self.model.offsets.data = torch.zeros(self.model.offsets.shape).cuda() 302 | loss_dict_s = self.model.opt_offsets(data_s) 303 | opt_phase = True 304 | if self.iter+1 == self.off_opt_interval[0]+self.off_opt_iters: 305 | self.off_opt_interval.pop(0) 306 | 307 | else: 308 | # for ind, d in enumerate(data_s): 309 | # d['image'] = self.aug(d['image'].cuda()) 310 | loss_dict_s = self.model(data_s) 311 | # print(loss_dict_s) 312 | 313 | # import pdb;pdb.set_trace() 314 | loss_dict = {} 315 | 316 | loss = 0 317 | for k,v in loss_dict_s.items(): 318 | loss += v 319 | 320 | 321 | """ 322 | If you need to accumulate gradients or do something similar, you can 323 | wrap the optimizer with your custom `zero_grad()` method. 324 | """ 325 | self.optimizer[0].zero_grad() 326 | self.optimizer[1].zero_grad() 327 | 328 | loss.backward() 329 | 330 | if not opt_phase: 331 | self.optimizer[0].step() 332 | else: 333 | self.optimizer[1].step() 334 | 335 | self.optimizer[0].zero_grad() 336 | self.optimizer[1].zero_grad() 337 | 338 | for k,v in loss_dict_s.items(): 339 | loss_dict.update({k:v}) 340 | 341 | # print(loss_di ct) 342 | self._trainer._write_metrics(loss_dict, data_time) 343 | """ 344 | If you need gradient clipping/scaling or other processing, you can 345 | wrap the optimizer with your custom `step()` method. But it is 346 | suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4 347 | """ 348 | 349 | def build_hooks(self): 350 | """ 351 | Build a list of default hooks, including timing, evaluation, 352 | checkpointing, lr scheduling, precise BN, writing events. 353 | Returns: 354 | list[HookBase]: 355 | """ 356 | cfg = self.cfg.clone() 357 | cfg.defrost() 358 | cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN 359 | 360 | ret = [ 361 | hooks.IterationTimer(), 362 | LRScheduler(), 363 | hooks.PreciseBN( 364 | # Run at the same freq as (but before) evaluation. 365 | cfg.TEST.EVAL_PERIOD, 366 | self.model, 367 | # Build a new data loader to not affect training 368 | self.build_train_loader(cfg), 369 | cfg.TEST.PRECISE_BN.NUM_ITER, 370 | ) 371 | if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) 372 | else None, 373 | ] 374 | 375 | # Do PreciseBN before checkpointer, because it updates the model and need to 376 | # be saved by checkpointer. 377 | # This is not always the best: if checkpointing has a different frequency, 378 | # some checkpoints may have more precise statistics than others. 379 | if comm.is_main_process(): 380 | ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) 381 | 382 | def test_and_save_results(): 383 | self._last_eval_results = self.test(self.cfg, self.model) 384 | return self._last_eval_results 385 | 386 | def do_test_st(flag): 387 | if flag == 'st': 388 | model = self.model 389 | else: 390 | print("Error in the flag") 391 | 392 | results = OrderedDict() 393 | for dataset_name in self.cfg.DATASETS.TEST: 394 | data_loader = build_detection_test_loader(self.cfg, dataset_name) 395 | evaluator = CustomPascalVOCDetectionEvaluator(dataset_name) 396 | results_i = inference_on_dataset(model, data_loader, evaluator) 397 | results[dataset_name] = results_i 398 | if comm.is_main_process(): 399 | logger.info("Evaluation results for {} in csv format:".format(dataset_name)) 400 | print_csv_format(results_i) 401 | storage = get_event_storage() 402 | storage.put_scalar(f'{dataset_name}_AP50', results_i['bbox']['AP50'],smoothing_hint=False) 403 | if len(results) == 1: 404 | results = list(results.values())[0] 405 | return results 406 | 407 | 408 | # Do evaluation after checkpointer, because then if it fails, 409 | # we can use the saved checkpoint to debug. 410 | ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) 411 | ret.append(hooks.EvalHook(cfg.TEST.EVAL_SAVE_PERIOD, lambda flag='st': do_test_st(flag))) 412 | 413 | if comm.is_main_process(): 414 | # Here the default print/log frequency of each writer is used. 415 | # run writers in the end, so that evaluation metrics are written 416 | ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) 417 | return ret 418 | 419 | @classmethod 420 | def build_lr_scheduler(cls, cfg, optimizer): 421 | """ 422 | It now calls :func:`detectron2.solver.build_lr_scheduler`. 423 | Overwrite it if you'd like a different scheduler. 424 | """ 425 | 426 | return build_lr_scheduler(cfg, optimizer[0]) 427 | 428 | def state_dict(self): 429 | ret = super().state_dict() 430 | ret["optimizer1"] = self.optimizer[0].state_dict() 431 | ret["optimizer2"] = self.optimizer[1].state_dict() 432 | return ret 433 | 434 | def load_state_dict(self, state_dict): 435 | super().load_state_dict(state_dict) 436 | self.optimizer[0].load_state_dict(state_dict["optimizer1"]) 437 | self.optimizer[1].load_state_dict(state_dict["optimizer2"]) 438 | 439 | 440 | 441 | class LRScheduler(HookBase): 442 | """ 443 | A hook which executes a torch builtin LR scheduler and summarizes the LR. 444 | It is executed after every iteration. 445 | """ 446 | 447 | def __init__(self, optimizer=None, scheduler=None): 448 | """ 449 | Args: 450 | optimizer (torch.optim.Optimizer): 451 | scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler): 452 | if a :class:`ParamScheduler` object, it defines the multiplier over the base LR 453 | in the optimizer. 454 | If any argument is not given, will try to obtain it from the trainer. 455 | """ 456 | self._optimizer = optimizer 457 | self._scheduler = scheduler 458 | 459 | def before_train(self): 460 | self._optimizer = self._optimizer or self.trainer.optimizer 461 | if isinstance(self.scheduler, ParamScheduler): 462 | self._scheduler = LRMultiplier( 463 | self._optimizer, 464 | self.scheduler, 465 | self.trainer.max_iter, 466 | last_iter=self.trainer.iter - 1, 467 | ) 468 | self._best_param_group_id1 = LRScheduler.get_best_param_group_id(self._optimizer[0]) 469 | self._best_param_group_id2 = LRScheduler.get_best_param_group_id(self._optimizer[1]) 470 | 471 | 472 | @staticmethod 473 | def get_best_param_group_id(optimizer): 474 | # NOTE: some heuristics on what LR to summarize 475 | # summarize the param group with most parameters 476 | largest_group = max(len(g["params"]) for g in optimizer.param_groups) 477 | 478 | if largest_group == 1: 479 | # If all groups have one parameter, 480 | # then find the most common initial LR, and use it for summary 481 | lr_count = Counter([g["lr"] for g in optimizer.param_groups]) 482 | lr = lr_count.most_common()[0][0] 483 | for i, g in enumerate(optimizer.param_groups): 484 | if g["lr"] == lr: 485 | return i 486 | else: 487 | for i, g in enumerate(optimizer.param_groups): 488 | if len(g["params"]) == largest_group: 489 | return i 490 | 491 | def after_step(self): 492 | lr1 = self._optimizer[0].param_groups[self._best_param_group_id1]["lr"] 493 | self.trainer.storage.put_scalar("lr1", lr1, smoothing_hint=False) 494 | 495 | lr2 = self._optimizer[1].param_groups[self._best_param_group_id2]["lr"] 496 | self.trainer.storage.put_scalar("lr2", lr2, smoothing_hint=False) 497 | 498 | self.scheduler.step() 499 | 500 | @property 501 | def scheduler(self): 502 | return self._scheduler or self.trainer.scheduler 503 | 504 | def state_dict(self): 505 | if isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler): 506 | return self.scheduler.state_dict() 507 | return {} 508 | 509 | def load_state_dict(self, state_dict): 510 | if isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler): 511 | logger = logging.getLogger(__name__) 512 | logger.info("Loading scheduler from state_dict ...") 513 | self.scheduler.load_state_dict(state_dict) 514 | 515 | def custom_build_detection_test_loader(cfg,dataset_name,mapper=None): 516 | 517 | if isinstance(dataset_name, str): 518 | dataset_name = [dataset_name] 519 | 520 | dataset = get_detection_dataset_dicts( 521 | dataset_name, 522 | filter_empty=False, 523 | proposal_files=[ 524 | cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name 525 | ] 526 | if cfg.MODEL.LOAD_PROPOSALS 527 | else None, 528 | ) 529 | if mapper is None: 530 | mapper = DatasetMapper(cfg, False) 531 | 532 | if isinstance(dataset, list): 533 | dataset = DatasetFromList(dataset, copy=False) 534 | if mapper is not None: 535 | dataset = MapDataset(dataset, mapper) 536 | 537 | sampler = None 538 | if isinstance(dataset, torchdata.IterableDataset): 539 | assert sampler is None, "sampler must be None if dataset is IterableDataset" 540 | else: 541 | if sampler is None: 542 | sampler = InferenceSampler(len(dataset)) 543 | collate_fn = None 544 | 545 | def trivial_batch_collator(batch): 546 | """ 547 | A batch collator that does nothing. 548 | """ 549 | return batch 550 | 551 | return torchdata.DataLoader( 552 | dataset, 553 | batch_size=1, 554 | sampler=sampler, 555 | drop_last=False, 556 | num_workers=cfg.DATALOADER.NUM_WORKERS, 557 | collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, 558 | ) 559 | 560 | 561 | def do_test(cfg, model, model_type=''): 562 | results = OrderedDict() 563 | for dataset_name in cfg.DATASETS.TEST: 564 | data_loader = build_detection_test_loader(cfg, dataset_name) 565 | evaluator = CustomPascalVOCDetectionEvaluator(dataset_name) 566 | results_i = inference_on_dataset(model, data_loader, evaluator) 567 | results[dataset_name] = results_i 568 | if comm.is_main_process(): 569 | logger.info("Evaluation results for {} in csv format:".format(dataset_name)) 570 | print_csv_format(results_i) 571 | 572 | if len(results) == 1: 573 | results = list(results.values())[0] 574 | return results 575 | 576 | def main(args): 577 | cfg = setup(args) 578 | if args.eval_only: 579 | model = Trainer.build_model(cfg) 580 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 581 | cfg.MODEL.WEIGHTS, resume=args.resume 582 | ) 583 | return do_test(cfg,model) 584 | trainer = Trainer(cfg) 585 | trainer.resume_or_load(resume=args.resume) 586 | for dataset_name in cfg.DATASETS.TEST: 587 | if 'daytime_clear_test' in dataset_name : 588 | trainer.register_hooks([ 589 | hooks.BestCheckpointer(cfg.TEST.EVAL_SAVE_PERIOD,trainer.checkpointer,f'{dataset_name}_AP50',file_prefix='model_best'), 590 | ]) 591 | 592 | trainer.train() 593 | 594 | 595 | if __name__ == "__main__": 596 | args = default_argument_parser().parse_args() 597 | cfg = setup(args) 598 | print("Command Line Args:", args) 599 | 600 | main(args) 601 | -------------------------------------------------------------------------------- /train_voc.py: -------------------------------------------------------------------------------- 1 | from cgi import parse_multipart 2 | import os 3 | import logging 4 | import time 5 | from collections import OrderedDict, Counter 6 | import copy 7 | 8 | import numpy as np 9 | 10 | import torch 11 | from torch import autograd 12 | import torch.utils.data as torchdata 13 | 14 | from detectron2 import model_zoo 15 | from detectron2.config import get_cfg 16 | from detectron2.layers.batch_norm import FrozenBatchNorm2d 17 | from detectron2.engine import DefaultPredictor, DefaultTrainer, default_setup 18 | from detectron2.engine import default_argument_parser, hooks, HookBase 19 | from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping, build_lr_scheduler 20 | from detectron2.checkpoint import DetectionCheckpointer 21 | from detectron2.data import build_detection_train_loader, build_detection_test_loader, get_detection_dataset_dicts 22 | from detectron2.data.common import DatasetFromList, MapDataset 23 | from detectron2.data.samplers import InferenceSampler 24 | from detectron2.utils.events import get_event_storage 25 | 26 | from detectron2.utils import comm 27 | from detectron2.evaluation import COCOEvaluator, verify_results, inference_on_dataset, print_csv_format 28 | 29 | from detectron2.solver import LRMultiplier 30 | from detectron2.modeling import build_model 31 | from detectron2.structures import ImageList, Instances, pairwise_iou, Boxes 32 | 33 | from fvcore.common.param_scheduler import ParamScheduler 34 | from fvcore.common.checkpoint import Checkpointer 35 | 36 | from data.datasets import builtin 37 | 38 | from detectron2.evaluation import PascalVOCDetectionEvaluator, COCOEvaluator, inference_on_dataset 39 | 40 | from detectron2.data import build_detection_train_loader, MetadataCatalog 41 | import torch.utils.data as data 42 | from detectron2.data.dataset_mapper import DatasetMapper 43 | import detectron2.data.detection_utils as utils 44 | import detectron2.data.transforms as detT 45 | 46 | import torchvision.transforms as T 47 | import torchvision.transforms.functional as tF 48 | 49 | from modeling import add_stn_config 50 | from modeling import CustomPascalVOCDetectionEvaluator 51 | 52 | logger = logging.getLogger("detectron2") 53 | 54 | def setup(args): 55 | cfg = get_cfg() 56 | add_stn_config(cfg) 57 | #hack to add base yaml 58 | cfg.merge_from_file(args.config_file) 59 | cfg.merge_from_file(model_zoo.get_config_file(cfg.BASE_YAML)) 60 | cfg.merge_from_file(args.config_file) 61 | cfg.merge_from_list(args.opts) 62 | #cfg.freeze() 63 | default_setup(cfg, args) 64 | return cfg 65 | 66 | class CustomDatasetMapper(DatasetMapper): 67 | def __init__(self,cfg,is_train) -> None: 68 | super().__init__(cfg,is_train) 69 | self.with_crops = cfg.INPUT.CLIP_WITH_IMG 70 | self.with_random_clip_crops = cfg.INPUT.CLIP_RANDOM_CROPS 71 | self.with_jitter = cfg.INPUT.IMAGE_JITTER 72 | self.cropfn = T.RandomCrop#T.RandomCrop([224,224]) 73 | self.aug = T.ColorJitter(brightness=.5, hue=.3) 74 | self.crop_size = cfg.INPUT.RANDOM_CROP_SIZE 75 | 76 | def __call__(self,dataset_dict): 77 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 78 | # USER: Write your own image loading if it's not from a file 79 | image = utils.read_image(dataset_dict["file_name"], format=self.image_format) 80 | utils.check_image_size(dataset_dict, image) 81 | 82 | # USER: Remove if you don't do semantic/panoptic segmentation. 83 | if "sem_seg_file_name" in dataset_dict: 84 | sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) 85 | else: 86 | sem_seg_gt = None 87 | 88 | aug_input = detT.AugInput(image, sem_seg=sem_seg_gt) 89 | transforms = self.augmentations(aug_input) 90 | image, sem_seg_gt = aug_input.image, aug_input.sem_seg 91 | 92 | image_shape = image.shape[:2] # h, w 93 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 94 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 95 | # Therefore it's important to use torch.Tensor. 96 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 97 | if sem_seg_gt is not None: 98 | dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) 99 | 100 | # USER: Remove if you don't use pre-computed proposals. 101 | # Most users would not need this feature. 102 | if self.proposal_topk is not None: 103 | utils.transform_proposals( 104 | dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk 105 | ) 106 | 107 | if not self.is_train: 108 | # USER: Modify this if you want to keep them for some reason. 109 | dataset_dict.pop("annotations", None) 110 | dataset_dict.pop("sem_seg_file_name", None) 111 | return dataset_dict 112 | 113 | if "annotations" in dataset_dict: 114 | self._transform_annotations(dataset_dict, transforms, image_shape) 115 | 116 | if self.with_jitter: 117 | dataset_dict["jitter_image"] = self.aug(dataset_dict["image"]) 118 | 119 | if self.with_crops: 120 | bbox = dataset_dict['instances'].gt_boxes.tensor 121 | csx = (bbox[:,0] + bbox[:,2])*0.5 122 | csy = (bbox[:,1] + bbox[:,3])*0.5 123 | maxwh = torch.maximum(bbox[:,2]-bbox[:,0],bbox[:,3]-bbox[:,1]) 124 | crops = list() 125 | gt_boxes = list() 126 | mean=[0.48145466, 0.4578275, 0.40821073] 127 | std=[0.26862954, 0.26130258, 0.27577711] 128 | for cx,cy,maxdim,label,box in zip(csx,csy,maxwh,dataset_dict['instances'].gt_classes, bbox): 129 | 130 | if int(maxdim) < 10: 131 | continue 132 | x0 = torch.maximum(cx-maxdim*0.5,torch.tensor(0)) 133 | y0 = torch.maximum(cy-maxdim*0.5,torch.tensor(0)) 134 | try: 135 | imcrop = T.functional.resized_crop(dataset_dict['image'],top=int(y0),left=int(x0),height=int(maxdim),width=int(maxdim),size=224) 136 | imcrop = imcrop.flip(0)/255 # bgr --> rgb for clip 137 | imcrop = T.functional.normalize(imcrop,mean,std) 138 | # print(x0,y0,x0+maxdim,y0+maxdim,dataset_dict['image'].shape) 139 | # print(imcrop.min(),imcrop.max() ) 140 | gt_boxes.append(box.reshape(1,-1)) 141 | except Exception as e: 142 | print(e) 143 | print('crops:',x0,y0,maxdim) 144 | exit() 145 | # crops.append((imcrop,label)) 146 | crops.append(imcrop.unsqueeze(0)) 147 | 148 | if len(crops) == 0: 149 | dataset_dict['crops'] = [] 150 | else: 151 | dataset_dict['crops'] = [torch.cat(crops,0),Boxes(torch.cat(gt_boxes,0))] 152 | 153 | if self.with_random_clip_crops: 154 | crops = [] 155 | rbboxs = [] 156 | 157 | for i in range(15): 158 | minsize = min(dataset_dict['image'].shape[1],dataset_dict['image'].shape[2]) 159 | p = self.cropfn.get_params(dataset_dict['image'],[min(self.crop_size,minsize),min(self.crop_size,minsize)]) 160 | c = tF.crop(dataset_dict['image'],*p) 161 | if self.crop_size != 224: 162 | c = tF.resize(img=c,size=224) 163 | crops.append(c) 164 | rbboxs.append(p) 165 | 166 | crops = torch.stack(crops) 167 | dataset_dict['randomcrops'] = crops 168 | 169 | #apply same crop bbox to the jittered image 170 | if self.with_jitter: 171 | jitter_crops = [] 172 | for p in rbboxs: 173 | jc = tF.crop(dataset_dict['jitter_image'],*p) 174 | if self.crop_size != 224: 175 | jc = tF.resize(img=jc,size=224) 176 | jitter_crops.append(jc) 177 | 178 | jcrops = torch.stack(jitter_crops) 179 | dataset_dict['jitter_randomcrops'] = jcrops 180 | 181 | return dataset_dict 182 | 183 | class CombineLoaders(data.IterableDataset): 184 | def __init__(self,loaders): 185 | self.loaders = loaders 186 | 187 | def __iter__(self,): 188 | dd = iter(self.loaders[1]) 189 | for d1 in self.loaders[0]: 190 | try: 191 | d2 = next(dd) 192 | except: 193 | dd=iter(self.loaders[1]) 194 | d2 = next(dd) 195 | 196 | list_out_dict=[] 197 | for v1,v2 in zip(d1,d2): 198 | out_dict = {} 199 | for k in v1.keys(): 200 | out_dict[k] = (v1[k],v2[k]) 201 | list_out_dict.append(out_dict) 202 | 203 | yield list_out_dict 204 | 205 | 206 | class Trainer(DefaultTrainer): 207 | 208 | def __init__(self,cfg) -> None: 209 | super().__init__(cfg) 210 | self.teach_model = None 211 | self.off_opt_interval = np.arange(0,cfg.SOLVER.MAX_ITER,cfg.OFFSET_OPT_INTERVAL[0]).tolist() 212 | self.off_opt_iters = cfg.OFFSET_OPT_ITERS 213 | 214 | @classmethod 215 | def build_model(cls, cfg): 216 | """ 217 | Returns: 218 | torch.nn.Module: 219 | It now calls :func:`detectron2.modeling.build_model`. 220 | Overwrite it if you'd like a different model. 221 | """ 222 | model = build_model(cfg) 223 | 224 | 225 | logger = logging.getLogger(__name__) 226 | logger.info("Model:\n{}".format(model)) 227 | 228 | return model 229 | 230 | @classmethod 231 | def build_train_loader(cls,cfg): 232 | original = cfg.DATASETS.TRAIN 233 | print(original) 234 | # cfg.DATASETS.TRAIN=(original[0],) 235 | data_loader1 = build_detection_train_loader(cfg, mapper=CustomDatasetMapper(cfg, True)) 236 | return data_loader1 237 | 238 | @classmethod 239 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 240 | if output_folder is None: 241 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 242 | if MetadataCatalog.get(dataset_name).evaluator_type == 'pascal_voc': 243 | return CustomPascalVOCDetectionEvaluator(dataset_name) 244 | else: 245 | return COCOEvaluator(dataset_name, output_dir=output_folder) 246 | 247 | @classmethod 248 | def build_optimizer(cls,cfg,model): 249 | 250 | trainable = {'others':[],'offset':[]} 251 | 252 | for name,val in model.named_parameters(): 253 | head = name.split('.')[0] 254 | #previously was setting all params to be true 255 | if val.requires_grad == True: 256 | print(name) 257 | if 'offset' in name: 258 | trainable['offset'].append(val) 259 | else: 260 | trainable['others'].append(val) 261 | 262 | optimizer1 = torch.optim.SGD( 263 | trainable['others'], 264 | cfg.SOLVER.BASE_LR, 265 | momentum=cfg.SOLVER.MOMENTUM, 266 | nesterov=cfg.SOLVER.NESTEROV, 267 | weight_decay=cfg.SOLVER.WEIGHT_DECAY, 268 | ) 269 | 270 | optimizer2 = torch.optim.Adam( 271 | trainable['offset'], 272 | 0.01, 273 | ) 274 | return (maybe_add_gradient_clipping(cfg, optimizer1),maybe_add_gradient_clipping(cfg, optimizer2)) 275 | 276 | 277 | def run_step(self): 278 | """ 279 | Implement the standard training logic described above. 280 | """ 281 | 282 | 283 | assert self.model.training, "[SimpleTrainer] model was changed to eval mode!" 284 | start = time.perf_counter() 285 | """ 286 | If you want to do something with the data, you can wrap the dataloader. 287 | """ 288 | data = next(self._trainer._data_loader_iter) 289 | data_time = time.perf_counter() - start 290 | 291 | """ 292 | If you want to do something with the losses, you can wrap the model. 293 | """ 294 | data_s = data 295 | 296 | opt_phase = False 297 | if len(self.off_opt_interval) and self.iter >= self.off_opt_interval[0] and self.iter < self.off_opt_interval[0]+self.off_opt_iters: 298 | 299 | if self.iter == self.off_opt_interval[0]: 300 | self.model.offsets.data = torch.zeros(self.model.offsets.shape).cuda() 301 | loss_dict_s = self.model.opt_offsets(data_s) 302 | opt_phase = True 303 | if self.iter+1 == self.off_opt_interval[0]+self.off_opt_iters: 304 | self.off_opt_interval.pop(0) 305 | 306 | else: 307 | # for ind, d in enumerate(data_s): 308 | # d['image'] = self.aug(d['image'].cuda()) 309 | loss_dict_s = self.model(data_s) 310 | # print(loss_dict_s) 311 | 312 | # import pdb;pdb.set_trace() 313 | loss_dict = {} 314 | 315 | loss = 0 316 | for k,v in loss_dict_s.items(): 317 | loss += v 318 | 319 | 320 | """ 321 | If you need to accumulate gradients or do something similar, you can 322 | wrap the optimizer with your custom `zero_grad()` method. 323 | """ 324 | self.optimizer[0].zero_grad() 325 | self.optimizer[1].zero_grad() 326 | 327 | loss.backward() 328 | 329 | if not opt_phase: 330 | self.optimizer[0].step() 331 | else: 332 | self.optimizer[1].step() 333 | 334 | self.optimizer[0].zero_grad() 335 | self.optimizer[1].zero_grad() 336 | 337 | for k,v in loss_dict_s.items(): 338 | loss_dict.update({k:v}) 339 | 340 | # print(loss_di ct) 341 | self._trainer._write_metrics(loss_dict, data_time) 342 | """ 343 | If you need gradient clipping/scaling or other processing, you can 344 | wrap the optimizer with your custom `step()` method. But it is 345 | suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4 346 | """ 347 | 348 | def build_hooks(self): 349 | """ 350 | Build a list of default hooks, including timing, evaluation, 351 | checkpointing, lr scheduling, precise BN, writing events. 352 | Returns: 353 | list[HookBase]: 354 | """ 355 | cfg = self.cfg.clone() 356 | cfg.defrost() 357 | cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN 358 | 359 | ret = [ 360 | hooks.IterationTimer(), 361 | LRScheduler(), 362 | hooks.PreciseBN( 363 | # Run at the same freq as (but before) evaluation. 364 | cfg.TEST.EVAL_PERIOD, 365 | self.model, 366 | # Build a new data loader to not affect training 367 | self.build_train_loader(cfg), 368 | cfg.TEST.PRECISE_BN.NUM_ITER, 369 | ) 370 | if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) 371 | else None, 372 | ] 373 | 374 | # Do PreciseBN before checkpointer, because it updates the model and need to 375 | # be saved by checkpointer. 376 | # This is not always the best: if checkpointing has a different frequency, 377 | # some checkpoints may have more precise statistics than others. 378 | if comm.is_main_process(): 379 | ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) 380 | 381 | def test_and_save_results(): 382 | self._last_eval_results = self.test(self.cfg, self.model) 383 | return self._last_eval_results 384 | 385 | def do_test_st(flag): 386 | if flag == 'st': 387 | model = self.model 388 | else: 389 | print("Error in the flag") 390 | 391 | results = OrderedDict() 392 | for dataset_name in self.cfg.DATASETS.TEST: 393 | data_loader = build_detection_test_loader(self.cfg, dataset_name) 394 | evaluator = CustomPascalVOCDetectionEvaluator(dataset_name) 395 | results_i = inference_on_dataset(model, data_loader, evaluator) 396 | results[dataset_name] = results_i 397 | if comm.is_main_process(): 398 | logger.info("Evaluation results for {} in csv format:".format(dataset_name)) 399 | print_csv_format(results_i) 400 | storage = get_event_storage() 401 | storage.put_scalar(f'{dataset_name}_AP50', results_i['bbox']['AP50'],smoothing_hint=False) 402 | if len(results) == 1: 403 | results = list(results.values())[0] 404 | return results 405 | 406 | 407 | # Do evaluation after checkpointer, because then if it fails, 408 | # we can use the saved checkpoint to debug. 409 | ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) 410 | ret.append(hooks.EvalHook(cfg.TEST.EVAL_SAVE_PERIOD, lambda flag='st': do_test_st(flag))) 411 | 412 | if comm.is_main_process(): 413 | # Here the default print/log frequency of each writer is used. 414 | # run writers in the end, so that evaluation metrics are written 415 | ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) 416 | return ret 417 | 418 | @classmethod 419 | def build_lr_scheduler(cls, cfg, optimizer): 420 | """ 421 | It now calls :func:`detectron2.solver.build_lr_scheduler`. 422 | Overwrite it if you'd like a different scheduler. 423 | """ 424 | 425 | return build_lr_scheduler(cfg, optimizer[0]) 426 | 427 | def state_dict(self): 428 | ret = super().state_dict() 429 | ret["optimizer1"] = self.optimizer[0].state_dict() 430 | ret["optimizer2"] = self.optimizer[1].state_dict() 431 | return ret 432 | 433 | def load_state_dict(self, state_dict): 434 | super().load_state_dict(state_dict) 435 | self.optimizer[0].load_state_dict(state_dict["optimizer1"]) 436 | self.optimizer[1].load_state_dict(state_dict["optimizer2"]) 437 | 438 | 439 | 440 | class LRScheduler(HookBase): 441 | """ 442 | A hook which executes a torch builtin LR scheduler and summarizes the LR. 443 | It is executed after every iteration. 444 | """ 445 | 446 | def __init__(self, optimizer=None, scheduler=None): 447 | """ 448 | Args: 449 | optimizer (torch.optim.Optimizer): 450 | scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler): 451 | if a :class:`ParamScheduler` object, it defines the multiplier over the base LR 452 | in the optimizer. 453 | If any argument is not given, will try to obtain it from the trainer. 454 | """ 455 | self._optimizer = optimizer 456 | self._scheduler = scheduler 457 | 458 | def before_train(self): 459 | self._optimizer = self._optimizer or self.trainer.optimizer 460 | if isinstance(self.scheduler, ParamScheduler): 461 | self._scheduler = LRMultiplier( 462 | self._optimizer, 463 | self.scheduler, 464 | self.trainer.max_iter, 465 | last_iter=self.trainer.iter - 1, 466 | ) 467 | self._best_param_group_id1 = LRScheduler.get_best_param_group_id(self._optimizer[0]) 468 | self._best_param_group_id2 = LRScheduler.get_best_param_group_id(self._optimizer[1]) 469 | 470 | 471 | @staticmethod 472 | def get_best_param_group_id(optimizer): 473 | # NOTE: some heuristics on what LR to summarize 474 | # summarize the param group with most parameters 475 | largest_group = max(len(g["params"]) for g in optimizer.param_groups) 476 | 477 | if largest_group == 1: 478 | # If all groups have one parameter, 479 | # then find the most common initial LR, and use it for summary 480 | lr_count = Counter([g["lr"] for g in optimizer.param_groups]) 481 | lr = lr_count.most_common()[0][0] 482 | for i, g in enumerate(optimizer.param_groups): 483 | if g["lr"] == lr: 484 | return i 485 | else: 486 | for i, g in enumerate(optimizer.param_groups): 487 | if len(g["params"]) == largest_group: 488 | return i 489 | 490 | def after_step(self): 491 | lr1 = self._optimizer[0].param_groups[self._best_param_group_id1]["lr"] 492 | self.trainer.storage.put_scalar("lr1", lr1, smoothing_hint=False) 493 | 494 | lr2 = self._optimizer[1].param_groups[self._best_param_group_id2]["lr"] 495 | self.trainer.storage.put_scalar("lr2", lr2, smoothing_hint=False) 496 | 497 | self.scheduler.step() 498 | 499 | @property 500 | def scheduler(self): 501 | return self._scheduler or self.trainer.scheduler 502 | 503 | def state_dict(self): 504 | if isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler): 505 | return self.scheduler.state_dict() 506 | return {} 507 | 508 | def load_state_dict(self, state_dict): 509 | if isinstance(self.scheduler, torch.optim.lr_scheduler._LRScheduler): 510 | logger = logging.getLogger(__name__) 511 | logger.info("Loading scheduler from state_dict ...") 512 | self.scheduler.load_state_dict(state_dict) 513 | 514 | def custom_build_detection_test_loader(cfg,dataset_name,mapper=None): 515 | 516 | if isinstance(dataset_name, str): 517 | dataset_name = [dataset_name] 518 | 519 | dataset = get_detection_dataset_dicts( 520 | dataset_name, 521 | filter_empty=False, 522 | proposal_files=[ 523 | cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name 524 | ] 525 | if cfg.MODEL.LOAD_PROPOSALS 526 | else None, 527 | ) 528 | if mapper is None: 529 | mapper = DatasetMapper(cfg, False) 530 | 531 | if isinstance(dataset, list): 532 | dataset = DatasetFromList(dataset, copy=False) 533 | if mapper is not None: 534 | dataset = MapDataset(dataset, mapper) 535 | 536 | sampler = None 537 | if isinstance(dataset, torchdata.IterableDataset): 538 | assert sampler is None, "sampler must be None if dataset is IterableDataset" 539 | else: 540 | if sampler is None: 541 | sampler = InferenceSampler(len(dataset)) 542 | collate_fn = None 543 | 544 | def trivial_batch_collator(batch): 545 | """ 546 | A batch collator that does nothing. 547 | """ 548 | return batch 549 | 550 | return torchdata.DataLoader( 551 | dataset, 552 | batch_size=1, 553 | sampler=sampler, 554 | drop_last=False, 555 | num_workers=cfg.DATALOADER.NUM_WORKERS, 556 | collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, 557 | ) 558 | 559 | 560 | def do_test(cfg, model, model_type=''): 561 | results = OrderedDict() 562 | for dataset_name in cfg.DATASETS.TEST: 563 | data_loader = build_detection_test_loader(cfg, dataset_name)#custom_build_detection_test_loader(cfg, dataset_name,CustomDatasetMapper(cfg,is_train=True)) 564 | evaluator = CustomPascalVOCDetectionEvaluator(dataset_name)#COCOEvaluator(dataset_name, output_dir=os.path.join(cfg.OUTPUT_DIR, "inference")) 565 | results_i = inference_on_dataset(model, data_loader, evaluator) 566 | results[dataset_name] = results_i 567 | if comm.is_main_process(): 568 | logger.info("Evaluation results for {} in csv format:".format(dataset_name)) 569 | print_csv_format(results_i) 570 | 571 | if len(results) == 1: 572 | results = list(results.values())[0] 573 | return results 574 | 575 | def main(args): 576 | cfg = setup(args) 577 | if args.eval_only: 578 | model = Trainer.build_model(cfg) 579 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 580 | cfg.MODEL.WEIGHTS, resume=args.resume 581 | ) 582 | return do_test(cfg,model) 583 | trainer = Trainer(cfg) 584 | trainer.resume_or_load(resume=args.resume) 585 | for dataset_name in cfg.DATASETS.TEST: 586 | if '_val' in dataset_name : 587 | trainer.register_hooks([ 588 | 589 | hooks.BestCheckpointer(cfg.TEST.EVAL_SAVE_PERIOD,trainer.checkpointer,f'{dataset_name}_AP50',file_prefix='model_best'), 590 | ]) 591 | 592 | trainer.train() 593 | 594 | 595 | if __name__ == "__main__": 596 | args = default_argument_parser().parse_args() 597 | cfg = setup(args) 598 | print("Command Line Args:", args) 599 | 600 | main(args) 601 | --------------------------------------------------------------------------------