├── doraemon ├── built │ ├── __init__.py │ ├── layer_optimizer.py │ ├── class_augmenter.py │ └── attention_based_pooler.py ├── models │ ├── representation │ │ ├── head │ │ │ ├── __init__.py │ │ │ ├── arcface.py │ │ │ ├── circleloss.py │ │ │ ├── mv_softmax.py │ │ │ ├── magface.py │ │ │ └── head_def.py │ │ ├── backbone │ │ │ ├── __init__.py │ │ │ ├── backbone_def.py │ │ │ └── timm_wrapper.py │ │ ├── __init__.py │ │ ├── README_CBIR.md │ │ └── face_model.py │ ├── classifier │ │ ├── __init__.py │ │ ├── classify_model.py │ │ └── README.md │ ├── losses │ │ ├── __init__.py │ │ └── loss.py │ ├── __init__.py │ ├── smartmodel.py │ └── ema.py ├── structure │ ├── __init__.py │ └── sampler.py ├── distills │ ├── __init__.py │ └── distillers.py ├── engine │ ├── procedure │ │ ├── __init__.py │ │ ├── visualizer.py │ │ ├── eval_recog.py │ │ └── train.py │ ├── representation │ │ ├── __init__.py │ │ ├── eval_face.py │ │ └── eval_cbir.py │ ├── __init__.py │ ├── scheduler.py │ └── optimizer.py ├── utils │ ├── __init__.py │ ├── average_meter.py │ ├── logger.py │ ├── plots.py │ ├── cam.py │ └── checks.py ├── dataset │ ├── __init__.py │ └── dataprocessor.py └── __init__.py ├── misc ├── Arial.ttf ├── cbir.jpg ├── gradcam.jpg ├── augments.jpg ├── cbir_val.jpg ├── doraemon.jpg ├── training.jpg ├── eval_class.jpeg └── tensorboard.jpg ├── .gitignore ├── deploy ├── config.json ├── README.md └── doraemon_modeling.py ├── data ├── toy-multi-cls.csv └── split2dataset.py ├── configs ├── representation │ ├── head_conf.yaml │ ├── backbone_conf.yaml │ ├── README.md │ ├── face.yaml │ └── image-retrieval.yaml └── recognition │ └── pet.yaml ├── setup.py ├── tools ├── deduplicate.py ├── single_predict.py ├── onnx_predict.py ├── build_querygallery.py ├── clustering.py ├── dataset_upload_hf.py ├── data_prepare.py ├── video_predict.py └── test_augment.py ├── scripts ├── train.py ├── validate.py └── infer.py └── README.md /doraemon/built/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /doraemon/models/representation/head/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /doraemon/models/representation/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /doraemon/structure/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import OHEMImageSampler -------------------------------------------------------------------------------- /doraemon/distills/__init__.py: -------------------------------------------------------------------------------- 1 | from .distillers import DistillCenterProcessor -------------------------------------------------------------------------------- /misc/Arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuji3/Doraemon/HEAD/misc/Arial.ttf -------------------------------------------------------------------------------- /misc/cbir.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuji3/Doraemon/HEAD/misc/cbir.jpg -------------------------------------------------------------------------------- /doraemon/models/classifier/__init__.py: -------------------------------------------------------------------------------- 1 | from .classify_model import VisionWrapper -------------------------------------------------------------------------------- /misc/gradcam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuji3/Doraemon/HEAD/misc/gradcam.jpg -------------------------------------------------------------------------------- /misc/augments.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuji3/Doraemon/HEAD/misc/augments.jpg -------------------------------------------------------------------------------- /misc/cbir_val.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuji3/Doraemon/HEAD/misc/cbir_val.jpg -------------------------------------------------------------------------------- /misc/doraemon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuji3/Doraemon/HEAD/misc/doraemon.jpg -------------------------------------------------------------------------------- /misc/training.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuji3/Doraemon/HEAD/misc/training.jpg -------------------------------------------------------------------------------- /misc/eval_class.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuji3/Doraemon/HEAD/misc/eval_class.jpeg -------------------------------------------------------------------------------- /misc/tensorboard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wuji3/Doraemon/HEAD/misc/tensorboard.jpg -------------------------------------------------------------------------------- /doraemon/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import bce, ce, focal, create_Lossfn, list_lossfns -------------------------------------------------------------------------------- /doraemon/models/representation/__init__.py: -------------------------------------------------------------------------------- 1 | from .face_model import FaceTrainingModel, FaceTrainingWrapper -------------------------------------------------------------------------------- /doraemon/engine/procedure/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval_recog import ConfusedMatrix, valuate 2 | from .visualizer import Visualizer -------------------------------------------------------------------------------- /doraemon/engine/representation/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval_cbir import valuate as valuate_cbir 2 | from .eval_face import valuate as valuate_face -------------------------------------------------------------------------------- /doraemon/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import SmartLogger 2 | from .cam import ClassActivationMaper 3 | from .plots import colorstr 4 | from .checks import check -------------------------------------------------------------------------------- /doraemon/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .smartmodel import get_model, VisionWrapper 2 | from .ema import ModelEMA 3 | from .representation.face_model import FaceModelLoader -------------------------------------------------------------------------------- /doraemon/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataprocessor import SmartDataProcessor 2 | from .transforms import create_AugTransforms 3 | from .basedataset import PredictImageDatasets -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /demo.py 2 | *.jpg 3 | *.log 4 | /run/ 5 | /inference 6 | /visualization/ 7 | /configs/catfood.yaml 8 | /dir/ 9 | .idea 10 | /data/ 11 | __pycache__*/ 12 | .DS_Store 13 | .jpg 14 | .png 15 | .npy 16 | doraemon_torch.egg-info/ 17 | 18 | -------------------------------------------------------------------------------- /doraemon/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .engine import * 3 | from .utils import * 4 | from .dataset import * 5 | from .distills import * 6 | from .engine.procedure import valuate as valuate_classifier 7 | from .engine.representation import valuate_cbir, valuate_face 8 | -------------------------------------------------------------------------------- /doraemon/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .optimizer import sgd, adam, sam, BaseSeperateLayer, create_Optimizer, list_optimizers 2 | from .scheduler import linear, cosine, linear_with_warm, cosine_with_warm, create_Scheduler, list_schedulers 3 | from .procedure import ConfusedMatrix, valuate, Visualizer 4 | from .vision_engine import CenterProcessor, yaml_load, increment_path 5 | -------------------------------------------------------------------------------- /deploy/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "doraemon", 3 | "model_path": "/path/to/xxx.pt OR user/repo/filename.pt:v1(revision)", 4 | "auto_map":{ 5 | "AutoConfig": "doraemon_modeling.DoraemonConfig", 6 | "AutoModel": "doraemon_modeling.DoraemonClassifier", 7 | "AutoProcessor": "doraemon_modeling.DoraemonProcessor" 8 | }, 9 | "transformers_version": "4.48.3" 10 | } -------------------------------------------------------------------------------- /doraemon/models/smartmodel.py: -------------------------------------------------------------------------------- 1 | from .representation import FaceTrainingWrapper 2 | from .classifier import VisionWrapper 3 | 4 | 5 | def get_model(model_cfg, logger, rank): 6 | assert 'task' in model_cfg, 'Task is not specified' 7 | 8 | match model_cfg['task']: 9 | case 'face' | 'cbir': return FaceTrainingWrapper(model_cfg, logger) 10 | case 'classification': return VisionWrapper(model_cfg, logger, rank) -------------------------------------------------------------------------------- /data/toy-multi-cls.csv: -------------------------------------------------------------------------------- 1 | image_path,cute,fluffy,small,aggressive,playful,train 2 | /home/xxx/data/toy-multi-cls/cat_1.jpg,1,1,1,0,1,True 3 | /home/xxx/data/toy-multi-cls/dog_1.jpg,1,1,0,0,1,True 4 | /home/xxx/data/toy-multi-cls/bird_1.jpg,1,0,1,0,0,True 5 | /home/xxx/data/toy-multi-cls/fish_1.jpg,0,0,1,0,0,True 6 | /home/xxx/data/toy-multi-cls/rabbit_1.jpg,1,1,1,0,1,True 7 | /home/xxx/data/toy-multi-cls/cat_3.jpg,1,1,1,0,1,False 8 | /home/xxx/data/toy-multi-cls/dog_3.jpg,1,1,0,1,1,False 9 | /home/xxx/data/toy-multi-cls/bird_3.jpg,1,0,1,0,0,False 10 | -------------------------------------------------------------------------------- /configs/representation/head_conf.yaml: -------------------------------------------------------------------------------- 1 | arcface: 2 | feat_dim: 512 3 | num_class: 72778 4 | margin_arc: 0.35 5 | margin_am: 0.0 6 | scale: 32 7 | 8 | magface: 9 | feat_dim: 512 10 | num_class: 72778 11 | margin_am: 0.0 12 | scale: 64 13 | l_a: 10 14 | u_a: 110 15 | l_margin: 0.45 16 | u_margin: 0.8 17 | lamda: 20 18 | 19 | circle: 20 | feat_dim: 512 21 | num_class: 72778 22 | margin: 0.25 23 | gamma: 256 24 | 25 | mv-softmax: 26 | feat_dim: 512 27 | num_class: 72778 28 | is_am: 1 29 | margin: 0.35 30 | mv_weight: 1.12 31 | scale: 32 -------------------------------------------------------------------------------- /doraemon/utils/average_meter.py: -------------------------------------------------------------------------------- 1 | # based on: https://github.com/pytorch/examples/blob/master/imagenet/main.py 2 | 3 | class AverageMeter: 4 | """Computes and stores the average and current value""" 5 | def __init__(self): 6 | #self.name = name 7 | #self.fmt = fmt 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /configs/representation/backbone_conf.yaml: -------------------------------------------------------------------------------- 1 | resnet: # 50 100 152 2 | depth: 152 3 | image_size: 112 4 | drop_ratio: 0.4 5 | net_mode: ir_se # ir ir_se 6 | feat_dim: 512 7 | 8 | efficientnet: 9 | image_size: 112 10 | # about [width, depth, image_size, drop_ratio] see 473-Line in backbone/EfficientNets.py 11 | width: 1.0 12 | depth: 1.0 13 | drop_ratio: 0.2 14 | # [out_h, out_w] decide the last linear layer, see 835-Line in backbone/EfficientNets.py 15 | out_h: 7 16 | out_w: 7 17 | # feat_dim decide the embedding dim 18 | feat_dim: 512 19 | 20 | swintransformer: # tiny small base 21 | model_size: base 22 | image_size: 224 23 | in_chans: 3 24 | feat_dim: 512 25 | 26 | convnext: # tiny small base large 27 | model_size: tiny 28 | feat_dim: 512 29 | image_size: 224 -------------------------------------------------------------------------------- /doraemon/models/representation/backbone/backbone_def.py: -------------------------------------------------------------------------------- 1 | from .timm_wrapper import TimmWrapper 2 | 3 | class BackboneFactory: 4 | """Factory to produce backbone according the backbone_conf.yaml. 5 | 6 | Attributes: 7 | backbone_type: which backbone will produce. 8 | backbone_param: params about model structure. 9 | """ 10 | 11 | def __init__(self, backbone_config): 12 | # self.backbone_type = list(backbone_config['backbone'].keys())[0] 13 | # self.backbone_param = backbone_config['backbone'][self.backbone_type] 14 | for k, v in backbone_config.items(): self.backbone_type, self.backbone_param = k, v 15 | 16 | def get_backbone(self): 17 | 18 | if self.backbone_type.startswith('timm'): 19 | model = self.backbone_type.split('-')[1] 20 | return TimmWrapper( 21 | model_name=model, 22 | **self.backbone_param, 23 | ) 24 | 25 | else: 26 | raise NotImplemented(f"{self.backbone_type} is not supported now !") -------------------------------------------------------------------------------- /doraemon/built/layer_optimizer.py: -------------------------------------------------------------------------------- 1 | from ..engine.optimizer import BaseSeperateLayer 2 | from typing import Iterator, List, Dict, Union 3 | from torch.nn import Module 4 | 5 | class SeperateLayerParams(BaseSeperateLayer): 6 | def __init__(self, model: Module): 7 | super().__init__(model) 8 | 9 | def create_ParamSequence(self, layer_wise: bool, lr: float) -> Union[Iterator, List[Dict]]: 10 | """ 11 | Args: 12 | layer_wise: lr衰减系数 13 | lr: 基准学习率 14 | 15 | Returns: 16 | params: torch.optim.Optimizer中的params 17 | """ 18 | if not layer_wise: return self.model.parameters() 19 | 20 | # params = [ 21 | # {'params': self.model.features.parameters()}, 22 | # {'params': self.model.norm.parameters()}, 23 | # {'params': self.model.head.parameters(), 'lr': lr * 10} 24 | # ] 25 | 26 | params = [ 27 | {'params': self.model.trainingwrapper['backbone'].parameters(), 'lr': lr}, 28 | {'params': self.model.trainingwrapper['head'].parameters(), 'lr': lr * 10} 29 | ] 30 | return params -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", mode="r", encoding="utf-8") as readme_file: 4 | readme = readme_file.read() 5 | 6 | setup( 7 | name='doraemon-torch', 8 | version='0.0.10a', 9 | author='duke', 10 | author_email='dk812821001@163.com', 11 | description='Doraemon', 12 | long_description=readme, 13 | long_description_content_type="text/markdown", 14 | url='https://github.com/wuji3/Doraemon', 15 | packages=find_packages(include=['doraemon', 'doraemon.*']), 16 | python_requires='>=3.10', 17 | install_requires=[ 18 | 'torchmetrics>=0.11.4', 19 | 'opencv-python>=4.7.0.72', 20 | 'numpy>=1.24.3', 21 | 'tqdm>=4.66.4', 22 | 'Pillow>=9.4.0', 23 | 'grad-cam>=1.4.8', 24 | 'timm>=0.9.16', 25 | 'tensorboard>=2.16.2', 26 | 'prettytable>=3.10.0', 27 | 'datasets>=2.20.0', 28 | 'imagehash>=4.3.1', 29 | 'transformers>=4.48.3', 30 | 'torch>=2.5.1', 31 | 'torchvision>=0.20.1', 32 | 'torchaudio>=2.5.1', 33 | 'faiss-cpu>=1.7.2', 34 | ], 35 | ) 36 | -------------------------------------------------------------------------------- /doraemon/structure/sampler.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class OHEMImageSampler: 6 | def __init__(self, min_kept: int, thresh: float, ignore_index: int = 255): 7 | self.min_kept = min_kept 8 | self.thresh = thresh 9 | self.ignore_index = ignore_index 10 | 11 | def sample(self, logits: Tensor, labels: Tensor): 12 | with torch.no_grad(): 13 | prob = F.softmax(logits, dim=1) 14 | # ignore ignore_index here 15 | valid1 = labels != self.ignore_index 16 | prob = prob[valid1] 17 | 18 | # extract the score for correct predictions 19 | tmp_prob = prob.gather(1, labels[valid1].unsqueeze(1)).squeeze(1) 20 | sort_prob, sort_indices = tmp_prob.sort() 21 | 22 | min_thresh = sort_prob[min(self.min_kept, sort_prob.numel()-1)] 23 | threshold = max(min_thresh, self.thresh) 24 | 25 | temp_valid = sort_prob < threshold 26 | valid_indices = sort_indices[temp_valid] 27 | 28 | valid2 = torch.zeros_like(labels, dtype=torch.bool) 29 | valid2[valid_indices] = True 30 | 31 | return valid1 & valid2 32 | -------------------------------------------------------------------------------- /tools/deduplicate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imagehash 3 | from PIL import Image 4 | from tqdm import tqdm 5 | 6 | def find_similar_images(userpaths, hashfunc=imagehash.average_hash): 7 | def is_image(filename): 8 | f = filename.lower() 9 | return f.endswith('.png') or f.endswith('.jpg') or \ 10 | f.endswith('.jpeg') or f.endswith('.bmp') or \ 11 | f.endswith('.gif') or '.jpg' in f or f.endswith('.svg') 12 | 13 | image_filenames = [] 14 | for userpath in userpaths: 15 | image_filenames += [os.path.join(userpath, path) for path in os.listdir(userpath) if is_image(path)] 16 | images = {} 17 | for img in tqdm(sorted(image_filenames), total=len(image_filenames)): 18 | try: 19 | hash = hashfunc(Image.open(img)) 20 | except Exception as e: 21 | print('Problem:', e, 'with', img) 22 | continue 23 | if hash in images: 24 | # print(img, ' already exists as', ' '.join(images[hash])) 25 | os.remove(img) # delete the duplicate image 26 | feat = f"/home/duke/data/favie/v2-embedding/features/{os.path.basename(img).replace('.jpg', '.npy')}" 27 | if os.path.isfile(feat): 28 | os.remove(feat) 29 | else: 30 | images[hash] = images.get(hash, []) + [img] 31 | 32 | return images 33 | 34 | if __name__ == '__main__': 35 | userpaths = ['/home/duke/data/favie/v2-embedding/images'] 36 | find_similar_images(userpaths, imagehash.dhash) -------------------------------------------------------------------------------- /doraemon/built/class_augmenter.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Dict, Optional 2 | from ..dataset.transforms import BaseClassWiseAugmenter 3 | 4 | class ClassWiseAugmenter(BaseClassWiseAugmenter): 5 | def __init__(self, base_transforms: Dict, class_transforms_mapping: Optional[Dict[str, List[int]]], base: List[int]): 6 | if base is not None: 7 | assert isinstance(base, list), f'{base} is not a list of indices' 8 | base_transforms = [t for i, t in enumerate(base_transforms) if i in base] 9 | 10 | super().__init__(base_transforms=base_transforms, class_transforms_mapping=class_transforms_mapping) 11 | 12 | def __call__(self, image, label: Union[List, int], class_indices: List[int]): 13 | if self.class_transforms is None: 14 | return super().__call__(image=image, label=label, class_indices=class_indices) 15 | 16 | # softmax 17 | if isinstance(label, int): 18 | if class_indices[label] in self.class_transforms: 19 | return self.class_transforms[class_indices[label]](image) 20 | else: return self.base_transforms(image) 21 | # sigmoid 22 | elif isinstance(label, list): # multi-label 23 | # multi-label 24 | if len(label) == 1: # Customized specific class 25 | c = label[0] 26 | if class_indices[c] in self.class_transforms: 27 | return self.class_transforms[class_indices[c]](image) 28 | 29 | # Generally common class 30 | return self.base_transforms(image) -------------------------------------------------------------------------------- /tools/single_predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models import get_model 3 | from PIL import Image 4 | import torchvision.transforms as T 5 | import argparse 6 | from dataset.transforms import create_AugTransforms 7 | 8 | def parse_opt(): 9 | parsers = argparse.ArgumentParser() 10 | 11 | parsers.add_argument('--img', default='./edd4dc86b50997f29b81ba0b2bab1906.jpg', type=str) 12 | parsers.add_argument('--pt', default='./best.pt', type=str) 13 | parsers.add_argument('--transforms', default={'resize': [640, 640], 'to_tensor': 'no_params', 'normalize': 'no_params'}) 14 | 15 | args = parsers.parse_args() 16 | return args 17 | 18 | def image_process(path: str, transforms: T.Compose): 19 | img = Image.open(path).convert('RGB') 20 | return transforms(img).unsqueeze(0) 21 | 22 | def main(opt): 23 | 24 | # variable 25 | img_path = opt.img 26 | weight_path = opt.pt 27 | transforms = opt.transforms 28 | 29 | # image 30 | image = image_process(img_path, create_AugTransforms(eval(transforms) if isinstance(transforms, str) else transforms)) 31 | 32 | # model 33 | model = get_model('mobilenet_v2', width_mult = 0.25) 34 | model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, 7) 35 | weight = torch.load(weight_path, map_location='cpu')['model'] 36 | model.load_state_dict(weight) 37 | # eval 38 | model.eval() 39 | 40 | out = model(image) 41 | print(torch.nn.functional.softmax(out, dim=-1)) 42 | 43 | 44 | if __name__ == '__main__': 45 | opt = parse_opt() 46 | main(opt) -------------------------------------------------------------------------------- /tools/onnx_predict.py: -------------------------------------------------------------------------------- 1 | import onnxruntime 2 | from PIL import Image 3 | import torchvision.transforms as T 4 | import argparse 5 | from dataset.transforms import create_AugTransforms 6 | 7 | def parse_opt(): 8 | parsers = argparse.ArgumentParser() 9 | 10 | parsers.add_argument('--img', default='./img.png', type=str) 11 | parsers.add_argument('--onnx', default='./shufflev2_0.5.onnx', type=str) 12 | parsers.add_argument('--transforms', default='centercrop_resize to_tensor normalize') 13 | parsers.add_argument('--imgsz', default='[[720, 720], [360, 360]]', type=str) 14 | parsers.add_argument('--input_onnx', default='input', type=str, help = 'input_name of onnx ') 15 | parsers.add_argument('--output_onnx', default='prob', type=str, help = 'output_name of onnx ') 16 | 17 | args = parsers.parse_args() 18 | return args 19 | 20 | def image_process(path: str, transforms: T.Compose): 21 | img = Image.open(path).convert('RGB') 22 | return transforms(img).unsqueeze(0).numpy() 23 | 24 | def main(opt): 25 | 26 | # variable 27 | img_path = opt.img 28 | onnx_path = opt.onnx 29 | transforms = opt.transforms 30 | imgsz = opt.imgsz 31 | intput_name = opt.input_onnx 32 | output_name = opt.output_onnx 33 | 34 | # image 35 | image = image_process(img_path, create_AugTransforms(transforms, eval(imgsz))) 36 | 37 | # model 38 | session = onnxruntime.InferenceSession(onnx_path) 39 | output = session.run([f'{output_name}'], {f'{intput_name}': image})[0] 40 | 41 | print(output) 42 | 43 | if __name__ == '__main__': 44 | opt = parse_opt() 45 | main(opt) 46 | 47 | -------------------------------------------------------------------------------- /data/split2dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from os.path import join as opj 4 | 5 | def splitImg2Category(dataDir="oxford-iiit-pet/images/", resDir="pet/"): 6 | for one_pic in os.listdir(dataDir): 7 | one_path = opj(dataDir, one_pic) 8 | oneDir = opj(resDir, one_pic.split('_')[0].strip()) 9 | os.makedirs(oneDir, exist_ok=True) 10 | shutil.copy(one_path, opj(oneDir, one_pic)) 11 | 12 | if __name__ == '__main__': 13 | # tag class 14 | splitImg2Category() 15 | 16 | # split 17 | annos = ['oxford-iiit-pet/annotations/trainval.txt', 'oxford-iiit-pet/annotations/test.txt'] 18 | 19 | for i, anno in enumerate(annos): 20 | mode = 'train' if i == 0 else 'val' 21 | 22 | with open(anno, 'r') as f: 23 | for img_name in f.readlines(): 24 | img_name = img_name.split()[0].strip() 25 | img_category = img_name.split('_')[0].strip() 26 | 27 | dstDir = opj('pet', mode, img_category) 28 | os.makedirs(dstDir, exist_ok=True) 29 | 30 | cur_category = opj('pet', img_category) 31 | if not os.path.exists(opj(dstDir, img_name + '.jpg')): 32 | shutil.move(opj('pet', img_category, img_name + '.jpg'), dstDir) 33 | 34 | # remove temporary directories 35 | for dir_ in os.listdir('pet'): 36 | if dir_ not in ('train', 'val'): 37 | shutil.rmtree(opj('pet', dir_)) 38 | 39 | # remove original oxford-iiit-pet directory 40 | shutil.rmtree('oxford-iiit-pet') 41 | 42 | print("Data preparation completed. The 'pet' directory is ready for use.") 43 | -------------------------------------------------------------------------------- /doraemon/distills/distillers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import Tensor 3 | from typing import Callable 4 | from ..engine.vision_engine import CenterProcessor 5 | 6 | class Distiller: 7 | def __init__(self, 8 | model_teacher: nn.Module, 9 | model_student: nn.Module, 10 | criterion_cls: Callable, 11 | criterion_kl: Callable, 12 | cls_weight: float = 0.5, 13 | kl_weight: float = 0.5): 14 | 15 | self.model_teacher = model_teacher 16 | self.model_student = model_student 17 | self.criterion_cls = criterion_cls 18 | self.criterion_kl = criterion_kl 19 | self.cls_weight = cls_weight 20 | self.kl_weight = kl_weight 21 | 22 | def __call__(self, inputs: Tensor, label: Tensor) -> Tensor: 23 | # forward 24 | logit_s = self.model_student(inputs) 25 | logit_t = self.model_teacher(inputs) 26 | # compute loss 27 | loss_cls = self.criterion_cls(logit_s, label) 28 | loss_kl = self.criterion_kl(logit_s, logit_t) 29 | loss = self.cls_weight * loss_cls + self.kl_weight * loss_kl 30 | 31 | return loss 32 | 33 | class DistillCenterProcessor(CenterProcessor): 34 | def __init__(self, cfgs: dict, rank: int, project: str, logger = None, opt = None): 35 | super().__init__(cfgs=cfgs['student'], rank=rank, project= project) 36 | 37 | # init teacher model 38 | # self.teacher = TorchVisionWrapper(cfgs['teacher'], logger = logger) 39 | # self.opt = opt 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /doraemon/models/ema.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch.nn as nn 3 | import math 4 | 5 | def is_parallel(model): 6 | # Returns True if model is of type DP or DDP 7 | return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel) 8 | 9 | 10 | def de_parallel(model): 11 | # De-parallelize a model: returns single-GPU model if model is of type DP or DDP 12 | return model.module if is_parallel(model) else model 13 | 14 | class ModelEMA: 15 | """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models 16 | Keeps a moving average of everything in the model state_dict (parameters and buffers) 17 | For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 18 | """ 19 | 20 | def __init__(self, model, decay=0.9999, tau=2000, updates=0): 21 | # Create EMA 22 | self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA 23 | self.updates = updates # number of EMA updates 24 | self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) 25 | for p in self.ema.parameters(): 26 | p.requires_grad_(False) 27 | 28 | def update(self, model): 29 | # Update EMA parameters 30 | self.updates += 1 31 | d = self.decay(self.updates) 32 | 33 | msd = de_parallel(model).state_dict() # model state_dict 34 | for k, v in self.ema.state_dict().items(): 35 | if v.dtype.is_floating_point: # true for FP16 and FP32 36 | v *= d 37 | v += (1 - d) * msd[k].detach() 38 | # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32' -------------------------------------------------------------------------------- /doraemon/models/representation/head/arcface.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn import Module, Parameter 5 | 6 | class ArcFace(Module): 7 | """Implementation for "ArcFace: Additive Angular Margin Loss for Deep Face Recognition" 8 | """ 9 | def __init__(self, feat_dim, num_class, margin_arc=0.35, margin_am=0.0, scale=32): 10 | super(ArcFace, self).__init__() 11 | self.weight = Parameter(torch.Tensor(feat_dim, num_class)) 12 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 13 | self.margin_arc = margin_arc 14 | self.margin_am = margin_am 15 | self.scale = scale 16 | self.cos_margin = math.cos(margin_arc) 17 | self.sin_margin = math.sin(margin_arc) 18 | self.min_cos_theta = math.cos(math.pi - margin_arc) 19 | 20 | def forward(self, feats, labels): 21 | kernel_norm = F.normalize(self.weight, dim=0) 22 | feats = F.normalize(feats) 23 | cos_theta = torch.mm(feats, kernel_norm) 24 | cos_theta = cos_theta.clamp(-1, 1) 25 | sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2)) 26 | cos_theta_m = cos_theta * self.cos_margin - sin_theta * self.sin_margin 27 | # 0 <= theta + m <= pi, ==> -m <= theta <= pi-m 28 | # because 0<=theta<=pi, so, we just have to keep theta <= pi-m, that is cos_theta >= cos(pi-m) 29 | cos_theta_m = torch.where(cos_theta > self.min_cos_theta, cos_theta_m, cos_theta-self.margin_am) 30 | index = torch.zeros_like(cos_theta) 31 | index.scatter_(1, labels.data.view(-1, 1), 1) 32 | index = index.byte().bool() 33 | output = cos_theta * 1.0 34 | output[index] = cos_theta_m[index] 35 | output *= self.scale 36 | return output 37 | -------------------------------------------------------------------------------- /doraemon/models/representation/head/circleloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Module, Parameter 4 | 5 | class CircleLoss(Module): 6 | """Implementation for "Circle Loss: A Unified Perspective of Pair Similarity Optimization" 7 | Note: this is the classification based implementation of circle loss. 8 | """ 9 | def __init__(self, feat_dim, num_class, margin=0.25, gamma=256): 10 | super(CircleLoss, self).__init__() 11 | self.weight = Parameter(torch.Tensor(feat_dim, num_class)) 12 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 13 | self.margin = margin 14 | self.gamma = gamma 15 | 16 | self.O_p = 1 + margin 17 | self.O_n = -margin 18 | self.delta_p = 1-margin 19 | self.delta_n = margin 20 | 21 | def forward(self, feats, labels): 22 | kernel_norm = F.normalize(self.weight, dim=0) 23 | feats = F.normalize(feats) 24 | cos_theta = torch.mm(feats, kernel_norm) 25 | cos_theta = cos_theta.clamp(-1, 1) 26 | index_pos = torch.zeros_like(cos_theta) 27 | index_pos.scatter_(1, labels.data.view(-1, 1), 1) 28 | index_pos = index_pos.byte().bool() 29 | index_neg = torch.ones_like(cos_theta) 30 | index_neg.scatter_(1, labels.data.view(-1, 1), 0) 31 | index_neg = index_neg.byte().bool() 32 | 33 | alpha_p = torch.clamp_min(self.O_p - cos_theta.detach(), min=0.) 34 | alpha_n = torch.clamp_min(cos_theta.detach() - self.O_n, min=0.) 35 | 36 | logit_p = alpha_p * (cos_theta - self.delta_p) 37 | logit_n = alpha_n * (cos_theta - self.delta_n) 38 | 39 | output = cos_theta * 1.0 40 | output[index_pos] = logit_p[index_pos] 41 | output[index_neg] = logit_n[index_neg] 42 | output *= self.gamma 43 | return output 44 | -------------------------------------------------------------------------------- /doraemon/models/representation/head/mv_softmax.py: -------------------------------------------------------------------------------- 1 | # based on: 2 | # https://github.com/xiaoboCASIA/SV-X-Softmax/blob/master/fc_layers.py 3 | 4 | import math 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.nn import Module, Parameter 8 | 9 | class MV_Softmax(Module): 10 | """Implementation for "Mis-classified Vector Guided Softmax Loss for Face Recognition" 11 | """ 12 | def __init__(self, feat_dim, num_class, is_am, margin=0.35, mv_weight=1.12, scale=32): 13 | super(MV_Softmax, self).__init__() 14 | self.weight = Parameter(torch.Tensor(feat_dim, num_class)) 15 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 16 | self.margin = margin 17 | self.mv_weight = mv_weight 18 | self.scale = scale 19 | self.is_am = is_am 20 | self.cos_m = math.cos(margin) 21 | self.sin_m = math.sin(margin) 22 | self.threshold = math.cos(math.pi - margin) 23 | self.mm = self.sin_m * margin 24 | 25 | def forward(self, x, label): 26 | kernel_norm = F.normalize(self.weight, dim=0) 27 | x = F.normalize(x) 28 | cos_theta = torch.mm(x, kernel_norm) 29 | batch_size = label.size(0) 30 | gt = cos_theta[torch.arange(0, batch_size), label].view(-1, 1) 31 | if self.is_am: # AM 32 | mask = cos_theta > gt - self.margin 33 | final_gt = torch.where(gt > self.margin, gt - self.margin, gt) 34 | else: # arcface 35 | sin_theta = torch.sqrt(1.0 - torch.pow(gt, 2)) 36 | cos_theta_m = gt * self.cos_m - sin_theta * self.sin_m 37 | mask = cos_theta > cos_theta_m 38 | final_gt = torch.where(gt > 0.0, cos_theta_m, gt) 39 | # process hard example. 40 | hard_example = cos_theta[mask] 41 | cos_theta[mask] = self.mv_weight * hard_example + self.mv_weight - 1.0 42 | cos_theta.scatter_(1, label.data.view(-1, 1), final_gt) 43 | cos_theta *= self.scale 44 | return cos_theta 45 | -------------------------------------------------------------------------------- /doraemon/built/attention_based_pooler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # from torchvision.models import MobileNetV3, ResNet, ConvNeXt, EfficientNet, SwinTransformer 4 | 5 | 6 | class AttentionPooling(nn.Module): 7 | """ 8 | Augmenting Convolutional networks with attention-based aggregation: https://arxiv.org/abs/2112.13692 9 | """ 10 | def __init__(self, in_dim): 11 | super().__init__() 12 | self.cls_vec = nn.Parameter(torch.randn(in_dim)) 13 | self.fc = nn.Linear(in_dim, in_dim) 14 | self.softmax = nn.Softmax(-1) 15 | 16 | def forward(self, x): 17 | # x.view: B,C,H,W -> BxHxW,C 18 | weights = torch.matmul(x.reshape(-1, x.shape[1]), self.cls_vec) 19 | # weights.view: BxHxW,C -> B, HxW 20 | weights = self.softmax(weights.reshape(x.shape[0], -1)) 21 | # x.view: B,C,H,W -> B,C,HxW 22 | # (B,C,HxW) @ (B,HxW,1) -> B,C,1 23 | x = torch.bmm(x.reshape(x.shape[0], x.shape[1], -1), weights.unsqueeze(-1)).squeeze(-1) 24 | x = x + self.cls_vec 25 | x = self.fc(x) 26 | x = x + self.cls_vec 27 | return x 28 | 29 | def atten_pool_replace(model: nn.Module): 30 | #--------------------------------Custom Pooling-----------------------------------# 31 | # if type(model) is MobileNetV3: 32 | # model.avgpool = AttentionPooling(in_dim=model.classifier[0].in_features) 33 | # elif type(model) is ResNet: 34 | # model.avgpool = AttentionPooling(in_dim=model.fc.in_features) 35 | # elif type(model) is ConvNeXt: 36 | # model.avgpool = AttentionPooling(in_dim=model.classifier[-1].in_features) 37 | # elif type(model) is EfficientNet: 38 | # model.avgpool = AttentionPooling(in_dim=model.classifier[-1].in_features) 39 | # elif type(model) is SwinTransformer: 40 | # model.avgpool = AttentionPooling(in_dim=model.head.in_features) 41 | # else: 42 | # raise KeyError(f'{type(model)} not support attention-based pool') 43 | 44 | return model 45 | #-----------------------------------------------------------------------------# -------------------------------------------------------------------------------- /tools/build_querygallery.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import shutil 3 | import argparse 4 | import os 5 | from os.path import join as opj 6 | 7 | """ 8 | Before 9 | --data 10 | --ID1 11 | --xxx1.jpg 12 | --xxx2.jpg 13 | --ID2 14 | --xxx3.jpg 15 | --xxx4.jpg 16 | 17 | After 18 | --data 19 | --data-query 20 | --ID1 21 | --xxx1.jpg 22 | --ID2 23 | --xxx3.jpg 24 | --data-gallery 25 | --ID1 26 | --xxx2.jpg 27 | --ID2 28 | --xxx4.jpg 29 | """ 30 | 31 | def parse_opt(): 32 | parsers = argparse.ArgumentParser() 33 | parsers.add_argument('--src', default='data', help='Image dir') 34 | parsers.add_argument('--frac', type=float, help='Fraction of query/gallery') 35 | parsers.add_argument('--drop', action='store_true', help="Cleaning up the source directory") 36 | 37 | return parsers.parse_args() 38 | 39 | 40 | def main(opt): 41 | src = opt.src 42 | frac = opt.frac 43 | drop = opt.drop 44 | 45 | src = os.path.realpath(src) 46 | root = os.path.dirname(src) 47 | basename = os.path.basename(src) 48 | 49 | all_classes = [x for x in os.listdir(src) if not x.startswith('.')] 50 | all_classes.sort() 51 | 52 | for c in all_classes: 53 | os.makedirs(opj(root, f'{basename}-query', c), exist_ok=True) 54 | os.makedirs(opj(root, f'{basename}-gallery', c), exist_ok=True) 55 | 56 | all_files = glob.glob(opj(src, c, '*')) 57 | all_files.sort() 58 | 59 | n = len(all_files) 60 | if n == 1: continue 61 | else: 62 | n_query = int(n * frac) if int(n * frac) != 0 else 1 63 | 64 | query_files = all_files[:n_query] 65 | gallery_files = all_files[n_query:] 66 | 67 | for f in query_files: 68 | shutil.copy(f, opj(root, f'{src}-query', c, os.path.basename(f))) 69 | 70 | for f in gallery_files: 71 | shutil.copy(f, opj(root, f'{src}-gallery', c, os.path.basename(f))) 72 | 73 | if drop: 74 | shutil.rmtree(src) 75 | 76 | if __name__ == "__main__": 77 | main(parse_opt()) -------------------------------------------------------------------------------- /doraemon/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union 3 | import yaml 4 | 5 | class SmartLogger: 6 | 7 | _Instance = None 8 | _Flag = False 9 | 10 | def __new__(cls, *args, **kwargs): 11 | if cls._Instance is None: 12 | cls._Instance = super().__new__(cls) 13 | return cls._Instance 14 | 15 | def __init__(self, filename = None, level: int = 1): 16 | if not self.__class__._Flag: 17 | self.__class__._Flag = True 18 | # logger -> Singleton 19 | self.file_logger = self.create_logger('file', 'file', filename=filename) if filename is not None else None 20 | self.console_logger = self.create_logger('console', 'console') 21 | self.level = level 22 | 23 | def log(self, msg: Union[str, dict]): 24 | if isinstance(msg, dict): 25 | self.file_logger.info(yaml.dump(msg, sort_keys=False, default_flow_style=False)) 26 | 27 | else: self.file_logger.info(msg) 28 | 29 | def console(self, msg: Union[str, dict]): 30 | if isinstance(msg, dict): 31 | self.console_logger.info('\n'+str(yaml.dump(msg, sort_keys=False, default_flow_style=False))) 32 | 33 | else: 34 | self.console_logger.info(msg) 35 | 36 | def both(self, msg: Union[str, dict]): 37 | self.log(msg) 38 | self.console(msg) 39 | 40 | def create_logger(self, name: str, kind: str, filename = None): 41 | assert kind in {'file', 'console'} 42 | logger = logging.getLogger(name) 43 | logger.setLevel(logging.DEBUG) 44 | # format 45 | file_format = logging.Formatter(fmt='%(asctime)-20s%(message)s', datefmt='%Y-%m-%d %H:%M:%S') 46 | if kind == 'file': 47 | handler = logging.FileHandler(filename=filename) 48 | handler.setFormatter(file_format) 49 | logger.addHandler(handler) 50 | elif kind == 'console': 51 | handler = logging.StreamHandler() 52 | handler.setFormatter(file_format) 53 | logger.addHandler(handler) 54 | return logger 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /doraemon/models/representation/backbone/timm_wrapper.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | import torch.nn as nn 4 | 5 | class TimmWrapper(nn.Module): 6 | """A wrapper for timm models that handles different model architectures uniformly.""" 7 | 8 | def __init__(self, 9 | model_name: str, 10 | feat_dim: int, 11 | image_size: int, 12 | pretrained: bool = True, 13 | **kwargs): 14 | super().__init__() 15 | 16 | self.model = timm.create_model( 17 | model_name, 18 | pretrained=pretrained, 19 | num_classes=0, # classification head -> nn.Identity() 20 | global_pool='', # global pooling -> nn.Identity() 21 | ) 22 | 23 | with torch.no_grad(): 24 | dummy_input = torch.zeros(1, 3, image_size, image_size) 25 | output = self.model(dummy_input) 26 | 27 | if isinstance(output, tuple): 28 | output = output[0] 29 | 30 | if len(output.shape) == 4: # CNN output: [B, C, H, W] 31 | _, channels, h, w = output.shape 32 | flatten_dim = channels * h * w 33 | self.output_layer = nn.Sequential( 34 | nn.BatchNorm2d(channels), 35 | nn.Flatten(1), 36 | nn.Linear(flatten_dim, feat_dim), 37 | nn.BatchNorm1d(feat_dim) 38 | ) 39 | elif len(output.shape) == 3: # Transformer output: [B, N, C] 40 | _, tokens, channels = output.shape 41 | flatten_dim = tokens * channels 42 | self.output_layer = nn.Sequential( 43 | nn.LayerNorm(channels), 44 | nn.Flatten(1), 45 | nn.Linear(flatten_dim, feat_dim), 46 | nn.BatchNorm1d(feat_dim) 47 | ) 48 | else: 49 | raise ValueError(f"Unexpected output shape: {output.shape}") 50 | 51 | def forward(self, x): 52 | x = self.model(x) 53 | x = self.output_layer(x) 54 | return x -------------------------------------------------------------------------------- /tools/clustering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import matplotlib.pyplot as plt 4 | from sklearn.cluster import DBSCAN, HDBSCAN 5 | from sklearn.manifold import TSNE 6 | from PIL import Image 7 | import os 8 | import shutil 9 | import time 10 | from tqdm import tqdm 11 | 12 | # Load image paths and embeddings 13 | embeddings_path = [] 14 | X = [] 15 | 16 | for npy in tqdm(glob.glob("/home/duke/data/favie/v4-embedding/features/*.npy")[:1000]): 17 | basename = os.path.basename(npy).replace('.npy', '.jpg') 18 | if os.path.isfile(f"/home/duke/data/favie/v4-embedding/images/{basename}"): 19 | x = np.load(npy) 20 | X.append(x) 21 | embeddings_path.append(npy) 22 | X = np.stack(X) 23 | embeddings_path = np.array(embeddings_path) 24 | 25 | # Perform DBSCAN clustering 26 | db = DBSCAN(eps=0.4, 27 | min_samples=5, 28 | metric="cosine", 29 | n_jobs=16).fit(X) 30 | # db = HDBSCAN(min_cluster_size = 10, 31 | # min_samples = 5, 32 | # cluster_selection_epsilon = 0.2, 33 | # metric = "cosine", 34 | # n_jobs = 16).fit(X) 35 | labels = db.labels_ 36 | 37 | # Number of clusters in labels, ignoring noise (-1 is noise) 38 | n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0) 39 | n_noise_ = list(labels).count(-1) 40 | 41 | print("Estimated number of clusters: %d" % n_clusters_) 42 | print("Estimated number of noise points: %d" % n_noise_) 43 | 44 | image_src = "/home/duke/data/favie/v4-embedding/images" 45 | cluster = "/home/duke/data/favie/v4-embedding/cluster" 46 | os.makedirs(cluster, exist_ok=True) 47 | for vis_label in range(n_clusters_): 48 | vis_idx = np.where(labels == vis_label)[0] 49 | vis_path = embeddings_path[vis_idx] 50 | vis_path = list(map(lambda x: os.path.basename(x).split(".")[0], vis_path)) 51 | 52 | for path in vis_path: 53 | target_dir = os.path.join(cluster, str(vis_label)) 54 | os.makedirs(target_dir, exist_ok=True) 55 | shutil.copy(os.path.join(image_src, f"{path}.jpg"), target_dir) 56 | 57 | # for path in vis_path: 58 | # img = Image.open(f"./images-test/{path}.jpg") 59 | # img = pad_to_square(img) 60 | # img = resize_image(img) 61 | # plt.imshow(img) 62 | # plt.show() 63 | -------------------------------------------------------------------------------- /doraemon/models/representation/head/magface.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.nn import Module, Parameter 5 | 6 | class MagFace(Module): 7 | """Implementation for "ArcFace: Additive Angular Margin Loss for Deep Face Recognition" 8 | """ 9 | def __init__(self, feat_dim, num_class, margin_am=0.0, scale=32, l_a=10, u_a=110, l_margin=0.45, u_margin=0.8, lamda=20): 10 | super(MagFace, self).__init__() 11 | self.weight = Parameter(torch.Tensor(feat_dim, num_class)) 12 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 13 | self.margin_am = margin_am 14 | self.scale = scale 15 | self.l_a = l_a 16 | self.u_a = u_a 17 | self.l_margin = l_margin 18 | self.u_margin = u_margin 19 | self.lamda = lamda 20 | 21 | def calc_margin(self, x): 22 | margin = (self.u_margin-self.l_margin) / \ 23 | (self.u_a-self.l_a)*(x-self.l_a) + self.l_margin 24 | return margin 25 | 26 | def forward(self, feats, labels): 27 | x_norm = torch.norm(feats, dim=1, keepdim=True).clamp(self.l_a, self.u_a)# l2 norm 28 | ada_margin = self.calc_margin(x_norm) 29 | cos_m, sin_m = torch.cos(ada_margin), torch.sin(ada_margin) 30 | loss_g = 1/(self.u_a**2) * x_norm + 1/(x_norm) 31 | kernel_norm = F.normalize(self.weight, dim=0) 32 | feats = F.normalize(feats) 33 | cos_theta = torch.mm(feats, kernel_norm) 34 | cos_theta = cos_theta.clamp(-1, 1) 35 | sin_theta = torch.sqrt(1.0 - torch.pow(cos_theta, 2)) 36 | cos_theta_m = cos_theta * cos_m - sin_theta * sin_m 37 | # 0 <= theta + m <= pi, ==> -m <= theta <= pi-m 38 | # because 0<=theta<=pi, so, we just have to keep theta <= pi-m, that is cos_theta >= cos(pi-m) 39 | min_cos_theta = torch.cos(math.pi - ada_margin) 40 | cos_theta_m = torch.where(cos_theta > min_cos_theta, cos_theta_m, cos_theta-self.margin_am) 41 | index = torch.zeros_like(cos_theta) 42 | index.scatter_(1, labels.data.view(-1, 1), 1) 43 | index = index.byte().bool() 44 | output = cos_theta * 1.0 45 | output[index] = cos_theta_m[index] 46 | output *= self.scale 47 | return output, self.lamda*loss_g 48 | -------------------------------------------------------------------------------- /doraemon/engine/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | from functools import wraps 3 | from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR 4 | 5 | __all__ = ['linear', # 线性衰减 6 | 'cosine', # 余弦衰减 7 | 'linear_with_warm', # 线性衰减带热身 8 | 'cosine_with_warm', # 余弦衰减带热身 9 | 'create_Scheduler', 10 | 'list_schedulers', ] 11 | 12 | SCHEDULER = {} 13 | def register_scheduler(fn: Callable): 14 | key = fn.__name__ 15 | if key in SCHEDULER: 16 | raise ValueError(f"An entry is already registered under the name '{key}'.") 17 | SCHEDULER[key] = fn 18 | @wraps(fn) 19 | def wrapper(*args, **kwargs): 20 | return fn(*args, **kwargs) 21 | 22 | return wrapper 23 | 24 | def de_lrf_ratio(lrf_ratio): 25 | return 0.1 if lrf_ratio is None else lrf_ratio 26 | 27 | @register_scheduler 28 | def linear(optimizer, warm_ep, epochs, lr0, lrf_ratio): 29 | return LinearLR(optimizer, start_factor=1, end_factor=de_lrf_ratio(lrf_ratio), total_iters=epochs) 30 | 31 | @register_scheduler 32 | def cosine(optimizer, warm_ep, epochs, lr0, lrf_ratio): 33 | return CosineAnnealingLR(optimizer, T_max=epochs, eta_min=de_lrf_ratio(lrf_ratio) * lr0, ) 34 | 35 | @register_scheduler 36 | def linear_with_warm(optimizer, warm_ep, epochs, lr0, lrf_ratio): 37 | scheduler = SequentialLR( 38 | optimizer = optimizer, 39 | schedulers=[ 40 | LinearLR(optimizer, start_factor=0.1, end_factor=1, total_iters=warm_ep), 41 | LinearLR(optimizer, start_factor=1, end_factor=de_lrf_ratio(lrf_ratio), total_iters=epochs-warm_ep), 42 | ], 43 | milestones=[warm_ep,] 44 | ) 45 | return scheduler 46 | 47 | @register_scheduler 48 | def cosine_with_warm(optimizer, warm_ep, epochs, lr0, lrf_ratio): 49 | scheduler = SequentialLR( 50 | optimizer=optimizer, 51 | schedulers=[ 52 | LinearLR(optimizer, start_factor=0.1, end_factor=1, total_iters=warm_ep), 53 | CosineAnnealingLR(optimizer, T_max=epochs-warm_ep, eta_min=de_lrf_ratio(lrf_ratio) * lr0, ) 54 | ], 55 | milestones=[warm_ep, ] 56 | ) 57 | return scheduler 58 | 59 | def create_Scheduler(scheduler, optimizer, warm_ep, epochs, lr0, lrf_ratio): 60 | return SCHEDULER[scheduler](optimizer, warm_ep, epochs, lr0, lrf_ratio) 61 | 62 | def list_schedulers(): 63 | lossfns = [k for k, v in SCHEDULER.items()] 64 | return sorted(lossfns) -------------------------------------------------------------------------------- /doraemon/dataset/dataprocessor.py: -------------------------------------------------------------------------------- 1 | from .basedataset import ImageDatasets 2 | from ..built.class_augmenter import ClassWiseAugmenter 3 | import torch 4 | import os 5 | from torch.utils.data import DataLoader 6 | from typing import Optional 7 | 8 | class SmartDataProcessor: 9 | def __init__(self, data_cfgs: dict, rank, project, training: bool = True): 10 | self.data_cfgs = data_cfgs # root, nw, imgsz, train, val 11 | self.rank = rank 12 | self.project = project 13 | self.label_transforms = None # used in CenterProcessor.__init__ 14 | 15 | if training: 16 | self.train_dataset = self.create_dataset('train') 17 | 18 | def create_dataset(self, mode: str, training: bool = True, id2label: Optional[dict] = None): 19 | assert mode in {'train', 'val'} 20 | 21 | cfg = self.data_cfgs.get(mode, -1) 22 | if isinstance(cfg, dict): 23 | dataset = ImageDatasets(root_or_dataset=self.data_cfgs['root'], mode=mode, 24 | transforms=ClassWiseAugmenter(cfg['augment'], None, None) if mode == 'val' else \ 25 | ClassWiseAugmenter(cfg['augment'], cfg['class_aug'], cfg['base_aug']), 26 | project=self.project, rank=self.rank, training = training, id2label=id2label) 27 | else: 28 | dataset = None 29 | return dataset 30 | 31 | def set_augment(self, mode: str, transforms = None): # sequence -> T.Compose([...]) 32 | if transforms is None: 33 | transforms = self.val_dataset.transforms.base_transforms 34 | dataset = getattr(self, f'{mode}_dataset') 35 | dataset.transforms.base_transforms = transforms 36 | 37 | def auto_aug_weaken(self, epoch: int, milestone: int, sequence: Optional[torch.nn.Module] = None): 38 | if epoch == milestone: 39 | # sequence = create_AugTransforms('random_horizonflip to_tensor normalize') 40 | self.set_augment('train', transforms = sequence) 41 | 42 | @staticmethod 43 | def set_dataloader(dataset, bs: int = 256, nw: int = 0, pin_memory: bool = True, shuffle: bool = True, sampler = None, collate_fn= None, *args, **kwargs): 44 | assert not (shuffle and sampler is not None) 45 | nd = torch.cuda.device_count() 46 | nw = min([os.cpu_count() // max(nd, 1), nw]) 47 | return DataLoader(dataset=dataset, batch_size=bs, num_workers=nw, pin_memory=pin_memory, sampler=sampler, shuffle=shuffle, collate_fn=collate_fn, *args, **kwargs) -------------------------------------------------------------------------------- /tools/dataset_upload_hf.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import os 3 | from PIL import Image 4 | from typing import Dict, List 5 | from datasets import Dataset, Features, Image as ImageFeature, Value 6 | 7 | class FaceDataset: 8 | def __init__(self, data_dir): 9 | self.data_dir = data_dir 10 | 11 | def generate_examples(self): 12 | examples = [] 13 | identity_dirs = sorted([d for d in os.listdir(self.data_dir) 14 | if os.path.isdir(os.path.join(self.data_dir, d))]) 15 | id_to_label = {id_name: idx for idx, id_name in enumerate(identity_dirs)} 16 | 17 | for identity in identity_dirs: 18 | identity_dir = os.path.join(self.data_dir, identity) 19 | label = id_to_label[identity] 20 | 21 | for file_name in os.listdir(identity_dir): 22 | if file_name.lower().endswith(('.png', '.jpg', '.jpeg')): 23 | image_path = os.path.join(identity_dir, file_name) 24 | try: 25 | examples.append({ 26 | "image": image_path, 27 | "label": label, 28 | "class_name": identity, 29 | "file_name": file_name 30 | }) 31 | except Exception as e: 32 | print(f"Error processing {image_path}: {e}") 33 | continue 34 | return examples 35 | 36 | def create_and_upload_dataset(data_dir): 37 | # Create dataset instance 38 | face_dataset = FaceDataset(data_dir) 39 | 40 | # Generate examples 41 | examples = face_dataset.generate_examples() 42 | 43 | # Create Dataset object 44 | dataset = Dataset.from_list( 45 | examples, 46 | features=Features({ 47 | "image": ImageFeature(), 48 | "label": Value("int64"), 49 | "class_name": Value("string"), 50 | "file_name": Value("string") 51 | }) 52 | ) 53 | 54 | # Upload to Hub 55 | dataset.push_to_hub( 56 | "User/DatasetName", 57 | private=False, 58 | max_shard_size="500MB" 59 | ) 60 | 61 | if __name__ == "__main__": 62 | import argparse 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument("--data_dir", type=str, required=True, help="Path to the face data directory") 65 | args = parser.parse_args() 66 | 67 | create_and_upload_dataset(args.data_dir) -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import torch 3 | from torch.distributed import init_process_group 4 | from doraemon import ( 5 | CenterProcessor, 6 | yaml_load, 7 | increment_path, 8 | DistillCenterProcessor, 9 | colorstr, 10 | check 11 | ) 12 | import os 13 | import argparse 14 | from pathlib import Path 15 | import warnings 16 | warnings.filterwarnings('ignore') 17 | 18 | ROOT = Path(os.path.dirname(__file__)) 19 | LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) 20 | WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) 21 | 22 | def parse_opt(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('cfgs', default = ROOT / 'configs/classification/pet.yaml', help='configs for models, data, hyps') 25 | parser.add_argument('--resume', default = '', help='if no resume, not write') 26 | parser.add_argument('--sync_bn', action='store_true', help='turn on syncBN, if on, speed will be slower') 27 | parser.add_argument('--project', default=ROOT / 'run', help='save to project/name') 28 | parser.add_argument('--name', default='exp', help='save to project/name') 29 | parser.add_argument('--distill', action='store_true') 30 | parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify') 31 | 32 | # face/cbir 33 | parser.add_argument('--print_freq', type=int, default=50, help='The print frequency for training state') 34 | parser.add_argument('--save_freq', type=int, default=5, help='The checkpoint frequency for saving state_dict epoch-wise, not contains warm epochs') 35 | return parser.parse_args() 36 | 37 | def main(opt): 38 | save_dir = increment_path(Path(opt.project) / opt.name) 39 | opt.save_dir = save_dir 40 | 41 | assert torch.cuda.device_count() > LOCAL_RANK 42 | # init process groups 43 | if LOCAL_RANK != -1: 44 | init_process_group(backend='nccl', world_size = WORLD_SIZE, rank = LOCAL_RANK) 45 | 46 | # configs 47 | cfgs = yaml_load(opt.cfgs) 48 | task: str= cfgs['model'].get('task', None) 49 | 50 | # check configs 51 | check(task, cfgs) 52 | 53 | # init cpu 54 | cpu = CenterProcessor(cfgs, LOCAL_RANK, project=save_dir, opt=opt) \ 55 | if not opt.distill else DistillCenterProcessor(cfgs, LOCAL_RANK, project=save_dir, opt=opt) 56 | 57 | # record config 58 | shutil.copy(opt.cfgs, save_dir) 59 | 60 | # syncBN 61 | if LOCAL_RANK != -1 and opt.sync_bn: 62 | cpu.set_sync_bn() 63 | if LOCAL_RANK == 0: 64 | cpu.logger.both(f'{colorstr("yellow", "Attention")}: sync_bn is on') 65 | # run 66 | cpu.run_classifier(resume=opt.resume if opt.resume else None) \ 67 | if task == 'classification' else cpu.run_embedding(resume=opt.resume if opt.resume else None) 68 | 69 | if __name__ == '__main__': 70 | opts = parse_opt() 71 | main(opts) -------------------------------------------------------------------------------- /configs/representation/README.md: -------------------------------------------------------------------------------- 1 | # Representation Learning Configuration instructions 2 | 3 | * model 4 | * [Model Config](backbone_conf.yaml) 5 | * [Head Config](head_conf.yaml) 6 | 7 | ```markdown 8 | model: 9 | task: cbir 10 | image_size: &imgsz 224 11 | backbone: 12 | swintransformer: # tiny small base 13 | model_size: base 14 | image_size: *imgsz 15 | feat_dim: &featd 128 16 | head: 17 | arcface: 18 | feat_dim: *featd 19 | num_class: 58671 20 | margin_arc: 0.35 21 | margin_am: 0.0 22 | scale: 32 23 | ``` 24 | 25 | * data 26 | ```markdown 27 | data: 28 | root: 29 | nw: 64 # if not multi-nw, set to 0 30 | train: 31 | bs: 80 # per gpu 32 | base_aug: null 33 | class_aug: null 34 | augment: # refer to utils/augment.py 35 | - random_choice: 36 | transforms: 37 | - random_color_jitter: 38 | brightness: 0.1 39 | contrast: 0.1 40 | saturation: 0.1 41 | hue: 0.1 42 | - random_cutout: 43 | n_holes: 3 44 | length: 12 45 | prob: 0.1 46 | color: [0, 255] 47 | - random_gaussianblur: 48 | kernel_size: 5 49 | - random_rotate: 50 | degrees: 10 51 | - random_adjustsharpness: 52 | p: 0.5 53 | - random_horizonflip: 54 | p: 0.5 55 | - random_choice: 56 | transforms: 57 | - resize_and_padding: 58 | size: *imgsz 59 | training: True 60 | - random_crop_and_resize: 61 | size: *imgsz 62 | scale: [0.7, 1] 63 | p: [0.9, 0.1] 64 | - to_tensor: no_params 65 | - normalize: 66 | mean: [0.485, 0.456, 0.406] 67 | std: [0.229, 0.224, 0.225] 68 | aug_epoch: 24 69 | val: 70 | bs: 128 71 | metrics: 72 | metrics: [mrr, recall, precision, auc, ndcg] 73 | cutoffs: [1, 3, 5] 74 | augment: 75 | - resize_and_padding: 76 | size: *imgsz 77 | training: False 78 | - to_tensor: no_params 79 | - normalize: 80 | mean: [0.485, 0.456, 0.406] 81 | std: [0.229, 0.224, 0.225] 82 | ``` 83 | * hyp 84 | ```markdown 85 | hyp: 86 | epochs: 25 87 | lr0: 0.006 88 | lrf_ratio: null # decay to lrf_ratio * lr0, if None, 0.1 89 | momentum: 0.937 90 | weight_decay: 0.0005 91 | warmup_momentum: 0.8 92 | warm_ep: 1 93 | loss: 94 | ce: True 95 | label_smooth: 0.0 96 | optimizer: 97 | - sgd # sgd, adam or sam 98 | - True # Different layers in the model set different learning rates, in built/layer_optimizer 99 | scheduler: cosine_with_warm # linear or cosine 100 | ``` -------------------------------------------------------------------------------- /tools/data_prepare.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import argparse 3 | import os 4 | from os.path import join as opj 5 | import shutil 6 | import pandas as pd 7 | from typing import Union 8 | 9 | """ 10 | project 11 | │ 12 | ├── data 13 | │ ├── clsXXX - 1 14 | │ ├── clsXXX - 2 15 | │ ├── clsXXX - ... 16 | ├── tools 17 | │ ├── data_prepare.py 18 | 19 | | 20 | | 21 | / 22 | 23 | project 24 | │ 25 | ├── data 26 | │ ├── train 27 | │ ├── clsXXX 28 | │ ├── XXX.jpg / png 29 | │ ├── val 30 | │ ├── clsXXX 31 | │ ├── XXX.jpg / png 32 | ├── tools 33 | │ ├── data_prepare.py 34 | """ 35 | 36 | def parse_opt(): 37 | parsers = argparse.ArgumentParser() 38 | parsers.add_argument('--postfix', default='jpg', help='postfix of image files') 39 | parsers.add_argument('--root', default='data', help='image dir') 40 | parsers.add_argument('--frac', type=float, nargs='+', help='fraction of train/val') 41 | parsers.add_argument('--drop', action='store_true') 42 | 43 | return parsers.parse_args() 44 | 45 | def data_split(postfix: str, root: str, frac: Union[float, list], drop: bool): 46 | 47 | all_classes = [x for x in os.listdir(root) if not x.startswith('.')] 48 | all_classes.sort() 49 | 50 | if len(frac) > 1: assert len(frac) == len(all_classes), 'if more frac, make sure every class should have a frac, len(frac) == len(all_classes)' 51 | else: 52 | a = frac[0] 53 | frac = [a for _ in all_classes] 54 | 55 | modes = ['train', 'val'] 56 | for m in modes: 57 | os.makedirs(opj(root, m), exist_ok=True) 58 | 59 | for i, cls in enumerate(all_classes): 60 | for m in modes: 61 | os.makedirs(opj(root, m, cls), exist_ok=True) 62 | 63 | s = pd.Series(glob.glob(opj(root, cls, f'*.{postfix}'))) 64 | train = s.sample(frac=frac[i]) 65 | val = s[~s.isin(train)] 66 | 67 | train.apply(lambda x: shutil.copy(x, opj(root, 'train', cls))) 68 | val.apply(lambda x: shutil.copy(x, opj(root, 'val', cls))) 69 | 70 | if drop: 71 | shutil.rmtree(opj(root, cls)) 72 | 73 | print(opj(root, cls), ' completed') 74 | 75 | if __name__ == '__main__': 76 | opt = parse_opt() 77 | data_split(opt.postfix, opt.root, opt.frac, opt.drop) -------------------------------------------------------------------------------- /doraemon/models/losses/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Callable 4 | from functools import wraps 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | __all__ = ['bce', 9 | 'ce', 10 | 'focal', 11 | 'create_Lossfn', 12 | 'list_lossfns',] 13 | 14 | LOSS = {} 15 | 16 | def register_loss(fn: Callable): 17 | key = fn.__name__ 18 | if key in LOSS: 19 | raise ValueError(f"An entry is already registered under the name '{key}'.") 20 | LOSS[key] = fn 21 | @wraps(fn) 22 | def wrapper(*args, **kwargs): 23 | return fn(*args, **kwargs) 24 | 25 | return wrapper 26 | 27 | class FocalLoss(nn.Module): 28 | # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) 29 | def __init__(self, loss_fcn, gamma=1.5, alpha= 0.25): 30 | super().__init__() 31 | self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() 32 | self.gamma = gamma 33 | self.alpha = alpha 34 | self.reduction = loss_fcn.reduction 35 | self.loss_fcn.reduction = 'none' # required to apply FL to each element 36 | 37 | def forward(self, pred, true): 38 | loss = self.loss_fcn(pred, true) 39 | # p_t = torch.exp(-loss) 40 | # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability 41 | 42 | # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py 43 | pred_prob = torch.sigmoid(pred) # prob from logits 44 | p_t = true * pred_prob + (1 - true) * (1 - pred_prob) 45 | alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) 46 | modulating_factor = (1.0 - p_t) ** self.gamma 47 | loss *= alpha_factor * modulating_factor 48 | 49 | if self.reduction == 'mean': 50 | return loss.mean() 51 | elif self.reduction == 'sum': 52 | return loss.sum() 53 | else: # 'none' 54 | return loss 55 | 56 | class DistillKL(nn.Module): 57 | """Distilling the Knowledge in a Neural Network""" 58 | def __init__(self, T: float): 59 | super(DistillKL, self).__init__() 60 | self.T = T 61 | 62 | def forward(self, y_s: Tensor, y_t: Tensor): 63 | p_s = F.log_softmax(y_s/self.T, dim=1) 64 | p_t = F.softmax(y_t/self.T, dim=1) 65 | loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0] 66 | return loss 67 | 68 | @register_loss 69 | def bce(): 70 | return nn.BCEWithLogitsLoss() 71 | @register_loss 72 | def ce(label_smooth: float = 0.): 73 | return nn.CrossEntropyLoss(label_smoothing=label_smooth) 74 | @register_loss 75 | def focal(gamma=1.5, alpha= 0.25): 76 | return FocalLoss(loss_fcn = nn.BCEWithLogitsLoss(), alpha=alpha, gamma=gamma) 77 | def create_Lossfn(lossfn: str): 78 | lossfn = lossfn.strip() 79 | return LOSS[lossfn] 80 | 81 | def list_lossfns(): 82 | lossfns = [k for k, v in LOSS.items()] 83 | return sorted(lossfns) 84 | -------------------------------------------------------------------------------- /deploy/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Doraemon Classifier Deployment API 3 | 4 | ## File Structure 5 | The Doraemon model requires the following three files: 6 | - `config.json` - Model configuration file 7 | - `doraemon_modeling.py` - Model implementation code 8 | - `xxx.pt` - Model weights file 9 | 10 | ## Loading the Model from Hugging Face 11 | Here's an example of loading and using the Doraemon model with the Transformers library: 12 | 13 | ```python 14 | from transformers import AutoModel, AutoProcessor 15 | import requests 16 | from io import BytesIO 17 | from PIL import Image 18 | 19 | # Load model and processor 20 | pretrained_model_name_or_path = "user/repo" 21 | model = AutoModel.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, revision="v2") 22 | processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True, revision="v2") 23 | 24 | # Prepare images 25 | urls = [ 26 | "https://github.com/ultralytics/ultralytics/blob/main/ultralytics/assets/bus.jpg", 27 | "https://github.com/ultralytics/ultralytics/blob/main/ultralytics/assets/zidane.jpg", 28 | ] 29 | 30 | images = [] 31 | for url in urls: 32 | response = requests.get(url) 33 | image = Image.open(BytesIO(response.content)) 34 | images.append(image) 35 | 36 | # Process input 37 | batch_inputs = processor(images, return_tensors="pt") 38 | 39 | # Get output 40 | probs = model(batch_inputs["pixel_values"]) 41 | print("Probs:\n", probs) 42 | output = processor.postprocess(probs) 43 | print("Tagging:") 44 | for r in output: 45 | print(r) 46 | ``` 47 | 48 | ## Usage 49 | 50 | ### Local API 51 | 52 | #### Option A: Using Config File 53 | Define `model_path` in `config.json` pointing to the absolute path of the model weights file. 54 | 55 | #### Option B: Using Command Line 56 | ```bash 57 | python doraemon_modeling.py --model_path /path/to/xxx.pt 58 | ``` 59 | This will override the settings in `config.json`. 60 | 61 | ### HF Remote API 62 | 63 | #### Step 1: Create a Hugging Face Repository 64 | ```bash 65 | # Install Hugging Face Hub 66 | pip install huggingface_hub 67 | 68 | # Login for authentication 69 | huggingface-cli login 70 | ``` 71 | 72 | #### Step 2: Configure the Model 73 | Set `model_path` in `config.json` with the format `user/repo/filename:revision` 74 | 75 | Required components: 76 | - `user`: Your Hugging Face username 77 | - `repo`: Repository name 78 | - `filename`: Weights file name (e.g., `best.pt`) 79 | - `revision`: Version number (e.g., `v3`) 80 | 81 | #### Step 3: Push to Hugging Face 82 | Push the following files to your Hugging Face repository: 83 | - `config.json` 84 | - `doraemon_modeling.py` 85 | - `xxx.pt` 86 | 87 | #### Step 4: Call API 88 | ```bash 89 | python doraemon_modeling.py --pretrained_model_name_or_path user/repo --revision v3 90 | ``` 91 | 92 | > **Note**: When `--pretrained_model_name_or_path` points to a Hugging Face repository, the local `deploy/doraemon_modeling.py` content will be ignored as the code from the repository will be executed instead. 93 | 94 | ## Version Control 95 | Make sure to specify the correct `revision` parameter when loading the model to match the version of your deployed model. 96 | -------------------------------------------------------------------------------- /doraemon/models/representation/head/head_def.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from .arcface import ArcFace 3 | from .circleloss import CircleLoss 4 | from .mv_softmax import MV_Softmax 5 | from .magface import MagFace 6 | 7 | class HeadFactory: 8 | """Factory to produce head according to the head_conf.yaml 9 | 10 | Attributes: 11 | head_type(str): which head will be produce. 12 | head_param(dict): parsed params and it's value. 13 | """ 14 | def __init__(self, head_config): 15 | for k, v in head_config.items(): self.head_type, self.head_param = k, v 16 | 17 | def get_head(self): 18 | if self.head_type == 'arcface': 19 | feat_dim = self.head_param['feat_dim'] # dimension of the output features, e.g. 512 20 | num_class = self.head_param['num_class'] # number of classes in the training set. 21 | margin_arc = self.head_param['margin_arc'] # cos(theta + margin_arc). 22 | margin_am = self.head_param['margin_am'] # cos_theta - margin_am. 23 | scale = self.head_param['scale'] # the scaling factor for cosine values. 24 | head = ArcFace(feat_dim, num_class, margin_arc, margin_am, scale) 25 | 26 | elif self.head_type == 'magface': 27 | feat_dim = self.head_param['feat_dim'] # dimension of the output features, e.g. 512 28 | num_class = self.head_param['num_class'] # number of classes in the training set. 29 | margin_am = self.head_param['margin_am'] # cos_theta - margin_am. 30 | scale = self.head_param['scale'] # the scaling factor for cosine values. 31 | l_a = self.head_param['l_a'] 32 | u_a = self.head_param['u_a'] 33 | l_margin = self.head_param['l_margin'] 34 | u_margin = self.head_param['u_margin'] 35 | lamda = self.head_param['lamda'] 36 | head = MagFace(feat_dim, num_class, margin_am, scale, l_a, u_a, l_margin, u_margin, lamda) 37 | 38 | elif self.head_type == 'circleloss': 39 | feat_dim = self.head_param['feat_dim'] # dimension of the output features, e.g. 512 40 | num_class = self.head_param['num_class'] # number of classes in the training set. 41 | margin = self.head_param['margin'] # O_p = 1 + margin, O_n = -margin. 42 | gamma = self.head_param['gamma'] # the scale facetor. 43 | head = CircleLoss(feat_dim, num_class, margin, gamma) 44 | 45 | elif self.head_type == 'mv-softmax': 46 | feat_dim = self.head_param['feat_dim'] # dimension of the output features, e.g. 512 47 | num_class = self.head_param['num_class'] # number of classes in the training set. 48 | is_am = self.head_param['is_am'] # am-softmax for positive samples. 49 | margin = self.head_param['margin'] # margin for positive samples. 50 | mv_weight = self.head_param['mv_weight'] # weight for hard negtive samples. 51 | scale = self.head_param['scale'] # the scaling factor for cosine values. 52 | head = MV_Softmax(feat_dim, num_class, is_am, margin, mv_weight, scale) 53 | 54 | else: 55 | raise NotImplemented("only arcface, magface, circleloss and mv-softmax are supported now !") 56 | return head 57 | -------------------------------------------------------------------------------- /tools/video_predict.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torchvision.models import get_model 4 | from PIL import Image 5 | import torchvision.transforms as T 6 | import argparse 7 | from dataset.transforms import create_AugTransforms 8 | import cv2 9 | import numpy as np 10 | import os 11 | 12 | def parse_opt(): 13 | parsers = argparse.ArgumentParser() 14 | parsers.add_argument('--video', default='./record/5000201400/4204742643555836267_0_2023-05-25-01-02-17_2023-05-25-01-04-01.mp4', type=str) 15 | parsers.add_argument('--pt', default='./best.pt', type=str) 16 | parsers.add_argument('--transforms', default='centercrop_resize to_tensor_without_div') 17 | parsers.add_argument('--imgsz', default='[[720, 720], [224, 224]]', type=str) 18 | parsers.add_argument('--output', default=False, type=bool) 19 | parsers.add_argument('--names', default='[0,2,4,6,7,8,10]', type=str) 20 | parsers.add_argument('--sample', default=0.5, type=float, help='retain ratio') 21 | parsers.add_argument('--fps', default=25, type=int, help='FPS') 22 | parsers.add_argument('--video_imgsz', default='[720, 1280]', type=str, help='h w') 23 | 24 | args = parsers.parse_args() 25 | return args 26 | 27 | def image_process(frame: np.array, transforms: T.Compose): 28 | img = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 29 | img = Image.fromarray(img) 30 | return transforms(img).unsqueeze(0) 31 | 32 | def main(opt): 33 | 34 | # variable 35 | video_path = opt.video 36 | weight_path = opt.pt 37 | transforms = opt.transforms 38 | imgsz = opt.imgsz 39 | is_output = opt.output 40 | names = eval(opt.names) 41 | sample = opt.sample 42 | fps = opt.fps 43 | video_imgsz = eval(opt.video_imgsz) 44 | 45 | if is_output: 46 | filename = f'{os.path.splitext(video_path)[0]}_new.mp4' 47 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 48 | out = cv2.VideoWriter(filename, fourcc, 25, (video_imgsz[1], video_imgsz[0])) # width, height 49 | # image 50 | # 获得视频的格式 51 | videoCapture = cv2.VideoCapture(video_path) 52 | 53 | # model 54 | model = get_model('mobilenet_v2', width_mult = 0.25) 55 | model.classifier[-1] = torch.nn.Linear(model.classifier[-1].in_features, 7) 56 | weight = torch.load(weight_path, map_location='cpu')['model'] 57 | model.load_state_dict(weight) 58 | # eval 59 | model.eval() 60 | success, frame = videoCapture.read() 61 | while success: 62 | if random.random() > sample: 63 | success, frame = videoCapture.read() 64 | continue 65 | image = image_process(frame, create_AugTransforms(transforms, eval(imgsz))) 66 | 67 | result = torch.nn.functional.softmax(model(image), dim=-1)[0] 68 | idxes = result.argsort(0, descending=True) 69 | 70 | text = '\n'.join(f'{result[j].item():.2f} {names[j]}' for j in idxes).split('\n') 71 | 72 | cv2.putText(frame, str(text), (5, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.75, 73 | (0, 0, 255), 2) 74 | 75 | if is_output: out.write(frame) 76 | else: 77 | cv2.imshow('windows', frame) # 显示 78 | cv2.waitKey(int(1000 / int(fps))) # 延迟 79 | success, frame = videoCapture.read() # 获取下一帧 80 | 81 | videoCapture.release() 82 | if is_output: out.release() 83 | 84 | if __name__ == '__main__': 85 | opt = parse_opt() 86 | main(opt) -------------------------------------------------------------------------------- /tools/test_augment.py: -------------------------------------------------------------------------------- 1 | from doraemon import create_AugTransforms 2 | from PIL import Image 3 | import numpy as np 4 | import argparse 5 | 6 | def parse_opt(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('-m','--img_path', type=str, default='/Users/wuji/Desktop/cluster 2/11/172546-B0CXT6VGDJ-52212-14-640-556.jpg', help='Path of raw image') 9 | parser.add_argument('-o', '--output_path', type=str, default='save_img.jpg', help='Path to save image') 10 | parser.add_argument('-H', '--height', default=4, help='Height of the jagged grid') 11 | parser.add_argument('-W', '--width', default=7, help='Width of the jagged grid') 12 | 13 | return parser.parse_args() 14 | 15 | def create_augs(): 16 | augs = { 17 | # 'resize_and_padding': dict(size = 224), 18 | # 'random_color_jitter': dict(brightness=0.1, contrast = 0.1, saturation = 0.1, hue = 0.1), 19 | # 'random_equalize': 'no_params', 20 | # 'random_crop_and_resize': dict(size = 224, scale=(0.7, 1.0)), 21 | # 'random_cutout': dict(n_holes=3, length=24, prob=0.1, color = (0, 255)), 22 | # 'random_grayscale': 'no_params', 23 | # 'random_gaussianblur': dict(kernel_size=5), 24 | # 'random_localgaussian': dict(ksize = (37, 37)), 25 | # 'random_rotate': dict(degrees = 20), 26 | # 'random_doubleflip': dict(prob=0.5), 27 | # 'random_horizonflip': dict(p=0.5), 28 | # 'random_autocontrast': dict(p=0.5), 29 | # 'random_adjustsharpness': dict(p=0.5), 30 | # 'pad2square': 'no_params', 31 | # 'resize': dict(size=224), 32 | } 33 | augs = [ 34 | {"random_choice": 35 | dict( 36 | transforms = [ 37 | dict(random_color_jitter = dict(brightness=0.1, contrast = 0.1, saturation = 0.1, hue = 0.1)), 38 | dict(random_cutout = dict(n_holes=3, length=12, prob=0.1, color = (0, 255))), 39 | dict(random_gaussianblur = dict(kernel_size=5)), 40 | dict(random_rotate = dict(degrees = 20)), 41 | dict(random_autocontrast = dict(p=0.5)), 42 | dict(random_adjustsharpness=dict(p=0.5)), 43 | dict(random_augmix=dict(severity=3)), 44 | ], 45 | ), 46 | }, 47 | {"random_choice": dict( 48 | transforms = [ 49 | dict(resize_and_padding = dict(size=224)), 50 | dict(random_crop_and_resize = dict(size = 224, scale=(0.7, 1)),) 51 | ], 52 | ) 53 | }, 54 | {"random_horizonflip": dict(p=0.5)}, 55 | ] 56 | 57 | return augs 58 | 59 | def main(args): 60 | 61 | augs = create_augs() 62 | 63 | t = create_AugTransforms(augs) 64 | 65 | img = Image.open(args.img_path).convert("RGB") 66 | 67 | images = [] 68 | for i in range(1, args.height * args.width + 1): 69 | image = t(img) 70 | images.append(image) 71 | 72 | array_images = [np.array(image) for image in images] 73 | 74 | rows, cols = args.height, args.width 75 | grid = np.zeros((rows * array_images[0].shape[0], cols * array_images[0].shape[1], 3), dtype=np.uint8) 76 | 77 | for r in range(rows): 78 | for c in range(cols): 79 | image_index = r * cols + c 80 | if image_index < len(array_images): 81 | grid[r*array_images[0].shape[0]:(r+1)*array_images[0].shape[0], c*array_images[0].shape[1]:(c+1)*array_images[0].shape[1], :] = array_images[image_index] 82 | 83 | img_pillow = Image.fromarray(grid) 84 | img_pillow.show() 85 | img_pillow.save(args.output_path) 86 | 87 | if __name__ == '__main__': 88 | main(parse_opt()) -------------------------------------------------------------------------------- /configs/representation/face.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | task: face 3 | image_size: &imgsz 224 4 | load_from: null 5 | backbone: 6 | #----------------------------Transformer-----------------------------------------# 7 | # timm-swin_base_patch4_window7_224.ms_in22k_ft_in1k: # imgsz 224 8 | # timm-vit_base_patch8_224.dino: # imgsz 224 9 | # timm-vit_base_patch16_224.augreg2_in21k_ft_in1k # imgsz 224 10 | # timm-vit_large_patch16_224.mae: # imgsz 224 11 | # timm-vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k: # imgsz 224 12 | # timm-swinv2_base_window8_256.ms_in1k: # imgsz 256 13 | # timm-swinv2_large_window12to16_192to256.ms_in22k_ft_in1k: # imgsz 256 14 | # timm-vit_base_patch16_clip_224.laion2b_ft_in1k: # imgsz 224 15 | # timm-vit_large_patch14_dinov2.lvd142m: # imgsz 518 16 | # timm-vit_so400m_patch14_siglip_224.webli: # imgsz 224 17 | #--------------------------------CNN---------------------------------------------# 18 | # timm-wide_resnet101_2.tv2_in1k: # imgsz 224 19 | # timm-resnet50d.gluon_in1k: # imgsz 224 20 | # timm-resnext50_32x4d.a3_in1k: # imgsz 224 21 | # timm-resnest50d_4s2x40d.in1k: # imgsz 224 22 | # timm-legacy_seresnet50.in1k: # imgsz 224 23 | # timm-tf_mobilenetv3_large_minimal_100.in1k: # imgsz 224 24 | # timm-convnext_base.clip_laion2b_augreg_ft_in1k: # imgsz 224 25 | # timm-convnext_base.clip_laiona_augreg_ft_in1k_384: # imgsz 384 26 | # timm-convnext_large.fb_in22k_ft_in1k: # imgsz 224 27 | # timm-tf_efficientnetv2_l.in21k_ft_in1k: # imgsz 224 28 | timm-swin_base_patch4_window7_224.ms_in22k_ft_in1k: # imgsz 224 29 | pretrained: True 30 | image_size: *imgsz 31 | feat_dim: &featd 128 32 | head: 33 | arcface: 34 | feat_dim: *featd 35 | num_class: 74726 36 | margin_arc: 0.35 37 | margin_am: 0.0 38 | scale: 32 39 | data: 40 | # Choose ONE of the following data sources: 41 | 42 | # 1. HuggingFace Dataset 43 | root: wuji3/face-recognition 44 | 45 | # 2. Local Dataset 46 | # root: # Format: path/to/data with train/ and query/ gallery/ subdirs 47 | 48 | nw: 64 # if not multi-nw, set to 0 49 | train: 50 | bs: 320 # per gpu 51 | base_aug: null 52 | class_aug: null 53 | augment: # refer to utils/augment.py 54 | - random_choice: 55 | transforms: 56 | - random_color_jitter: 57 | brightness: 0.1 58 | contrast: 0.1 59 | saturation: 0.1 60 | hue: 0.1 61 | - random_gaussianblur: 62 | kernel_size: 5 63 | - random_horizonflip: 64 | p: 0.5 65 | - random_choice: 66 | transforms: 67 | - resize_and_padding: 68 | size: *imgsz 69 | training: True 70 | - random_crop_and_resize: 71 | size: *imgsz 72 | scale: [0.7, 1] 73 | - to_tensor: no_params 74 | - normalize: 75 | mean: [0.485, 0.456, 0.406] 76 | std: [0.229, 0.224, 0.225] 77 | aug_epoch: 9 78 | val: 79 | bs: 64 80 | pair_txt: 81 | augment: 82 | - resize_and_padding: 83 | size: *imgsz 84 | training: False 85 | - to_tensor: no_params 86 | - normalize: 87 | mean: [0.485, 0.456, 0.406] 88 | std: [0.229, 0.224, 0.225] 89 | hyp: 90 | epochs: 10 91 | lr0: 0.006 92 | lrf_ratio: null # decay to lrf_ratio * lr0, if None, 0.1 93 | momentum: 0.937 94 | weight_decay: 0.0005 95 | warmup_momentum: 0.8 96 | warm_ep: 1 97 | loss: 98 | ce: True 99 | label_smooth: 0.0 100 | optimizer: 101 | - sgd # sgd, adam or sam 102 | - False # Different layers in the model set different learning rates, in built/layer_optimizer 103 | scheduler: cosine_with_warm # linear or cosine 104 | -------------------------------------------------------------------------------- /scripts/validate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join as opj 3 | import argparse 4 | from pathlib import Path 5 | from doraemon import (valuate_classifier, 6 | valuate_cbir, 7 | valuate_face, 8 | FaceModelLoader, 9 | get_model, 10 | SmartDataProcessor, 11 | SmartLogger, 12 | ) 13 | import torch 14 | from prettytable import PrettyTable 15 | 16 | LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) 17 | ROOT = Path(os.path.dirname(__file__)) 18 | 19 | def parse_opt(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('model-path', default='run/exp/best.pt', help='Path to model configs') 22 | parser.add_argument('--ema', action='store_true', help='Exponential Moving Average for model weight') 23 | 24 | # classifier 25 | parser.add_argument('--eval_topk', default=5, type=int, help='Tell topk_acc, maybe top5, top3...') 26 | parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify') 27 | 28 | return parser.parse_args() 29 | 30 | def main(opt): 31 | # device 32 | if LOCAL_RANK != -1: 33 | device = torch.device('cuda', LOCAL_RANK) 34 | else: 35 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 36 | 37 | logger = SmartLogger(filename=None, level=1) if LOCAL_RANK in {-1,0} else None 38 | 39 | pt_file = Path(getattr(opt, 'model-path')) 40 | pt = torch.load(pt_file, weights_only=False) 41 | config = pt['config'] 42 | task: str = config['model']['task'] 43 | 44 | if task == 'classification': 45 | modelwrapper = get_model(config['model'], None, LOCAL_RANK) 46 | # checkpoint loading 47 | model = modelwrapper.load_weight(pt, ema=opt.ema, device=device) 48 | 49 | data_processor = SmartDataProcessor(config['data'], LOCAL_RANK, None, training = False) 50 | data_processor.val_dataset = data_processor.create_dataset('val', training = False, id2label=pt['id2label']) 51 | 52 | # set val dataloader 53 | dataloader = data_processor.set_dataloader(data_processor.val_dataset, nw=config['data']['nw'], bs=config['data']['val']['bs'], 54 | collate_fn=data_processor.val_dataset.collate_fn) 55 | 56 | conm_path = opj(os.path.dirname(pt_file), 'conm.png') 57 | 58 | thresh = config['hyp']['loss']['bce'][1] if config['hyp']['loss']['bce'][0] else 0 59 | valuate_classifier(model, dataloader, device, None, False, None, logger, thresh=thresh, top_k=opt.eval_topk, 60 | conm_path=conm_path) 61 | 62 | elif task in ('face', 'cbir'): 63 | # logger 64 | logger = SmartLogger(filename=None) 65 | 66 | # checkpoint loading 67 | logger.console(f'Loading Model, EMA is {opt.ema}') 68 | model_loader = FaceModelLoader(model_cfg=config['model']) 69 | model = model_loader.load_weight(model_path=pt_file, ema=opt.ema) 70 | 71 | logger.console('Evaluating...') 72 | if task == 'face': 73 | mean, std = valuate_face(model, config['data'], torch.device('cuda')) 74 | pretty_tabel = PrettyTable(["model_name", "mean accuracy", "standard error"]) 75 | pretty_tabel.add_row([os.path.basename(pt_file), mean, std]) 76 | 77 | logger.console('\n' + str(pretty_tabel)) 78 | else: 79 | metrics = valuate_cbir(model, 80 | config['data'], 81 | torch.device('cuda', LOCAL_RANK if LOCAL_RANK > 0 else 0), 82 | logger, 83 | vis=False) 84 | logger.console(metrics) 85 | 86 | else: 87 | raise ValueError(f'Unknown task {task}') 88 | 89 | if __name__ == '__main__': 90 | opt = parse_opt() 91 | main(opt) -------------------------------------------------------------------------------- /configs/representation/image-retrieval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | task: cbir 3 | image_size: &imgsz 112 4 | load_from: null 5 | backbone: 6 | #----------------------------Transformer-----------------------------------------# 7 | # timm-swin_base_patch4_window7_224.ms_in22k_ft_in1k: # imgsz 224 8 | # timm-vit_base_patch8_224.dino: # imgsz 224 9 | # timm-vit_base_patch16_224.augreg2_in21k_ft_in1k # imgsz 224 10 | # timm-vit_large_patch16_224.mae: # imgsz 224 11 | # timm-vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k: # imgsz 224 12 | # timm-swinv2_base_window8_256.ms_in1k: # imgsz 256 13 | # timm-swinv2_large_window12to16_192to256.ms_in22k_ft_in1k: # imgsz 256 14 | # timm-vit_base_patch16_clip_224.laion2b_ft_in1k: # imgsz 224 15 | # timm-vit_large_patch14_dinov2.lvd142m: # imgsz 518 16 | # timm-vit_so400m_patch14_siglip_224.webli: # imgsz 224 17 | #--------------------------------CNN---------------------------------------------# 18 | # timm-wide_resnet101_2.tv2_in1k: # imgsz 224 19 | # timm-resnet50d.gluon_in1k: # imgsz 224 20 | # timm-resnext50_32x4d.a3_in1k: # imgsz 224 21 | # timm-resnest50d_4s2x40d.in1k: # imgsz 224 22 | # timm-legacy_seresnet50.in1k: # imgsz 224 23 | timm-tf_mobilenetv3_large_minimal_100.in1k: # imgsz 224 24 | # timm-convnext_base.clip_laion2b_augreg_ft_in1k: # imgsz 224 25 | # timm-convnext_large.fb_in22k_ft_in1k: # imgsz 224 26 | # timm-tf_efficientnetv2_l.in21k_ft_in1k: # imgsz 224 27 | # timm-swin_base_patch4_window7_224.ms_in22k_ft_in1k: # imgsz 224 28 | pretrained: True 29 | image_size: *imgsz 30 | feat_dim: &featd 128 31 | head: 32 | arcface: 33 | feat_dim: *featd 34 | num_class: 5000 35 | margin_arc: 0.35 36 | margin_am: 0.0 37 | scale: 32 38 | data: 39 | # Choose ONE of the following data sources: 40 | 41 | # 1. HuggingFace Dataset 42 | root: wuji3/image-retrieval 43 | 44 | # 2. Local Dataset 45 | # root: # Format: path/to/data with train/ and query/ gallery/ subdirs 46 | 47 | nw: 64 # if not multi-nw, set to 0 48 | train: 49 | bs: 320 # per gpu 50 | base_aug: null 51 | class_aug: null 52 | augment: # refer to utils/augment.py 53 | - random_choice: 54 | transforms: 55 | - random_color_jitter: 56 | brightness: 0.1 57 | contrast: 0.1 58 | saturation: 0.1 59 | hue: 0.1 60 | - random_cutout: 61 | n_holes: 3 62 | length: 12 63 | prob: 0.1 64 | color: [0, 255] 65 | - random_gaussianblur: 66 | kernel_size: 5 67 | - random_rotate: 68 | degrees: 10 69 | - random_adjustsharpness: 70 | p: 0.5 71 | - random_horizonflip: 72 | p: 0.5 73 | - random_choice: 74 | transforms: 75 | - resize_and_padding: 76 | size: *imgsz 77 | training: True 78 | - random_crop_and_resize: 79 | size: *imgsz 80 | scale: [0.7, 1] 81 | p: [0.9, 0.1] 82 | - to_tensor: no_params 83 | - normalize: 84 | mean: [0.485, 0.456, 0.406] 85 | std: [0.229, 0.224, 0.225] 86 | aug_epoch: 5 87 | val: 88 | bs: 128 89 | metrics: 90 | metrics: [mrr, recall, precision, auc, ndcg] 91 | cutoffs: [1, 3, 5] 92 | augment: 93 | - resize_and_padding: 94 | size: *imgsz 95 | training: False 96 | - to_tensor: no_params 97 | - normalize: 98 | mean: [0.485, 0.456, 0.406] 99 | std: [0.229, 0.224, 0.225] 100 | hyp: 101 | epochs: 6 102 | lr0: 0.006 103 | lrf_ratio: null # decay to lrf_ratio * lr0, if None, 0.1 104 | momentum: 0.937 105 | weight_decay: 0.0005 106 | warmup_momentum: 0.8 107 | warm_ep: 1 108 | loss: 109 | ce: True 110 | label_smooth: 0.0 111 | optimizer: 112 | - sgd # sgd, adam or sam 113 | - True # Different layers in the model set different learning rates, in built/layer_optimizer 114 | scheduler: cosine_with_warm # linear or cosine 115 | -------------------------------------------------------------------------------- /doraemon/models/classifier/classify_model.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch.nn as nn 3 | from ...built.attention_based_pooler import atten_pool_replace 4 | import torch 5 | import torch.distributed as dist 6 | import timm 7 | from torch.nn.init import normal_, constant_ 8 | 9 | class VisionWrapper: 10 | def __init__(self, model_cfgs: dict, logger = None, rank = -1): 11 | self.logger = logger 12 | self.model_cfgs = model_cfgs 13 | self.rank = rank 14 | 15 | self.kwargs = model_cfgs['kwargs'] 16 | self.num_classes = model_cfgs['num_classes'] 17 | self.pretrained = model_cfgs['pretrained'] 18 | self.backbone_freeze = model_cfgs['backbone_freeze'] 19 | self.bn_freeze = model_cfgs['bn_freeze'] 20 | self.bn_freeze_affine = model_cfgs['bn_freeze_affine'] 21 | 22 | model_cfgs_copy = deepcopy(model_cfgs) 23 | _, model_cfgs_copy['choice'] = model_cfgs['name'].split('-') 24 | 25 | self.model = self.create_model(**model_cfgs_copy) 26 | # pool layer 27 | if model_cfgs['attention_pool']: 28 | self.model = atten_pool_replace(self.model) 29 | 30 | del model_cfgs_copy 31 | 32 | if not self.pretrained: self.reset_parameters() 33 | 34 | def create_model(self, choice: str, num_classes: int = 1000, pretrained: bool = False, 35 | backbone_freeze: bool = False, bn_freeze: bool = False, 36 | bn_freeze_affine: bool = False, **kwargs): 37 | # Only rank 0 downloads the pre-trained weights 38 | if pretrained and self.rank == 0: 39 | _ = timm.create_model( 40 | choice, 41 | pretrained=True, 42 | num_classes=num_classes, 43 | **kwargs['kwargs'] 44 | ) 45 | 46 | if self.rank >= 0: 47 | dist.barrier(device_ids=[self.rank]) 48 | 49 | model = timm.create_model( 50 | choice, 51 | pretrained=pretrained, 52 | num_classes=num_classes, 53 | **kwargs['kwargs'] 54 | ) 55 | 56 | if backbone_freeze: self.freeze_backbone() 57 | if bn_freeze: self.freeze_bn(bn_freeze_affine) 58 | 59 | return model 60 | 61 | def load_weight(self, load_from_path: str | dict, ema: bool = False, device: torch.device = None): 62 | if isinstance(load_from_path, str): 63 | checkpoint = torch.load(load_from_path, map_location='cpu', weights_only=False) 64 | if ema: 65 | weights = checkpoint['ema'].float().state_dict() 66 | else: 67 | weights = checkpoint['model'] 68 | elif isinstance(load_from_path, dict): 69 | weights = load_from_path['ema'].float().state_dict() if ema else load_from_path['model'] 70 | else: 71 | raise TypeError(f"load_from_path must be str or dict, got {type(load_from_path)}") 72 | 73 | self.model.load_state_dict(weights) 74 | return self.model.to(device) 75 | 76 | def init_parameters(self, m: nn.Module): 77 | if isinstance(m, nn.Conv2d): 78 | normal_(m.weight, mean=0, std=0.02) 79 | elif isinstance(m, nn.BatchNorm2d): 80 | constant_(m.weight, 1) 81 | constant_(m.bias, 0) 82 | elif isinstance(m, nn.Linear): 83 | normal_(m.weight, mean=0, std=0.02) 84 | constant_(m.bias, 0) 85 | 86 | def reset_parameters(self): 87 | self.model.apply(self.init_parameters) 88 | 89 | def freeze_backbone(self): 90 | # Get the classifier module 91 | classifier = self.model.get_classifier() 92 | 93 | # Freeze all parameters except the classifier 94 | for _, m in self.model.named_modules(): 95 | if m is not classifier: # Skip the classifier module 96 | for p in m.parameters(recurse=False): 97 | p.requires_grad_(False) 98 | 99 | if self.rank <= 0: # Only print on main process 100 | print('backbone freeze') 101 | 102 | def freeze_bn(self, bn_freeze_affine: bool = False): 103 | for m in self.model.modules(): 104 | if isinstance(m, nn.BatchNorm2d): 105 | m.eval() 106 | if bn_freeze_affine: 107 | m.weight.requires_grad_(False) 108 | m.bias.requires_grad_(False) -------------------------------------------------------------------------------- /doraemon/utils/plots.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw, ImageFont 2 | from pathlib import Path 3 | import os 4 | import platform 5 | from urllib.error import URLError 6 | import torch 7 | 8 | def is_writeable(dir, test=False): 9 | # Return True if directory has write permissions, test opening a file with write permissions if test=True 10 | if not test: 11 | return os.access(dir, os.W_OK) # possible issues on Windows 12 | file = Path(dir) / 'tmp.txt' 13 | try: 14 | with open(file, 'w'): # open file with write permissions 15 | pass 16 | file.unlink() # remove file 17 | return True 18 | except OSError: 19 | return False 20 | 21 | def user_config_dir(dir='Doraemon', env_var='VISION_CONFIG_DIR'): 22 | # Return path of user configuration directory. Prefer environment variable if exists. Make dir if required. 23 | env = os.getenv(env_var) 24 | if env: 25 | path = Path(env) # use environment variable 26 | else: 27 | cfg = {'Windows': 'AppData/Roaming', 'Linux': '.config', 'Darwin': 'Library/Application Support'} # 3 OS dirs 28 | path = Path.home() / cfg.get(platform.system(), '') # OS-specific config dir 29 | path = (path if is_writeable(path) else Path('/tmp')) / dir # GCP and AWS lambda fix, only /tmp is writeable 30 | path.mkdir(exist_ok=True) # make if required 31 | return path 32 | 33 | CONFIG_DIR = user_config_dir() 34 | 35 | def check_font(font='Arial.ttf', progress=False): 36 | # Download font to CONFIG_DIR if necessary 37 | font = Path(font) 38 | file = CONFIG_DIR / font.name 39 | if not font.exists() and not file.exists(): 40 | url = f'https://ultralytics.com/assets/{font.name}' 41 | # LOGGER.info(f'Downloading {url} to {file}...') 42 | torch.hub.download_url_to_file(url, str(file), progress=progress) 43 | 44 | def check_pil_font(font='Arial.ttf', size=10): 45 | # Return a PIL TrueType Font, downloading to CONFIG_DIR if necessary 46 | font = Path(font) 47 | font = font if font.exists() else (CONFIG_DIR / font.name) 48 | try: 49 | return ImageFont.truetype(str(font) if font.exists() else font.name, size) 50 | except Exception: # download if missing 51 | try: 52 | check_font(font) 53 | return ImageFont.truetype(str(font), size) 54 | except TypeError: 55 | pass 56 | except URLError: # not online 57 | return ImageFont.load_default() 58 | 59 | def is_ascii(s=''): 60 | # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7) 61 | s = str(s) # convert list, tuple, None, etc. to str 62 | return len(s.encode().decode('ascii', 'ignore')) == len(s) 63 | 64 | class Annotator: 65 | # cp misc/Arial.ttf ~/.config/Doraemon 66 | def __init__(self, im, font='Arial.ttf'): 67 | font_size = None 68 | non_ascii = not is_ascii('abc') 69 | self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) 70 | self.draw = ImageDraw.Draw(self.im) 71 | self.font = check_pil_font(font='Arial.Unicode.ttf' if non_ascii else font, size=font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)) 72 | 73 | def text(self, xy, text, txt_color=(255, 255, 255)): 74 | # Add text to image (PIL-only) 75 | self.draw.text(xy, text, fill=txt_color, font=self.font) 76 | 77 | def colorstr(*input): 78 | # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world') 79 | *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string 80 | colors = { 81 | 'black': '\033[30m', # basic colors 82 | 'red': '\033[31m', 83 | 'green': '\033[32m', 84 | 'yellow': '\033[33m', 85 | 'blue': '\033[34m', 86 | 'magenta': '\033[35m', 87 | 'cyan': '\033[36m', 88 | 'white': '\033[37m', 89 | 'bright_black': '\033[90m', # bright colors 90 | 'bright_red': '\033[91m', 91 | 'bright_green': '\033[92m', 92 | 'bright_yellow': '\033[93m', 93 | 'bright_blue': '\033[94m', 94 | 'bright_magenta': '\033[95m', 95 | 'bright_cyan': '\033[96m', 96 | 'bright_white': '\033[97m', 97 | 'end': '\033[0m', # misc 98 | 'bold': '\033[1m', 99 | 'underline': '\033[4m'} 100 | return ''.join(colors[x] for x in args) + f'{string}' + colors['end'] 101 | -------------------------------------------------------------------------------- /doraemon/models/representation/README_CBIR.md: -------------------------------------------------------------------------------- 1 | #
Image Retrieval
2 | 3 | ## 📦 Data Preparation 4 | 5 | 6 | You can use either pre-prepared datasets or your own dataset for training. 7 | 8 | ### Option 1: Pre-prepared Datasets 9 | 10 | **Using HuggingFace Dataset (Recommended)** 11 | 12 | Dataset: [wuji3/image-retrieval](https://huggingface.co/datasets/wuji3/image-retrieval) 13 | ```yaml 14 | # In your config file (e.g., configs/faceX/cbir.yaml) 15 | data: 16 | root: wuji3/image-retrieval 17 | ``` 18 | 19 | ### Option 2: Custom Dataset 20 | 21 | #### Dataset Structure 22 | 23 | Organize your data in the following structure: 24 | ``` 25 | your_dataset/ 26 | ├── train/ 27 | │ ├── class1/ # Folder name = class/ID name 28 | │ │ ├── image1.jpg 29 | │ │ └── ... 30 | │ └── class2/ 31 | │ └── ... 32 | ├── gallery/ # Query database 33 | │ ├── class1/ 34 | │ │ └── ... 35 | │ └── class2/ 36 | │ └── ... 37 | └── query/ # Query images 38 | ├── class1/ 39 | │ └── ... 40 | └── class2/ 41 | └── ... 42 | ``` 43 | 44 | Note: 45 | - IDs in query set should be a subset of gallery set 46 | - Gallery set can contain additional ID categories 47 | - Each ID folder contains different images of the same identity 48 | 49 | #### Data Preparation Tools 50 | 51 | We provide a convenient tool for dataset construction: 52 | 53 | ```bash 54 | python tools/build_querygallery.py --src --frac 55 | ``` 56 | 57 | This tool will transform your original data: 58 | ``` 59 | data/ 60 | └── ID1/ 61 | ├── xxx1.jpg 62 | └── xxx2.jpg 63 | └── ID2/ 64 | ├── xxx3.jpg 65 | └── xxx4.jpg 66 | ``` 67 | 68 | Into retrieval format: 69 | ``` 70 | data/ 71 | └── data-query/ 72 | └── ID1/ 73 | └── xxx1.jpg 74 | └── ID2/ 75 | └── xxx3.jpg 76 | └── data-gallery/ 77 | └── ID1/ 78 | └── xxx2.jpg 79 | └── ID2/ 80 | └── xxx4.jpg 81 | ``` 82 | 83 | ## 🧊 Models 84 | 85 | ### Model Configuration 86 | 87 | The model configuration includes backbone and head components: 88 | 89 | ```yaml 90 | model: 91 | task: cbir 92 | image_size: &imgsz 224 93 | load_from: null 94 | backbone: 95 | timm-resnet50d.gluon_in1k: # Multiple backbones supported 96 | pretrained: True 97 | image_size: *imgsz 98 | feat_dim: &featd 128 99 | head: 100 | arcface: # Support multiple loss functions 101 | feat_dim: *featd 102 | num_class: 5000 103 | margin_arc: 0.35 104 | margin_am: 0.0 105 | scale: 32 106 | ``` 107 | 108 | ### Supported Loss Functions 109 | 110 | Supported heads: [ArcFace](https://arxiv.org/abs/1801.07698), [MagFace](https://arxiv.org/abs/2103.06627), [CircleLoss](https://arxiv.org/abs/2002.10857), [MV-Softmax](https://arxiv.org/abs/1912.00833) 111 | 112 | Configure the desired loss function in the `head` section of your config file. 113 | 114 | ### Available Models 115 | ```python 116 | import timm 117 | timm.list_models(pretrained=True) # ['beit_base_patch16_224.in22k_ft_in22k', 'swin_base_patch4_window7_224.ms_in22k_ft_in1k', 'vit_base_patch16_siglip_224.webli', ...] 118 | ``` 119 | 120 | ## 🚀 Training 121 | 122 | ### Basic Training 123 | ```bash 124 | # Single GPU training 125 | python -m scripts.train configs/representation/image-retrieval.yaml 126 | 127 | # Multi-GPU training 128 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 -m scripts.train configs/representation/image-retrieval.yaml 129 | ``` 130 | 131 | **Options:** 132 | - `--print_freq 50`: Print log every 50 steps 133 | - `--save_freq 5`: Save checkpoint and validate every 5 epochs 134 | 135 | ## 📊 Evaluation 136 | 137 | ### Metrics 138 | We support multiple evaluation metrics: 139 | - MRR (Mean Reciprocal Rank) 140 | - Recall@K 141 | - Precision@K 142 | - AUC 143 | - NDCG@K 144 | 145 | Configure evaluation parameters in your config file: 146 | ```yaml 147 | metrics: 148 | metrics: [mrr, recall, precision, auc, ndcg] 149 | cutoffs: [1, 3, 5] # Evaluate performance at top1, top3, and top5 150 | ``` 151 | 152 | ### Run Evaluation 153 | ```bash 154 | python -m scripts.validate scripts/run/exp/Epoch_X.pt --ema 155 | ``` 156 | 157 | ## 🔍 Inference 158 | ```bash 159 | python -m scripts.infer scripts/run/exp/Epoch_X.pt 160 | ``` 161 | **Options:** 162 | - `--max_rank`: Visualize top k retrieval results (default: 10) 163 | 164 | 165 | ### 🖼️ Visualization 166 |

167 | 168 | 169 |
170 | Retrieval Results (Left: Evaluation, Right: Inference) 171 |

-------------------------------------------------------------------------------- /doraemon/models/classifier/README.md: -------------------------------------------------------------------------------- 1 | #
Image Classification
2 | 3 | ## 📦 Data Preparation 4 | 5 | ### Quick Start with Pre-prepared Datasets 6 | 1. **Using HuggingFace Dataset (Recommended)** 7 | 8 | Dataset: [wuji3/oxford-iiit-pet](https://huggingface.co/datasets/wuji3/oxford-iiit-pet) 9 | ```yaml 10 | # In your config file (e.g., configs/classification/pet.yaml) 11 | data: 12 | root: wuji3/oxford-iiit-pet 13 | ``` 14 | 15 | 2. **Download Pre-prepared Dataset** 16 | - Oxford-IIIT Pet Dataset (37 pet breeds) 17 | - [Download from Baidu Cloud](https://pan.baidu.com/s/1PjM6kPoTyzNYPZkpmDoC6A) (Code: yjsl) **Recommended** 18 | - [Download from Official URL](https://s3.amazonaws.com/fast-ai-imageclas/oxford-iiit-pet.tgz) 19 | ```bash 20 | # After downloading: 21 | cd data 22 | tar -xf oxford-iiit-pet.tgz 23 | python split2dataset.py 24 | ``` 25 | 26 | ### Training with Your Own Dataset 27 | You can prepare your data in either single-label or multi-label format: 28 | 29 | #### Option 1: Single-label Format 30 | ``` 31 | your_dataset/ 32 | ├── train/ 33 | │ ├── class1/ # Folder name = class name 34 | │ │ ├── image1.jpg 35 | │ │ └── ... 36 | │ └── class2/ 37 | │ └── ... 38 | └── val/ 39 | ├── class1/ 40 | │ └── ... 41 | └── class2/ 42 | └── ... 43 | ``` 44 | 45 | #### Option 2: Multi-label Format (CSV) 46 | Create a CSV file with the following structure: 47 | ```csv 48 | image_path,tag1,tag2,tag3,train 49 | /path/to/image1.jpg,1,0,1,True # 1=has_tag, 0=no_tag 50 | /path/to/image2.jpg,0,1,0,True # True=training set 51 | ``` 52 | 53 | ### Data Preparation Helper 54 | Convert a folder of categorized images into the required training format: 55 | ```bash 56 | # If your data structure is: 57 | # your_dataset/ 58 | # ├── class1/ 59 | # │ ├── img1.jpg 60 | # │ └── img2.jpg 61 | # ├── class2/ 62 | # │ ├── img3.jpg 63 | # │ └── img4.jpg 64 | # └── ... 65 | 66 | python tools/data_prepare.py \ 67 | --root path/to/your/images \ 68 | --postfix jpg \ # Image format: jpg or png 69 | --frac 0.8 # Split ratio: 80% training, 20% validation 70 | ``` 71 | 72 | This script will automatically: 73 | 1. Create train/ and val/ directories 74 | 2. Split images from each class into training and validation sets 75 | 3. Maintain the class folder structure in both sets 76 | 77 | ## 🧊 Models 78 | 79 | ### Model Configuration 80 | ```yaml 81 | model: 82 | task: classification 83 | name: timm-swin_base_patch4_window7_224 # Format: timm-{model_name} 84 | image_size: 224 85 | num_classes: 35 86 | pretrained: True 87 | kwargs: {} # Additional parameters for model initialization 88 | ``` 89 | 90 | ### Available Models 91 | ```python 92 | import timm 93 | timm.list_models(pretrained=True) # ['beit_base_patch16_224.in22k_ft_in22k', 'swin_base_patch4_window7_224.ms_in22k_ft_in1k', 'vit_base_patch16_siglip_224.webli', ...] 94 | ``` 95 | 96 | ## 🚀 Training 97 | 98 | ### Training Options 99 | 100 | ```bash 101 | # Single GPU training 102 | python -m scripts.train configs/recognition/pet.yaml 103 | 104 | # Multi-GPU training 105 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 -m scripts.train configs/recognition/pet.yaml [options] 106 | ``` 107 | 108 | **Options:** 109 | - `--resume path/to/model.pt`: Resume interrupted training 110 | - `--sync_bn`: Enable synchronized BatchNorm for multi-GPU 111 | 112 | ### Monitor Training 113 | ```bash 114 | # View real-time training log 115 | tail -f scripts/run/exp/log{timestamp}.log # e.g., log20241113-155144.log 116 | ``` 117 | 118 | ## 📊 Evaluation & Visualization 119 | 120 | ### Analyze Model Predictions 121 | 122 | ```bash 123 | python -m scripts.infer scripts/run/exp/best.pt --data [options] 124 | ``` 125 | 126 | **Options:** 127 | - `--infer-option {default, autolabel}`: 128 | - `default`: Infer + Visualize + CaseAnalysis 129 | - `autolabel`: Infer + Label 130 | - `--classes A B C`: Filter specific classes 131 | - `--split val`: Only analyze validation set 132 | - `--sampling 10`: Analyze a random subset of samples 133 | - `--ema`: Use Exponential Moving Average for model weight 134 | 135 | ### Validate Model Performance 136 | ```bash 137 | python -m scripts.validate scripts/run/exp/best.pt --ema 138 | ``` 139 | 140 | ## 🖼️ Example 141 | 142 |

143 | 144 | 145 | 146 |
147 | Visualization of Model Processes (Left: Training, Center: Evaluation, Right: Inference) 148 |

149 | -------------------------------------------------------------------------------- /doraemon/models/representation/face_model.py: -------------------------------------------------------------------------------- 1 | from .backbone.backbone_def import BackboneFactory 2 | from .head.head_def import HeadFactory 3 | import torch.nn as nn 4 | from torch.nn.init import normal_, constant_ 5 | import torch 6 | import torch.nn.functional as F 7 | import os 8 | import numpy as np 9 | 10 | class FaceTrainingWrapper: 11 | def __init__(self, model_cfg, logger = None): 12 | self.model = FaceTrainingModel(model_cfg) 13 | self.logger = logger 14 | 15 | def init_parameters(self, m: nn.Module): 16 | if isinstance(m, nn.Conv2d): 17 | normal_(m.weight, mean=0, std=0.02) 18 | elif isinstance(m, nn.BatchNorm2d): 19 | constant_(m.weight, 1) 20 | constant_(m.bias, 0) 21 | elif isinstance(m, nn.Linear): 22 | normal_(m.weight, mean=0, std=0.02) 23 | constant_(m.bias, 0) 24 | 25 | def reset_parameters(self): 26 | self.model.apply(self.init_parameters) 27 | 28 | class FaceTrainingModel(nn.Module): 29 | """Define a traditional faceX model which contains a backbone and a head. 30 | 31 | Attributes: 32 | backbone: the backbone of faceX model. 33 | head: the head of faceX model. 34 | """ 35 | 36 | def __init__(self, model_cfg): 37 | """Init faceX model by backbone factorcy and head factory. 38 | 39 | Args: 40 | backbone_factory: produce a backbone according to config files. 41 | head_factory: produce a head according to config files. 42 | """ 43 | super().__init__() 44 | backbone = BackboneFactory(model_cfg['backbone']).get_backbone() 45 | head = HeadFactory(model_cfg['head']).get_head() 46 | self.trainingwrapper = nn.ModuleDict({ 47 | 'backbone': backbone, 48 | 'head': head 49 | }) 50 | 51 | def forward(self, data, label): 52 | feat = self.trainingwrapper['backbone'](data) 53 | pred = self.trainingwrapper['head'](feat, label) 54 | return pred 55 | 56 | class FaceModelLoader: 57 | def __init__(self, model_cfg: dict): 58 | self.model = BackboneFactory(model_cfg['backbone']).get_backbone() 59 | 60 | def load_weight_default(self, model_path): 61 | """The default method to load a model. 62 | 63 | Args: 64 | model_path:: the path of the weight file. 65 | 66 | Returns: 67 | model: initialized model. 68 | """ 69 | self.model.load_state_dict(torch.load(model_path, weights_only=False)['state_dict'], strict=True) 70 | 71 | return self.model 72 | 73 | def load_weight(self, model_path, ema: bool = False): 74 | """The custom method to load a model, from a model having feature extractor and head. 75 | 76 | Args: 77 | model_path: the path of the weight file. 78 | 79 | Returns: 80 | model: initialized model. 81 | """ 82 | pretrained_dict = torch.load(model_path, weights_only=False)['ema'] if ema else torch.load(model_path, weights_only=False)['state_dict'] 83 | 84 | self.model.load_state_dict(pretrained_dict, strict=True) 85 | 86 | return self.model 87 | 88 | class FeatureExtractor: 89 | 90 | def __init__(self, model): 91 | self.model = model 92 | 93 | def extract_face(self, dataloader, device) -> dict: 94 | """Extract and return features. 95 | 96 | Args: 97 | model: initialized model. 98 | dataloader: load data to be extracted. 99 | 100 | Returns: 101 | image_name2feature: key is the name of image, value is feature of image. 102 | """ 103 | model = self.model 104 | model.eval() 105 | model.to(device) 106 | 107 | image_name2feature = {} 108 | with torch.no_grad(): 109 | for batch_idx, (_, tensors, file_realpaths) in enumerate(dataloader): 110 | tensors = tensors.to(device) 111 | features = model(tensors) 112 | features = F.normalize(features) 113 | features = features.cpu().numpy() 114 | for realpath, feature in zip(file_realpaths, features): 115 | filename = os.path.join(os.path.basename(os.path.dirname(realpath)), os.path.basename(realpath)) 116 | image_name2feature[filename] = feature 117 | 118 | return image_name2feature 119 | 120 | def extract_cbir(self, dataloader, device) -> np.ndarray: 121 | """Extract and return features. 122 | 123 | Args: 124 | model: initialized model. 125 | dataloader: load data to be extracted. 126 | 127 | Returns: 128 | features: feature of image. 129 | """ 130 | model = self.model 131 | model.eval() 132 | model.to(device) 133 | 134 | features = [] 135 | with torch.no_grad(): 136 | for batch_idx, tensors in enumerate(dataloader): 137 | tensors = tensors.to(device) 138 | feature = model(tensors) 139 | feature = F.normalize(feature) 140 | feature = feature.cpu().numpy() 141 | 142 | features.append(feature) 143 | 144 | return np.concatenate(features, axis=0) -------------------------------------------------------------------------------- /doraemon/engine/optimizer.py: -------------------------------------------------------------------------------- 1 | from torch.optim import SGD, Adam 2 | from typing import Callable, List, Dict, Optional, Iterator, Union 3 | from functools import wraps 4 | from abc import ABCMeta, abstractmethod 5 | from torch.nn import Module 6 | import torch 7 | from torch.nn.modules.batchnorm import _BatchNorm 8 | 9 | __all__ = ['sgd', 10 | 'adam', 11 | 'sam', 12 | 'BaseSeperateLayer', 13 | 'create_Optimizer', 14 | 'list_optimizers'] 15 | 16 | OPTIMIZER = {} 17 | 18 | def register_optimizer(fn: Callable): 19 | key = fn.__name__ 20 | if key in OPTIMIZER: 21 | raise ValueError(f"An entry is already registered under the name '{key}'.") 22 | OPTIMIZER[key] = fn 23 | @wraps(fn) 24 | def wrapper(*args, **kwargs): 25 | return fn(*args, **kwargs) 26 | 27 | return wrapper 28 | 29 | class SAM(torch.optim.Optimizer): 30 | """ 31 | https://arxiv.org/abs/2010.01412 Sharpness-Aware Minimization for Efficiently Improving Generalization 32 | """ 33 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=True, lr: float = 0.01, momentum: float = 0.9, weight_decay: float = 5e-4, **kwargs): 34 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" 35 | 36 | defaults = dict(rho=rho, adaptive=adaptive, lr=lr, momentum=momentum, weight_decay=weight_decay, **kwargs) 37 | super(SAM, self).__init__(params, defaults) 38 | 39 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 40 | self.param_groups = self.base_optimizer.param_groups 41 | self.defaults.update(self.base_optimizer.defaults) 42 | 43 | @torch.no_grad() 44 | def first_step(self, zero_grad=False): 45 | grad_norm = self._grad_norm() 46 | for group in self.param_groups: 47 | scale = group["rho"] / (grad_norm + 1e-12) 48 | 49 | for p in group["params"]: 50 | if p.grad is None: continue 51 | self.state[p]["old_p"] = p.data.clone() 52 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) 53 | p.add_(e_w) # climb to the local maximum "w + e(w)" 54 | 55 | if zero_grad: self.zero_grad() 56 | 57 | @torch.no_grad() 58 | def second_step(self, zero_grad=False): 59 | for group in self.param_groups: 60 | for p in group["params"]: 61 | if p.grad is None: continue 62 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" 63 | 64 | self.base_optimizer.step() # do the actual "sharpness-aware" update 65 | 66 | if zero_grad: self.zero_grad() 67 | 68 | @torch.no_grad() 69 | def step(self, closure=None): 70 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" 71 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass 72 | 73 | self.first_step(zero_grad=True) 74 | closure() 75 | self.second_step() 76 | 77 | def _grad_norm(self): 78 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism 79 | norm = torch.norm( 80 | torch.stack([ 81 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) 82 | for group in self.param_groups for p in group["params"] 83 | if p.grad is not None 84 | ]), 85 | p=2 86 | ) 87 | return norm 88 | 89 | def load_state_dict(self, state_dict): 90 | super().load_state_dict(state_dict) 91 | self.base_optimizer.param_groups = self.param_groups 92 | 93 | def disable_running_stats(self, model): 94 | def _disable(module): 95 | if isinstance(module, _BatchNorm): 96 | module.backup_momentum = module.momentum 97 | module.momentum = 0 98 | 99 | model.apply(_disable) 100 | 101 | def enable_running_stats(self, model): 102 | def _enable(module): 103 | if isinstance(module, _BatchNorm) and hasattr(module, "backup_momentum"): 104 | module.momentum = module.backup_momentum 105 | 106 | model.apply(_enable) 107 | 108 | class BaseSeperateLayer(metaclass=ABCMeta): 109 | """ 110 | 用于对model的多个层分别设置具体的学习率 111 | """ 112 | def __init__(self, model: Module) -> None: 113 | self.model = model 114 | 115 | @abstractmethod 116 | def create_ParamSequence(self, alpha: Optional[float], lr: float) -> Union[Iterator, List[Dict]]: 117 | pass 118 | 119 | @register_optimizer 120 | def sgd(*args, **kwargs): 121 | return SGD(*args, **kwargs) 122 | 123 | @register_optimizer 124 | def adam(*args, **kwargs): 125 | return Adam(*args, **kwargs) 126 | 127 | @register_optimizer 128 | def sam(base_optimizer = SGD, *args, **kwargs): 129 | return SAM(base_optimizer=base_optimizer, *args, **kwargs) 130 | 131 | def create_Optimizer(optimizer: str, lr: float, weight_decay, momentum, params): 132 | # return partial(OPTIMIZER[optimizer], lr = lr, weight_decay = weight_decay, momentum = momentum) 133 | return OPTIMIZER[optimizer](params = params, lr = lr, weight_decay = weight_decay, momentum = momentum) 134 | 135 | def list_optimizers(): 136 | optimizers = [k for k, v in OPTIMIZER.items()] 137 | return sorted(optimizers) -------------------------------------------------------------------------------- /configs/recognition/pet.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | task: classification # Task type: classification 3 | #----------------------------Transformer-----------------------------------------# 4 | # timm-swin_base_patch4_window7_224.ms_in22k_ft_in1k: # imgsz 224 5 | # timm-vit_base_patch8_224.dino: # imgsz 224 6 | # timm-vit_base_patch16_224.augreg2_in21k_ft_in1k # imgsz 224 7 | # timm-vit_large_patch16_224.mae: # imgsz 224 8 | # timm-vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k: # imgsz 224 9 | # timm-swinv2_base_window8_256.ms_in1k: # imgsz 256 10 | # timm-swinv2_large_window12to16_192to256.ms_in22k_ft_in1k: # imgsz 256 11 | # timm-vit_base_patch16_clip_224.laion2b_ft_in1k: # imgsz 224 12 | # timm-vit_large_patch14_dinov2.lvd142m: # imgsz 518 13 | # timm-vit_so400m_patch14_siglip_224.webli: # imgsz 224 14 | #--------------------------------CNN---------------------------------------------# 15 | # timm-wide_resnet101_2.tv2_in1k: # imgsz 224 16 | # timm-resnet50d.gluon_in1k: # imgsz 224 17 | # timm-resnext50_32x4d.a3_in1k: # imgsz 224 18 | # timm-resnest50d_4s2x40d.in1k: # imgsz 224 19 | # timm-legacy_seresnet50.in1k: # imgsz 224 20 | # timm-tf_mobilenetv3_large_minimal_100.in1k: # imgsz 224 21 | # timm-convnext_base.clip_laion2b_augreg_ft_in1k: # imgsz 224 22 | # timm-convnext_base.clip_laiona_augreg_ft_in1k_384: # imgsz 384 23 | # timm-convnext_large.fb_in22k_ft_in1k: # imgsz 224 24 | # timm-tf_efficientnetv2_l.in21k_ft_in1k: # imgsz 224 25 | load_from: null 26 | name: timm-swin_base_patch4_window7_224.ms_in22k_ft_in1k # Model used please refer to timm.models 27 | image_size: &resize_size 224 # Input image size, using anchor for later reference 28 | kwargs: {} # Additional parameters passed to model initialization 29 | num_classes: 35 # Number of classification categories 30 | pretrained: True # Whether to use pre-trained weights 31 | backbone_freeze: False # Whether to freeze the backbone network 32 | bn_freeze: False # Whether to freeze Batch Normalization layers 33 | bn_freeze_affine: False # Whether to freeze affine transformation parameters of BN layers 34 | attention_pool: False # Whether to use attention pooling 35 | data: 36 | # Choose ONE of the following data sources: 37 | 38 | # 1. HuggingFace Dataset (Single-label) 39 | root: wuji3/oxford-iiit-pet 40 | 41 | # 2. Local Dataset (Single-label) 42 | # root: # Format: path/to/dataset with train/ and val/ subdirs 43 | 44 | # 3. CSV Dataset (Multi-label) 45 | # root: .csv # Format: CSV file with image_path and label columns 46 | 47 | nw: 16 # if not multi-nw, set to 0 48 | train: 49 | bs: 80 # Batch size per GPU 50 | base_aug: null # Base data augmentation 51 | class_aug: null # Class-specific data augmentation 52 | # Abyssinian: [1,2,3,4] 53 | # Birman: [1,2,3,4] 54 | augment: # Data augmentation strategy, refer to utils/augment.py 55 | - random_choice: 56 | transforms: 57 | - random_color_jitter: 58 | brightness: 0.1 59 | contrast: 0.1 60 | saturation: 0.1 61 | hue: 0.1 62 | - random_cutout: 63 | n_holes: 3 64 | length: 12 65 | prob: 0.1 66 | color: [0, 255] 67 | - random_gaussianblur: 68 | kernel_size: 5 69 | - random_rotate: 70 | degrees: 10 71 | - random_autocontrast: 72 | p: 0.5 73 | - random_adjustsharpness: 74 | p: 0.5 75 | - random_augmix: 76 | severity: 3 77 | - random_horizonflip: 78 | p: 0.5 79 | - random_choice: 80 | transforms: 81 | - resize_and_padding: 82 | size: *resize_size 83 | training: True 84 | - random_crop_and_resize: 85 | size: *resize_size 86 | scale: [0.7, 1] 87 | p: [0.9, 0.1] 88 | - to_tensor: no_params 89 | - normalize: 90 | mean: [0.485, 0.456, 0.406] 91 | std: [0.229, 0.224, 0.225] 92 | aug_epoch: 14 # Number of epochs to apply data augmentation 93 | val: 94 | bs: 320 95 | augment: 96 | - resize_and_padding: 97 | size: *resize_size 98 | training: False 99 | - to_tensor: no_params 100 | - normalize: 101 | mean: [0.485, 0.456, 0.406] 102 | std: [0.229, 0.224, 0.225] 103 | hyp: 104 | epochs: 15 # Total number of training epochs 105 | lr0: 0.006 # Initial learning rate 106 | lrf_ratio: null # Learning rate decay ratio, null means using default value 0.1 107 | momentum: 0.937 # Optimizer momentum 108 | weight_decay: 0.0005 # Weight decay 109 | warmup_momentum: 0.8 # Momentum during warmup phase 110 | warm_ep: 1 # Number of warmup epochs 111 | loss: 112 | ce: True # Whether to use cross-entropy loss 113 | bce: # Binary cross-entropy, config: [use or not, weight, multi-label or not] 114 | - False 115 | - [0.5, 0.5, 0.5, 0.5, 0.5] # tag1, tag2, tag3, tag4, tag5 116 | label_smooth: 0.05 # Label smoothing coefficient 117 | strategy: 118 | prog_learn: False # Whether to use progressive learning 119 | mixup: # Mixup data augmentation configuration 120 | ratio: 0.0 121 | duration: 10 122 | focal: # Only With BCE, Focal Loss, config: [use or not, alpha, gamma] 123 | - False 124 | - 0.25 125 | - 1.5 126 | ohem: # Only With CE, Online Hard Example Mining, config: [use or not, min_kept, thresh_prob, ignore_index] 127 | - False 128 | - 8 129 | - 0.7 130 | - 255 131 | optimizer: 132 | - sgd # Optimizer type: SGD, Adam, or SAM 133 | - False # Whether to set different learning rates for different layers of the model 134 | scheduler: cosine_with_warm # Learning rate scheduler: cosine annealing with warmup -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | #
DORAEMON: Deep Object Recognition And Embedding Model Of Networks
2 | 3 |

4 | 5 |

6 | 7 |

8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |

16 | 17 | ## 🚀 Quick Start 18 | 19 | Installation Guide 20 | 21 | ```bash 22 | # Create and activate environment 23 | python -m venv doraemon 24 | source doraemon/bin/activate 25 | 26 | # Install Doraemon 27 | pip install doraemon-torch 28 | 29 | # If you need to install in editable mode (for development) 30 | pip install -e . 31 | ``` 32 | 33 | ## 📢 What's New 34 | 35 | - 🎁 2025.11.07: [Doraemon paper](https://arxiv.org/abs/2511.04394) paper released; welcome to cite our paper if you find the project useful for your research or development. 36 | - 🎁 2025.03.16: Doraemon v0.1.0 released 37 | - 🎁 2024.10.01: Content-Based Image Retrieval (CBIR): We collect a product dataset from Kaggle & TianChi with a complete pipeline for training, end-to-end validation, and visualization. Please check [ImageRetrieval.md](doraemon/models/representation/README_CBIR.md) 38 | - 🎁 2024.04.01: Face Recognition: Based on a cleaned MS-Celeb-1M-v1c with over 70,000 IDs and 3.6 million images, validated with LFW. Includes loss functions like ArcFace, CircleLoss, and MagFace. 39 | - 🎁 2023.06.01: Image Classification (IC): Given the Oxford-IIIT Pet dataset. Supports different learning rates for different layers, hard example mining, multi-label and single-label training, bad case analysis, GradCAM visualization, automatic labeling to aid semi-supervised training, and category-specific data augmentation. Refer to [ImageClassification.md](doraemon/models/classifier/README.md) 40 | 41 | ## ✨ Highlights 42 | - [Optimization Algorithms](doraemon/engine/optimizer.py): Various optimization techniques to enhance model training efficiency, including SGD, Adam, and SAM (Sharpness-Aware Minimization). 43 | 44 | - [Data Augmentation](doraemon/dataset/transforms.py): A variety of data augmentation techniques to improve model robustness, such as CutOut, Color-Jitter, and Copy-Paste etc. 45 | 46 | - [Regularization](doraemon/engine/optimizer.py): Techniques to prevent overfitting and improve model generalization, including Label Smoothing, OHEM, Focal Loss, and Mixup. 47 | 48 | - [Visualization](doraemon/utils/cam.py): Integrated visualization tool to understand model decision-making, featuring GradCAM. 49 | 50 | - [Personalized Data Augmentation](doraemon/built/class_augmenter.py): Apply exclusive data augmentation to specific classes with Class-Specific Augmentation. 51 | 52 | - [Personalized Hyperparameter Tuning](doraemon/built/layer_optimizer.py): Apply different learning rates to specific layers using Layer-Specific Learning Rates. 53 | 54 | ## 🚀 Deployment API 55 | 56 | Doraemon offers incredibly simple yet powerful deployment options: 57 | 58 | - **Local API Inference**: Deploy models with just a single weight file (*.pt) - one command setup for high-performance local inference 59 | - **Seamless HuggingFace Integration**: Effortlessly deploy to the Huggingface ecosystem with full support for: 60 | - `AutoModel.from_pretrained()` 61 | - `AutoProcessor.from_pretrained()` 62 | - And all standard Hugging Face API interfaces 63 | 64 | For detailed deployment instructions and ready-to-use examples, see our [Deployment Guide](deploy/README.md). 65 | 66 | ## 📚 Tutorials 67 | 68 | For detailed guidance on specific tasks, please refer to the following resources: 69 | 70 | - **Image Classification**: If you are working on image classification tasks, please refer to [Doc: Image Classification](doraemon/models/classifier/README.md). 71 | 72 | - **Image Retrieval**: For image retrieval tasks, please refer to [Doc: Image Retrieval](doraemon/models/representation/README_CBIR.md). 73 | 74 | - **Face Recognition**: Stay tuned. 75 | 76 | ## 📊 Datasets 77 | 78 | Doraemon integrates the following datasets, allowing users to quickly start training: 79 | 80 | - **Image Retrieval**: Available at [Ecommerce Product](https://huggingface.co/datasets/wuji3/image-retrieval) 81 | - **Face Recognition**: Available at [MS-Celeb-1M-v1c](https://huggingface.co/datasets/wuji3/face-recognition) 82 | - **Image Classification**: Available at [Oxford-IIIT Pet](https://huggingface.co/datasets/wuji3/oxford-iiit-pet) 83 | 84 | ## 🧩 Supported Models 85 | 86 | **Doraemon** now supports 1000+ models through integration with Timm: 87 | 88 | - All models from `timm.list_models(pretrained=True)` 89 | - Including CLIP, SigLIP, DeiT, BEiT, MAE, EVA, DINO and more 90 | 91 | [Model Performance Benchmarks](https://github.com/huggingface/pytorch-image-models/tree/main/results) can help you select the most suitable model by comparing: 92 | - Inference speed 93 | - Training efficiency 94 | - Accuracy across different datasets 95 | - Parameter count vs performance trade-offs 96 | 97 | > For detailed benchmark results, see [@huggingface/pytorch-image-models#1933](https://github.com/huggingface/pytorch-image-models/issues/1933) 98 | 99 | ## Citation 100 | 101 | If you find **Doraemon** useful for your research or development, please cite the following paper: 102 | 103 | ``` 104 | @misc{du2025visual, 105 | title={DORAEMON: A Unified Library for Visual Object Modeling and Representation Learning at Scale}, 106 | author={Ke Du and Yimin Peng and Chao Gao and Fan Zhou and Siqiao Xue}, 107 | year={2025}, 108 | journal={arXiv preprint arXiv:2511.04394}, 109 | url={https://arxiv.org/abs/2511.04394}, 110 | } 111 | ``` 112 | -------------------------------------------------------------------------------- /scripts/infer.py: -------------------------------------------------------------------------------- 1 | from doraemon import (increment_path, 2 | create_AugTransforms, 3 | colorstr, 4 | PredictImageDatasets, 5 | Visualizer, 6 | get_model, 7 | SmartLogger, 8 | FaceModelLoader, 9 | valuate_cbir) 10 | 11 | from torch.utils.data import DataLoader 12 | import os 13 | import argparse 14 | from pathlib import Path 15 | import torch 16 | import time 17 | from tqdm import tqdm 18 | 19 | LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) 20 | ROOT = Path(os.path.dirname(__file__)) 21 | 22 | def parse_opt(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('model-path', default='run/exp8', help='Path to model configs') 25 | 26 | # classification 27 | parser.add_argument('--data', default = ROOT / 'data', help='Target data directory') 28 | parser.add_argument('--infer-option', choices=['default', 'autolabel'], default='default', 29 | help='default: Infer + Visualize + CaseAnalysis, autolabel: Infer + Label') 30 | parser.add_argument('--split', default=None, type=str, help='Split to visualize') 31 | parser.add_argument('--classes', default=None, nargs='+', help='Which class to check') 32 | parser.add_argument('--ema', action='store_true', help = 'Exponential Moving Average for model weight') 33 | parser.add_argument('--sampling', default=None, type=int, help='Sample n images for visualization') 34 | 35 | # CBIR 36 | parser.add_argument('--max_rank', default=10, type=int, help='Visualize top k retrieval results') 37 | parser.add_argument('--root', default = None, help = 'Prediction root path for cbir dataset (If need change from cfgs)') 38 | 39 | # Unless specific needs, it is generally not modified below. 40 | parser.add_argument('--show_path', default = ROOT / 'inference') 41 | parser.add_argument('--name', default = 'exp') 42 | parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify') 43 | 44 | return parser.parse_args() 45 | 46 | if __name__ == '__main__': 47 | if LOCAL_RANK != -1: 48 | device = torch.device('cuda', LOCAL_RANK) 49 | else: 50 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 51 | 52 | opt = parse_opt() 53 | visual_dir = increment_path(Path(opt.show_path) / opt.name) 54 | 55 | pt_file = Path(getattr(opt, 'model-path')) 56 | pt = torch.load(pt_file, weights_only=False) 57 | config = pt['config'] 58 | task: str = config['model']['task'] 59 | 60 | logger = SmartLogger(filename=None, level=1) if LOCAL_RANK in {-1,0} else None 61 | 62 | if task == 'classification': 63 | modelwrapper = get_model(config['model'], None, LOCAL_RANK) 64 | model = modelwrapper.load_weight(pt, ema=opt.ema, device=device) 65 | 66 | if opt.classes is None: 67 | opt.classes = list(pt['label2id']) 68 | 69 | dataset = PredictImageDatasets(opt.data, 70 | transforms=create_AugTransforms(config['data']['val']['augment']), 71 | sampling=opt.sampling, 72 | classes=opt.classes, 73 | split=opt.split, 74 | require_gt= opt.infer_option == 'default') 75 | dataloader = DataLoader(dataset, 76 | shuffle=False, 77 | pin_memory=True, 78 | num_workers=config['data']['nw'], 79 | batch_size=1, 80 | collate_fn=PredictImageDatasets.collate_fn) 81 | 82 | t0 = time.time() 83 | Visualizer.predict_images(model, 84 | dataloader, 85 | device, 86 | visual_dir, 87 | pt['id2label'], 88 | logger, 89 | config['hyp']['loss']['bce'][1] if config['hyp']['loss']['bce'][0] else 0, 90 | opt.infer_option 91 | ) 92 | 93 | logger.console(f'\nPredicting complete ({(time.time() - t0) / 60:.3f} minutes)' 94 | f"\nResults saved to {colorstr('bold', visual_dir)}") 95 | elif task in ('face', 'cbir'): 96 | # logger 97 | logger = SmartLogger(filename=None) 98 | 99 | # checkpoint loading 100 | logger.console(f'Loading Model, EMA Is {opt.ema}') 101 | model_loader = FaceModelLoader(model_cfg=config['model']) 102 | model = model_loader.load_weight(model_path=pt_file, ema=opt.ema) 103 | 104 | if opt.root is not None: 105 | config['data']['root'] = opt.root 106 | 107 | config['data']['val']['metrics']['cutoffs'] = [opt.max_rank] 108 | metrics, retrieval_results, scores, ground_truths, queries, query_dataset, gallery_dataset = valuate_cbir(model, 109 | config['data'], 110 | device, 111 | logger, 112 | vis=True) 113 | 114 | for idx, q in tqdm(enumerate(queries), total=len(queries), desc='Visualizing', position=0): 115 | Visualizer.visualize_results(q, 116 | retrieval_results[idx], 117 | scores[idx], 118 | ground_truths[idx], 119 | visual_dir, 120 | opt.max_rank, 121 | query_dataset, 122 | gallery_dataset 123 | ) 124 | 125 | logger.console(f'Metrics: {metrics}') 126 | 127 | else: 128 | raise ValueError(f'Unknown task {task}') -------------------------------------------------------------------------------- /doraemon/engine/representation/eval_face.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from ...models.representation.face_model import FeatureExtractor 4 | from ...dataset.basedataset import PredictImageDatasets 5 | from ...dataset.transforms import create_AugTransforms 6 | from torch.utils.data import DataLoader 7 | 8 | def process_pairtxt(pair_txt: str, imgdir: str): 9 | assert os.path.isfile(pair_txt), f'please check the path of {pair_txt}' 10 | 11 | pair_array= np.loadtxt(pair_txt, dtype=str) 12 | unique_face_images = np.unique(pair_array[:,:2].flatten()).tolist() 13 | unique_face_realpath = [os.path.join(imgdir,'val', path) for path in unique_face_images] 14 | 15 | pair_list = pair_array.tolist() 16 | 17 | return unique_face_realpath, pair_list 18 | class Evaluator: 19 | 20 | def __init__(self, feature_extractor): 21 | """ 22 | Args: 23 | feature_model: a feature extractor. 24 | """ 25 | self.feature_extractor = feature_extractor 26 | 27 | def test(self, pair_list, feature_dataloader, device): 28 | # check pair_list 29 | Evaluator.check_nps(pair_list) 30 | image_name2feature = self.feature_extractor.extract_face(feature_dataloader, device) 31 | mean, std = self.test_one_model(pair_list, image_name2feature) 32 | return mean, std 33 | 34 | def test_one_model(self, test_pair_list, image_name2feature, is_normalize=True): 35 | """Get the accuracy of a model. 36 | 37 | Args: 38 | test_pair_list: the pair list given by PairsParser. 39 | image_name2feature: the map of image name and it's feature. 40 | is_normalize: wether the feature is normalized. 41 | 42 | Returns: 43 | mean: estimated mean accuracy. 44 | std: standard error of the mean. 45 | """ 46 | nps = len(test_pair_list) 47 | nps_one_group = nps // 10 48 | subsets_score_list = np.zeros((10, nps_one_group), dtype=np.float32) 49 | subsets_label_list = np.zeros((10, nps_one_group), dtype=np.int8) 50 | for index, cur_pair in enumerate(test_pair_list): 51 | cur_subset = index // 600 52 | cur_id = index % 600 53 | image_name1 = os.path.normpath(cur_pair[0]) 54 | image_name2 = os.path.normpath(cur_pair[1]) 55 | label = cur_pair[2] if type(cur_pair[2]) is int else int(cur_pair[2]) 56 | subsets_label_list[cur_subset][cur_id] = label 57 | feat1 = image_name2feature[image_name1] 58 | feat2 = image_name2feature[image_name2] 59 | if not is_normalize: 60 | feat1 = feat1 / np.linalg.norm(feat1) 61 | feat2 = feat2 / np.linalg.norm(feat2) 62 | cur_score = np.dot(feat1, feat2) 63 | subsets_score_list[cur_subset][cur_id] = cur_score 64 | 65 | subset_train = np.array([True] * 10) 66 | accu_list = [] 67 | for subset_idx in range(10): 68 | test_score_list = subsets_score_list[subset_idx] 69 | test_label_list = subsets_label_list[subset_idx] 70 | subset_train[subset_idx] = False 71 | train_score_list = subsets_score_list[subset_train].flatten() 72 | train_label_list = subsets_label_list[subset_train].flatten() 73 | subset_train[subset_idx] = True 74 | best_thres = self.getThreshold(train_score_list, train_label_list) 75 | positive_score_list = test_score_list[test_label_list == 1] 76 | negtive_score_list = test_score_list[test_label_list == 0] 77 | true_pos_pairs = np.sum(positive_score_list > best_thres) 78 | true_neg_pairs = np.sum(negtive_score_list < best_thres) 79 | accu_list.append((true_pos_pairs + true_neg_pairs) / 600) 80 | mean = np.mean(accu_list) 81 | std = np.std(accu_list, ddof=1) / np.sqrt(10) # ddof=1, division 9. 82 | return mean, std 83 | 84 | def getThreshold(self, score_list, label_list, num_thresholds=1000): 85 | """Get the best threshold by train_score_list and train_label_list. 86 | Args: 87 | score_list(ndarray): the score list of all pairs. 88 | label_list(ndarray): the label list of all pairs. 89 | num_thresholds(int): the number of threshold that used to compute roc. 90 | Returns: 91 | best_thres(float): the best threshold that computed by train set. 92 | """ 93 | pos_score_list = score_list[label_list == 1] 94 | neg_score_list = score_list[label_list == 0] 95 | pos_pair_nums = pos_score_list.size 96 | neg_pair_nums = neg_score_list.size 97 | score_max = np.max(score_list) 98 | score_min = np.min(score_list) 99 | score_span = score_max - score_min 100 | step = score_span / num_thresholds 101 | threshold_list = score_min + step * np.array(range(1, num_thresholds + 1)) 102 | fpr_list = [] 103 | tpr_list = [] 104 | for threshold in threshold_list: 105 | fpr = np.sum(neg_score_list > threshold) / neg_pair_nums # FP / [(FP + TN): all negative] 106 | tpr = np.sum(pos_score_list > threshold) / pos_pair_nums # TP / [(TP + FN): all positive] 107 | fpr_list.append(fpr) 108 | tpr_list.append(tpr) 109 | fpr = np.array(fpr_list) 110 | tpr = np.array(tpr_list) 111 | best_index = np.argmax(tpr - fpr) # top-left in ROC-Curve 112 | best_thres = threshold_list[best_index] 113 | return best_thres 114 | 115 | @staticmethod 116 | def check_nps(pair_list): 117 | """check the number of pairs is a multiple of 10""" 118 | assert len(pair_list) % 10 == 0, 'make sure the number of rows is a multiple of 10 in pair.txt' 119 | 120 | def valuate(model, 121 | data_cfg, 122 | device): 123 | 124 | # feature extractor 125 | feature_extractor = FeatureExtractor(model) 126 | 127 | # process pairtxt 128 | test_images_path, pair_list = process_pairtxt(data_cfg['val']['pair_txt'], data_cfg['root']) 129 | 130 | # dataloader 131 | feature_dataset = PredictImageDatasets(transforms=create_AugTransforms(data_cfg['val']['augment'])) 132 | feature_dataset.imgs_path = test_images_path 133 | feature_dataloader = DataLoader(feature_dataset, shuffle=False, pin_memory=True, num_workers=data_cfg['nw'], 134 | batch_size=data_cfg['val']['bs'], 135 | collate_fn=PredictImageDatasets.collate_fn) 136 | 137 | evaluator = Evaluator(feature_extractor) 138 | mean, std = evaluator.test(pair_list, feature_dataloader, device) 139 | 140 | return mean, std 141 | -------------------------------------------------------------------------------- /doraemon/utils/cam.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Callable, Optional, List 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from timm.models import VisionTransformer, \ 7 | SwinTransformer, \ 8 | ResNet, \ 9 | MobileNetV3, \ 10 | ConvNeXt, \ 11 | SwinTransformerV2, \ 12 | SENet, \ 13 | EfficientNet 14 | 15 | from PIL.Image import Image as ImageType 16 | from torchvision.transforms import Compose 17 | from ..dataset.transforms import SPATIAL_TRANSFORMS 18 | from ..dataset.transforms import PadIfNeed, Reverse_PadIfNeed, ResizeAndPadding2Square, ReverseResizeAndPadding2Square 19 | import cv2 20 | 21 | from pytorch_grad_cam import GradCAM, \ 22 | ScoreCAM, \ 23 | GradCAMPlusPlus, \ 24 | AblationCAM, \ 25 | XGradCAM, \ 26 | EigenCAM, \ 27 | EigenGradCAM, \ 28 | LayerCAM, \ 29 | FullGrad 30 | 31 | from pytorch_grad_cam.utils.image import show_cam_on_image 32 | from pytorch_grad_cam.ablation_layer import AblationLayerVit 33 | 34 | # -----------------ClassifierOutputTarget is used to specify the targets, specifically which class the model is looking at---------------------- # 35 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 36 | # ---------------------------------------------------------------------------------------------- # 37 | 38 | class ClassActivationMaper: 39 | 40 | METHODS = \ 41 | {"gradcam": GradCAM, 42 | "scorecam": ScoreCAM, 43 | "gradcam++": GradCAMPlusPlus, 44 | "ablationcam": AblationCAM, 45 | "xgradcam": XGradCAM, 46 | "eigencam": EigenCAM, 47 | "eigengradcam": EigenGradCAM, 48 | "layercam": LayerCAM, 49 | "fullgrad": FullGrad} 50 | 51 | def __init__(self, model: nn.Module, method: str, device, transforms): 52 | self.model = model 53 | self.device = device 54 | 55 | target_layers, reshape_transform = self._create_target_layers_and_transform(model) 56 | 57 | # cam 58 | if method not in self.METHODS: 59 | raise Exception(f"Method {method} not implemented") 60 | if method == "ablationcam": 61 | self.cam = ClassActivationMaper.METHODS[method](model=model, 62 | target_layers=target_layers, 63 | reshape_transform=reshape_transform, 64 | use_cuda = device != torch.device('cpu'), 65 | ablation_layer=AblationLayerVit()) 66 | else: 67 | self.cam = ClassActivationMaper.METHODS[method](model=model, 68 | target_layers=target_layers, 69 | use_cuda = device != torch.device('cpu'), 70 | reshape_transform=reshape_transform) 71 | 72 | self.spatial_transforms, reversed_fun = ClassActivationMaper.pickup_spatial_transforms(transforms) 73 | 74 | if reversed_fun is not None: 75 | self.reverse_pad2square = reversed_fun 76 | 77 | def __call__(self, 78 | image, 79 | input_tensor: torch.Tensor, 80 | dsize: Tuple[int, int],# w, h 81 | targets: Optional[List[ClassifierOutputTarget]] = None) -> np.array: 82 | grayscale_cam = self.cam(input_tensor=input_tensor, 83 | targets=targets, 84 | eigen_smooth=False, 85 | aug_smooth=False) 86 | 87 | grayscale_cam = grayscale_cam[0, :] 88 | 89 | if not isinstance(image, ImageType): 90 | raise ValueError("Only images read by PIL.Image are allowed") 91 | 92 | image = self.spatial_transforms(image) 93 | image = np.array(image, dtype=np.float32) 94 | 95 | cam_image = show_cam_on_image(image / 255, grayscale_cam) 96 | 97 | if hasattr(self, 'reverse_pad2square'): 98 | if max(dsize) != max(cam_image.shape): 99 | if isinstance(self.reverse_pad2square, Reverse_PadIfNeed): 100 | cam_image = cv2.resize(cam_image, (max(dsize), max(dsize)), cv2.INTER_LINEAR) 101 | elif isinstance(self.reverse_pad2square, ReverseResizeAndPadding2Square): 102 | pass 103 | else: 104 | raise ValueError(f"{type(self.reverse_pad2square)} not support reverse function") 105 | cam_image = self.reverse_pad2square(cam_image, dsize) 106 | return cam_image 107 | 108 | def _create_target_layers_and_transform(self, model: nn.Module) -> Tuple[list, Optional[Callable]]: 109 | 110 | if isinstance(model, (SwinTransformer, SwinTransformerV2)): 111 | return [model.norm], lambda tensor: torch.permute(tensor, dims=[0, 3, 1, 2]) 112 | 113 | elif isinstance(model, VisionTransformer): 114 | # Get patch and image size information 115 | patch_size = model.patch_embed.patch_size 116 | if isinstance(patch_size, int): 117 | patch_size = (patch_size, patch_size) 118 | 119 | img_size = model.patch_embed.img_size 120 | if isinstance(img_size, int): 121 | img_size = (img_size, img_size) 122 | 123 | # Calculate feature map size 124 | feature_size = (img_size[0] // patch_size[0], 125 | img_size[1] // patch_size[1]) 126 | 127 | def reshape_transform(tensor): 128 | # Remove CLS token 129 | tensor = tensor[:, 1:, :] 130 | 131 | # Reshape to [batch_size, height, width, channels] 132 | B, _, C = tensor.shape 133 | H, W = feature_size 134 | tensor = tensor.reshape(B, H, W, C) 135 | 136 | # Convert to [batch_size, channels, height, width] 137 | tensor = tensor.permute(0, 3, 1, 2) 138 | return tensor 139 | 140 | return [model.blocks[-1].norm1], reshape_transform 141 | 142 | elif isinstance(model, MobileNetV3): 143 | return [model.blocks[-1][0].conv], None 144 | 145 | elif isinstance(model, (SENet, ResNet)): 146 | return [model.layer4[-1].conv3], None 147 | 148 | elif isinstance(model, ConvNeXt): 149 | return [model.norm_pre], None 150 | 151 | elif isinstance(model, EfficientNet): 152 | return [model.bn2], None 153 | 154 | else: 155 | raise KeyError(f'{type(model)} not support yet') 156 | 157 | @staticmethod 158 | def pickup_spatial_transforms(transforms: Compose): 159 | sequence = [] 160 | reversed_fun = None 161 | for t in transforms.transforms: 162 | if type(t) in SPATIAL_TRANSFORMS: 163 | sequence.append(t) 164 | if type(t) is PadIfNeed: 165 | reversed_fun = Reverse_PadIfNeed(mode=t.mode) 166 | elif type(t) is ResizeAndPadding2Square: 167 | reversed_fun = ReverseResizeAndPadding2Square(size=t.size) 168 | 169 | return Compose(sequence), reversed_fun -------------------------------------------------------------------------------- /deploy/doraemon_modeling.py: -------------------------------------------------------------------------------- 1 | from transformers import PreTrainedModel, PretrainedConfig 2 | from transformers.processing_utils import ProcessorMixin 3 | from transformers.feature_extraction_utils import BatchFeature 4 | from typing import Optional 5 | import torch 6 | from huggingface_hub import hf_hub_download 7 | from doraemon import create_AugTransforms 8 | from timm import create_model 9 | import torch.nn.functional as F 10 | import numpy as np 11 | import os 12 | 13 | class DoraemonConfig(PretrainedConfig): 14 | model_type = "doraemon" 15 | 16 | def __init__( 17 | self, 18 | **kwargs 19 | ): 20 | 21 | super().__init__(**kwargs) 22 | model_path = kwargs.pop("model_path", "") 23 | if model_path: 24 | if not os.path.exists(model_path): 25 | # user/repo/filename:revision 26 | repo_id, filename = model_path.rsplit("/", 1) 27 | filename, revision = filename.split(":") 28 | model_path = hf_hub_download(repo_id=repo_id, 29 | filename=filename, 30 | cache_dir=None, 31 | revision=revision, 32 | force_download=False) 33 | self.model_path = model_path 34 | pt = torch.load(model_path, weights_only=False) 35 | self.label2id = pt.get("label2id", self.label2id) 36 | self.id2label = pt.get("id2label", self.id2label) 37 | self.transforms = pt["config"].get("data", {}).get("val", {}).get("augment", {}) 38 | self.task = pt["config"].get("model", {}).get("task", None) 39 | self.num_classes = pt["config"].get("model", {}).get("num_classes", None) 40 | threshold = 0 41 | if pt["config"].get('hyp', {}).get('loss', {}).get('bce', [False])[0]: 42 | threshold = pt["config"]['hyp']['loss']['bce'][1] 43 | if isinstance(threshold, (int, float)): 44 | threshold = [threshold] * self.num_classes 45 | assert len(threshold) == self.num_classes and isinstance(threshold, list), "threshold must be a list of length num_classes" 46 | self.threshold = threshold 47 | self.timm_model = pt["config"].get("model", {}).get("name").split("-")[1] 48 | 49 | @classmethod 50 | def get_config_dict(cls, pretrained_model_name_or_path, **kwargs): 51 | config_dict, kwargs = super().get_config_dict(pretrained_model_name_or_path, **kwargs) 52 | for key in list(kwargs.keys()): 53 | if key in config_dict: 54 | value = kwargs.pop(key) 55 | if value is not None: 56 | config_dict[key] = value 57 | 58 | return config_dict, kwargs 59 | 60 | class DoraemonProcessor(ProcessorMixin): 61 | attributes = [] 62 | config_class = DoraemonConfig 63 | 64 | def __init__(self, config: Optional[DoraemonConfig] = None, **kwargs): 65 | super().__init__() 66 | self.config = config 67 | self.transforms = create_AugTransforms(config.transforms) 68 | self.threshold = np.array(config.threshold) 69 | 70 | def __call__( 71 | self, 72 | image, 73 | return_tensors: Optional[str] = "pt", 74 | **kwargs 75 | ) -> BatchFeature: 76 | 77 | image_tensors = self.preprocess(image) 78 | 79 | return BatchFeature( 80 | data={ 81 | "pixel_values": image_tensors, 82 | }, 83 | tensor_type=return_tensors 84 | ) 85 | 86 | def preprocess(self, images, *args, **kwargs): 87 | if not isinstance(images, list): 88 | images = [images] 89 | 90 | for idx, im in enumerate(images): 91 | images[idx] = self.transforms(im) 92 | return torch.stack(images, dim=0) 93 | 94 | def postprocess(self, probs): 95 | batch_size = probs.shape[0] 96 | results = [] 97 | 98 | if self.config.threshold != 0: 99 | above_threshold = probs > self.config.threshold 100 | 101 | for b in range(batch_size): 102 | result = {self.config.id2label[i]: 0 for i in range(len(self.config.id2label))} 103 | for i in np.where(above_threshold[b])[0]: 104 | result[self.config.id2label[i]] = 1 105 | results.append(result) 106 | else: 107 | indices = np.argmax(probs, axis=1) 108 | for b in range(batch_size): 109 | idx = indices[b] 110 | results.append({self.config.id2label[idx]: float(probs[b, idx])}) 111 | 112 | return results 113 | 114 | @classmethod 115 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 116 | config: DoraemonConfig = kwargs.get("config", None) 117 | if config is None: 118 | config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) 119 | return cls(config=config) 120 | 121 | class DoraemonClassifier(PreTrainedModel): 122 | config_class = DoraemonConfig 123 | 124 | def __init__(self, config: Optional[DoraemonConfig] = None, **kwargs): 125 | super().__init__(config) 126 | self.model = create_model(config.timm_model, pretrained=False, num_classes=config.num_classes) 127 | 128 | def forward(self, pixel_values): 129 | with torch.inference_mode(): 130 | output = self.model(pixel_values) 131 | if self.config.threshold != 0: 132 | output = torch.sigmoid(output) 133 | else: 134 | output = F.softmax(output, dim=1) 135 | 136 | return output.cpu().numpy() 137 | 138 | @classmethod 139 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 140 | config = kwargs.get("config", None) 141 | if config is None: 142 | config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) 143 | 144 | wrapper = cls(config=config) 145 | 146 | model_path = config.model_path 147 | 148 | pt = torch.load(model_path, weights_only=False) 149 | wrapper.model.load_state_dict(pt['ema'].state_dict() if pt.get("ema", False) else pt['model']) 150 | wrapper.model.eval() 151 | 152 | return wrapper 153 | if __name__ == "__main__": 154 | from transformers import AutoModel, AutoProcessor 155 | import requests 156 | from io import BytesIO 157 | from PIL import Image 158 | import argparse 159 | 160 | def parse_args(): 161 | parser = argparse.ArgumentParser(description='Doraemon Model Inference Deploy With Huggingface') 162 | parser.add_argument('--pretrained_model_name_or_path', type=str, default='./', 163 | help='pretrained model name or path') 164 | parser.add_argument('--revision', type=str, default=None, 165 | help='model revision') 166 | parser.add_argument('--model_path', type=str, default=None, 167 | help='If given, it will be override the model_path in config.json') 168 | args = parser.parse_args() 169 | 170 | return args 171 | 172 | args = parse_args() 173 | print(args) 174 | model = AutoModel.from_pretrained(args.pretrained_model_name_or_path, 175 | revision=args.revision, 176 | model_path=args.model_path, 177 | trust_remote_code=True) 178 | processor = AutoProcessor.from_pretrained(args.pretrained_model_name_or_path, 179 | revision=args.revision, 180 | model_path=args.model_path, 181 | trust_remote_code=True) 182 | 183 | urls = [ 184 | "https://github.com/ultralytics/ultralytics/blob/main/ultralytics/assets/bus.jpg", 185 | "https://github.com/ultralytics/ultralytics/blob/main/ultralytics/assets/zidane.jpg", 186 | ] 187 | 188 | images = [] 189 | for url in urls: 190 | response = requests.get(url) 191 | image = Image.open(BytesIO(response.content)) 192 | images.append(image) 193 | 194 | batch_inputs = processor(images, return_tensors="pt") 195 | 196 | probs = model(batch_inputs["pixel_values"]) 197 | print("Probs:\n", probs) 198 | output = processor.postprocess(probs) 199 | print("Tagging:") 200 | for r in output: 201 | print(r) -------------------------------------------------------------------------------- /doraemon/engine/procedure/visualizer.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from ...utils.plots import Annotator 3 | import platform 4 | import shutil 5 | import os 6 | import torch.nn.functional as F 7 | import cv2 8 | from typing import Union 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from functools import partial 12 | 13 | class Visualizer: 14 | 15 | @staticmethod 16 | def predict_images(model, 17 | dataloader, 18 | device, 19 | visual_path, 20 | class_indices: dict, 21 | logger, 22 | thresh: Union[float, list[float]], 23 | infer_option: str = 'default', 24 | ): 25 | """ 26 | Args: 27 | infer_option: 28 | - 'default': Infer + Visualize + GradCAM + Badcase 29 | - 'autolabel': Infer + Label 30 | """ 31 | os.makedirs(visual_path, exist_ok=True) 32 | is_single_label = isinstance(thresh, (int, float)) and thresh == 0 33 | 34 | # Determine classification head type and activation function once 35 | class_head = 'ce' if is_single_label else 'bce' 36 | activation_fn = partial(F.softmax, dim=0) if class_head == 'ce' else partial(F.sigmoid) 37 | 38 | target_classes = dataloader.dataset.classes if dataloader.dataset.classes is not None else None 39 | if target_classes and not isinstance(target_classes, list): 40 | target_classes = [target_classes] 41 | 42 | # 获取每个目标类别的索引和阈值 43 | target_indices = [] 44 | target_thresholds = [] 45 | if not is_single_label and isinstance(thresh, list): 46 | for target_class in target_classes: 47 | target_idx = None 48 | for idx, class_name in class_indices.items(): 49 | if class_name == target_class: 50 | target_idx = idx 51 | target_indices.append(idx) 52 | break 53 | if target_idx is None: 54 | raise ValueError(f"Target class {target_class} not found in class indices") 55 | 56 | # 获取并验证目标类别的阈值 57 | target_thresh = thresh[target_idx] 58 | if not isinstance(target_thresh, float): 59 | raise ValueError(f"Invalid threshold type for target class: {type(target_thresh)}. Must be float") 60 | target_thresholds.append(target_thresh) 61 | 62 | # Initialize CAM if in default mode 63 | if infer_option == 'default': 64 | from ...utils.cam import ClassActivationMaper 65 | cam = ClassActivationMaper(model, method='gradcam', device=device, transforms=dataloader.dataset.transforms) 66 | 67 | # eval mode 68 | model.eval() 69 | n = len(dataloader) 70 | 71 | fixed_class_length = 15 72 | progress_width = len(str(n)) 73 | 74 | image_postfix_table = dict() 75 | for i, (img, inputs, img_path, gt_labels) in enumerate(dataloader): 76 | img = img[0] 77 | img_path = img_path[0] 78 | gt_label = gt_labels[0] if gt_labels is not None else None 79 | 80 | if infer_option == 'default': 81 | cam_image = cam(image=img, input_tensor=inputs, dsize=img.size) 82 | cam_image = cv2.resize(cam_image, img.size, interpolation=cv2.INTER_LINEAR) 83 | 84 | # system 85 | if platform.system().lower() == 'windows': 86 | annotator = Annotator(img, font=r'C:/WINDOWS/FONTS/SIMSUN.TTC') # windows 87 | else: 88 | annotator = Annotator(img) # linux 89 | 90 | # transforms 91 | inputs = inputs.to(device) 92 | # forward 93 | logits = model(inputs).squeeze() 94 | 95 | # post process using pre-determined activation function 96 | probs = activation_fn(logits) 97 | top5i = probs.argsort(0, descending=True)[:5].tolist() 98 | 99 | text = '\n'.join(f'{class_indices[j]:<{fixed_class_length}} {probs[j].item():.2f}' for j in top5i) 100 | 101 | formatted_predictions = ' '.join(f'{class_indices[j]:<{fixed_class_length}}{probs[j].item():.2f}' for j in top5i) 102 | logger.console(f"[{i+1:>{progress_width}}|{n:<{progress_width}}] {os.path.basename(img_path):<20} {formatted_predictions}") 103 | 104 | annotator.text((32, 32), text, txt_color=(0, 0, 0)) 105 | 106 | # Save predictions and ground truth 107 | save_dir = os.path.join(visual_path, 'labels') 108 | os.makedirs(save_dir, exist_ok=True) 109 | image_postfix_table[os.path.basename(os.path.splitext(img_path)[0] + '.txt')] = { 110 | 'ext': os.path.splitext(img_path)[1], 111 | 'gt': gt_label 112 | } 113 | with open(os.path.join(save_dir, os.path.basename(os.path.splitext(img_path)[0] + '.txt')), 'a') as f: 114 | f.write(text + '\n') 115 | 116 | if infer_option == 'default': 117 | img = np.hstack([cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR), cam_image]) 118 | cv2.imwrite(os.path.join(visual_path, os.path.basename(img_path)), img) 119 | 120 | # Process badcases only in default mode 121 | if infer_option == 'default': 122 | badcase_root = os.path.join(visual_path, 'badcase') 123 | if target_classes: 124 | for target_class in target_classes: 125 | os.makedirs(os.path.join(badcase_root, target_class), exist_ok=True) 126 | else: 127 | os.makedirs(badcase_root, exist_ok=True) 128 | 129 | for txt in glob.glob(os.path.join(visual_path, 'labels', '*.txt')): 130 | with open(txt, 'r') as f: 131 | lines = f.readlines() 132 | gt = image_postfix_table[os.path.basename(txt)]['gt'] 133 | 134 | if gt is None: 135 | continue 136 | 137 | if is_single_label: 138 | pred_class = lines[0].rsplit(' ',1)[0] 139 | is_badcase = pred_class != gt 140 | target_class = gt # For determining save path 141 | else: 142 | # Multi-label case, check each target class 143 | is_badcase = False 144 | target_class = None # For determining save path 145 | for target, thresh in zip(target_classes, target_thresholds): 146 | found_correct_pred = False 147 | for line in lines: 148 | class_name, prob = line.rsplit(' ', 1) 149 | prob = float(prob) 150 | if class_name == target: 151 | if prob < thresh: 152 | is_badcase = True 153 | target_class = target # Record the class causing the badcase 154 | found_correct_pred = True 155 | break 156 | if not found_correct_pred: 157 | is_badcase = True 158 | target_class = target 159 | 160 | if is_badcase: 161 | try: 162 | source_path = os.path.join(visual_path, 163 | os.path.basename(txt).replace('.txt', 164 | image_postfix_table[os.path.basename(txt)]['ext'])) 165 | if target_class and target_classes: 166 | dest_path = os.path.join(badcase_root, target_class) 167 | else: 168 | dest_path = badcase_root 169 | shutil.move(source_path, dest_path) 170 | except FileNotFoundError: 171 | print(f'FileNotFoundError->{txt}') 172 | 173 | @staticmethod 174 | def visualize_results(query, 175 | retrieval_results, 176 | scores, 177 | ground_truths, 178 | savedir, 179 | max_rank=5, 180 | query_dataset=None, 181 | gallery_dataset=None 182 | ): 183 | 184 | os.makedirs(savedir, exist_ok=True) 185 | 186 | fig, axes = plt.subplots(2, max_rank + 1, figsize=(3 * (max_rank + 1), 12)) 187 | 188 | for ax in axes.ravel(): 189 | ax.set_axis_off() 190 | # Display the query image in the first position of the second row 191 | query_img = query_dataset.get_image(query) 192 | ax = fig.add_subplot(2, max_rank + 1, max_rank + 2) 193 | ax.imshow(query_img) 194 | ax.set_title('Query') 195 | ax.axis("off") 196 | 197 | # Display the ground truth images 198 | for i in range(min(5, len(ground_truths))): 199 | gt_img = gallery_dataset.get_image(ground_truths[i]) 200 | ax = fig.add_subplot(2, max_rank + 1, i + 1) 201 | ax.imshow(gt_img) 202 | ax.set_title('Ground Truth') 203 | ax.axis("off") 204 | 205 | # Display the retrieval images 206 | for i in range(max_rank): 207 | retrieval_img = gallery_dataset.get_image(retrieval_results[i]) 208 | 209 | score = scores[i] 210 | is_tp = retrieval_results[i] in ground_truths 211 | label = 'true' if is_tp else 'false' 212 | color = (1, 0, 0) 213 | 214 | ax = fig.add_subplot(2, max_rank + 1, (max_rank + 1) + i + 2) 215 | if is_tp: 216 | ax.add_patch(plt.Rectangle(xy=(0, 0), width=retrieval_img.width - 1, 217 | height=retrieval_img.height - 1, edgecolor=color, 218 | fill=False, linewidth=8)) 219 | ax.imshow(retrieval_img) 220 | ax.set_title('{:.4f}/{}'.format(score, label)) 221 | ax.axis("off") 222 | 223 | #plt.tight_layout() 224 | image_id = os.path.basename(os.path.dirname(query)) 225 | image_name = os.path.basename(query) 226 | image_unique = image_id + '_' + image_name 227 | fig.savefig(os.path.join(savedir, image_unique)) 228 | plt.close(fig) -------------------------------------------------------------------------------- /doraemon/engine/procedure/eval_recog.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.amp import autocast 4 | from tqdm import tqdm 5 | from typing import Callable, Optional, Union, List 6 | from torchmetrics import Precision, Recall, F1Score 7 | from torch import Tensor 8 | import itertools 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from typing import Sequence 12 | from prettytable import PrettyTable 13 | 14 | 15 | __all__ = ['valuate'] 16 | 17 | class ConfusedMatrix: 18 | def __init__(self, nc: int): 19 | self.nc = nc 20 | self.mat = None 21 | 22 | def update(self, gt: Tensor, pred: Tensor): 23 | if self.mat is None: self.mat = torch.zeros((self.nc, self.nc), dtype=torch.int64, device = gt.device) 24 | 25 | idx = gt * self.nc + pred 26 | self.mat += torch.bincount(idx, minlength=self.nc).reshape(self.nc, self.nc) 27 | 28 | def save_conm(self, cm: np.ndarray, classes: Sequence, save_path: str, cmap=plt.cm.cool): 29 | """ 30 | - cm : 计算出的混淆矩阵的值 31 | - classes : 混淆矩阵中每一行每一列对应的列 32 | - normalize : True:显示百分比, False:显示个数 33 | """ 34 | ax = plt.gca() 35 | ax.tick_params(axis="x", top=True, labeltop=True, bottom=False, labelbottom=False) 36 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 37 | plt.colorbar() 38 | tick_marks = [x for x in range(len(classes))] 39 | plt.xticks(tick_marks, classes, rotation=0, fontsize=10) 40 | plt.yticks(tick_marks, classes, fontsize=10) 41 | fmt = '.2f' 42 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 43 | plt.text(j, i, format(cm[i, j], fmt), 44 | horizontalalignment="center", 45 | color="black") 46 | plt.tight_layout() 47 | plt.ylabel('GT', fontsize=12) 48 | plt.xlabel('Predict', fontsize=12) 49 | ax.xaxis.set_label_position('top') 50 | plt.gcf().subplots_adjust(top=0.9) 51 | plt.savefig(save_path) 52 | 53 | def valuate(model: nn.Module, dataloader, device: torch.device, pbar, 54 | is_training: bool = False, lossfn: Optional[Callable] = None, 55 | logger = None, thresh: Union[float, List[float]] = 0, 56 | top_k: int = 5, conm_path: str = None): 57 | """ 58 | Evaluate model performance 59 | 60 | Args: 61 | ... 62 | thresh: float or List[float], threshold for each class in multi-label classification 63 | - float: same threshold for all classes 64 | - List[float]: specific threshold for each class 65 | ... 66 | """ 67 | is_single_label = isinstance(thresh, (int, float)) and thresh == 0 68 | 69 | # Check threshold type 70 | if isinstance(thresh, (list, tuple, np.ndarray)): 71 | # Multi-threshold case: each class uses a different threshold 72 | assert len(thresh) == len(dataloader.dataset.id2label), \ 73 | f'Number of thresholds ({len(thresh)}) must match number of classes ({len(dataloader.dataset.id2label)})' 74 | thresh = torch.tensor(thresh, device=device) 75 | # Verify that all thresholds are within the valid range 76 | assert (thresh > 0).all() and (thresh < 1).all(), \ 77 | 'For multi-label (BCE), all thresholds should be in (0, 1)' 78 | elif isinstance(thresh, (int, float)): 79 | if is_single_label: 80 | # Single-label classification (Softmax) case 81 | pass 82 | else: 83 | # Multi-label classification (BCE), use the same threshold 84 | assert 0 < thresh < 1, 'For multi-label (BCE), threshold should be in (0, 1)' 85 | thresh = torch.full((len(dataloader.dataset.id2label),), 86 | thresh, 87 | device=device) 88 | else: 89 | raise ValueError(f'Unsupported threshold type: {type(thresh)}. ' 90 | f'Expected float or list/tuple/ndarray of floats.') 91 | 92 | # eval mode 93 | model.eval() 94 | 95 | n = len(dataloader) # number of batches 96 | action = 'validating' 97 | desc = f'{pbar.desc[:-36]}{action:>36}' if pbar else f'{action}' 98 | bar = tqdm(dataloader, desc, n, not is_training, bar_format='{l_bar}{bar:10}{r_bar}', position=0) 99 | pred, targets, loss = [], [], 0 100 | 101 | with torch.no_grad(): 102 | with autocast('cuda', enabled=(device != torch.device('cpu'))): 103 | for images, labels in bar: 104 | images, labels = images.to(device, non_blocking=True), labels.to(device) 105 | y = model(images) 106 | if is_single_label: 107 | pred.append(y.argsort(1, descending=True)[:, :top_k]) 108 | targets.append(labels) 109 | else: 110 | # Get prediction probabilities using sigmoid 111 | pred_prob = y.sigmoid() 112 | # Predict using threshold for each class 113 | pred.append(pred_prob >= thresh) 114 | # Convert to hard labels 115 | hard_labels = (labels >= 0.5).float() 116 | targets.append(hard_labels) 117 | if lossfn: 118 | loss += lossfn(y, labels) 119 | 120 | loss /= n 121 | pred, targets = torch.cat(pred), torch.cat(targets) 122 | 123 | if not is_training and is_single_label and len(dataloader.dataset.id2label) <= 10: 124 | conm = ConfusedMatrix(len(dataloader.dataset.id2label)) 125 | conm.update(targets, pred[:, 0]) 126 | conm.save_conm(conm.mat.detach().cpu().numpy(), dataloader.dataset.id2label, conm_path if conm_path is not None else 'conm.png') 127 | 128 | if is_single_label: 129 | correct = (targets[:, None] == pred).float() 130 | acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy 131 | top1, top5 = acc.mean(0).tolist() 132 | 133 | if not is_training: 134 | table = PrettyTable(['Class', 'Samples', 'Top1', f'Top{top_k}']) 135 | 136 | for i, c in dataloader.dataset.id2label.items(): 137 | acc_i = acc[targets == i] 138 | top1i, top5i = acc_i.mean(0).tolist() 139 | table.add_row([c, acc_i.shape[0], f'{top1i:.3f}', f'{top5i:.3f}']) 140 | 141 | table.add_row(['MEAN', acc.shape[0], f'{top1:.3f}', f'{top5:.3f}']) 142 | 143 | logger.console('\n' + str(table)) 144 | else: 145 | table = PrettyTable(['Class', 'Samples', 'Top1', f'Top{top_k}']) 146 | for i, c in dataloader.dataset.id2label.items(): 147 | acc_i = acc[targets == i] 148 | top1i, top5i = acc_i.mean(0).tolist() 149 | table.add_row([c, acc_i.shape[0], f'{top1i:.3f}', f'{top5i:.3f}']) 150 | table.add_row(['MEAN', acc.shape[0], f'{top1:.3f}', f'{top5:.3f}']) 151 | logger.log('\n' + str(table)) 152 | else: 153 | num_classes = len(dataloader.dataset.id2label) 154 | # Compute precision, recall, and F1-score for each class 155 | precisioner = Precision(task='multilabel', threshold=0.5, num_labels=num_classes, average=None).to(device) 156 | recaller = Recall(task='multilabel', threshold=0.5, num_labels=num_classes, average=None).to(device) 157 | f1scorer = F1Score(task='multilabel', threshold=0.5, num_labels=num_classes, average=None).to(device) 158 | 159 | # Compute precision, recall, and F1-score for each class 160 | precision = precisioner(pred.float(), targets) 161 | recall = recaller(pred.float(), targets) 162 | f1score = f1scorer(pred.float(), targets) 163 | 164 | cls_numbers = targets.sum(0).int().tolist() 165 | 166 | if is_training: 167 | table = PrettyTable(['Class', 'Samples', 'Precision', 'Recall', 'F1-Score', 'Threshold']) 168 | for i, c in dataloader.dataset.id2label.items(): 169 | table.add_row([ 170 | c, 171 | cls_numbers[i], 172 | f'{precision[i].item():.3f}', 173 | f'{recall[i].item():.3f}', 174 | f'{f1score[i].item():.3f}', 175 | f'{thresh[i].item():.3f}' if isinstance(thresh, torch.Tensor) else f'{thresh:.3f}' 176 | ]) 177 | table.add_row([ 178 | 'MEAN', 179 | sum(cls_numbers), 180 | f'{precision.mean().item():.3f}', 181 | f'{recall.mean().item():.3f}', 182 | f'{f1score.mean().item():.3f}', 183 | '-' 184 | ]) 185 | logger.log('\n' + str(table)) 186 | else: 187 | table = PrettyTable(['Class', 'Samples', 'Precision', 'Recall', 'F1-Score', 'Threshold']) 188 | 189 | for i, c in dataloader.dataset.id2label.items(): 190 | table.add_row([ 191 | c, 192 | cls_numbers[i], 193 | f'{precision[i].item():.3f}', 194 | f'{recall[i].item():.3f}', 195 | f'{f1score[i].item():.3f}', 196 | f'{thresh[i].item():.3f}' if isinstance(thresh, torch.Tensor) else f'{thresh:.3f}' 197 | ]) 198 | 199 | table.add_row([ 200 | 'MEAN', 201 | sum(cls_numbers), 202 | f'{precision.mean().item():.3f}', 203 | f'{recall.mean().item():.3f}', 204 | f'{f1score.mean().item():.3f}', 205 | '-' 206 | ]) 207 | 208 | # 显示表格 209 | logger.console('\n' + str(table)) 210 | 211 | if pbar: 212 | if is_single_label: 213 | pbar.desc = f'{pbar.desc[:-36]}{loss:>12.3g}{top1:>12.3g}{top5:>12.3g}' 214 | else: 215 | pbar.desc = f'{pbar.desc[:-36]}{loss:>12.3g}{precision.mean().item():>12.3g}{recall.mean().item():>12.3g}{f1score.mean().item():>12.3g}' 216 | 217 | # filename = 'train_results.txt' if is_training else 'val_results.txt' 218 | # with open(filename, 'w') as f: 219 | # if is_single_label: 220 | # # Save each sample's prediction and true label 221 | # for i in range(len(targets)): 222 | # f.write(f'Sample {i}: GT={targets[i].item()}, Pred={pred[i, 0].item()}, Top{top_k}={[pred[i, j].item() for j in range(top_k)]}\n') 223 | # else: 224 | # # Save each sample's multi-label prediction and true label 225 | # for i in range(len(targets)): 226 | # gt = targets[i].cpu().numpy() 227 | # pd = pred[i].cpu().numpy() 228 | # gt_indices = np.where(gt == 1)[0] 229 | # pred_indices = np.where(pd == 1)[0] 230 | # f.write(f'Sample {i}:\n') 231 | # f.write(f' GT classes: {gt_indices.tolist()}\n') 232 | # f.write(f' Pred classes: {pred_indices.tolist()}\n') 233 | 234 | if lossfn: 235 | if is_single_label: 236 | return top1, top5, loss 237 | else: 238 | return precision.mean().item(), recall.mean().item(), f1score.mean().item(), loss 239 | else: 240 | if is_single_label: 241 | return top1, top5 242 | else: 243 | return precision.mean().item(), recall.mean().item(), f1score.mean().item() 244 | -------------------------------------------------------------------------------- /doraemon/utils/checks.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from pathlib import Path 3 | import os 4 | from datasets import load_dataset 5 | 6 | def check_cfgs_common(cfgs): 7 | def find_normalize(augment_list): 8 | for augment in augment_list: 9 | if 'normalize' in augment: 10 | return augment['normalize'] 11 | return None 12 | 13 | hyp_cfg = cfgs['hyp'] 14 | data_cfg = cfgs['data'] 15 | model_cfg = cfgs['model'] 16 | 17 | # Check loss configuration 18 | assert reduce(lambda x, y: int(x) + int(y[0]), list(hyp_cfg['loss'].values())) == 1, \ 19 | 'Loss configuration error: Only one loss type should be enabled. Set either ce: true or bce: [true, ...] in hyp.loss' 20 | 21 | # Check optimizer 22 | assert hyp_cfg['optimizer'][0] in {'sgd', 'adam', 'sam'}, \ 23 | 'Invalid optimizer selection. Please choose from: sgd, adam, or sam' 24 | 25 | # Check scheduler and warm-up settings 26 | valid_schedulers = {'linear', 'cosine', 'linear_with_warm', 'cosine_with_warm'} 27 | assert hyp_cfg['scheduler'] in valid_schedulers, \ 28 | 'Invalid scheduler selection. Supported options: linear, cosine, linear_with_warm, cosine_with_warm' 29 | 30 | assert hyp_cfg['warm_ep'] >= 0 and isinstance(hyp_cfg['warm_ep'], int) and hyp_cfg['warm_ep'] < hyp_cfg['epochs'], \ 31 | f'Invalid warm-up epochs: must be a non-negative integer less than total epochs ({hyp_cfg["epochs"]})' 32 | 33 | if hyp_cfg['warm_ep'] == 0: 34 | assert hyp_cfg['scheduler'] in {'linear', 'cosine'}, \ 35 | 'When warm-up is disabled (warm_ep: 0), only linear or cosine scheduler is supported' 36 | if hyp_cfg['warm_ep'] > 0: 37 | assert hyp_cfg['scheduler'] in {'linear_with_warm', 'cosine_with_warm'}, \ 38 | 'When using warm-up (warm_ep > 0), scheduler must be either linear_with_warm or cosine_with_warm' 39 | 40 | # Check normalization settings 41 | train_normalize = find_normalize(data_cfg['train']['augment']) 42 | val_normalize = find_normalize(data_cfg['val']['augment']) 43 | 44 | # Check backbone configuration 45 | if 'backbone' in model_cfg: 46 | backbone_cfg = next(iter(model_cfg['backbone'].items())) 47 | backbone_name, backbone_params = backbone_cfg 48 | else: 49 | backbone_name = model_cfg['name'] 50 | backbone_params = { 51 | 'pretrained': model_cfg.get('pretrained', False), 52 | 'image_size': model_cfg.get('image_size') 53 | } 54 | 55 | # Verify model type is timm 56 | assert backbone_name.startswith('timm-'), \ 57 | "Only timm models are supported. Model name must start with 'timm-'" 58 | 59 | # Check normalization requirements based on pretrained status 60 | is_pretrained = backbone_params.get('pretrained', False) 61 | if is_pretrained: 62 | if train_normalize is None or val_normalize is None: 63 | raise ValueError('Pretrained models require normalization in both training and validation augmentations') 64 | if train_normalize['mean'] != val_normalize['mean'] or train_normalize['std'] != val_normalize['std']: 65 | raise ValueError('Inconsistent normalization parameters: mean and std must be identical for training and validation') 66 | 67 | # Check image size for backbone 68 | assert 'image_size' in backbone_params, \ 69 | f'Image size must be specified for {backbone_name}' 70 | assert backbone_params['image_size'] == model_cfg['image_size'], \ 71 | f'Image size mismatch: {backbone_params["image_size"]} in backbone config vs {model_cfg["image_size"]} in model config' 72 | 73 | def check_cfgs_face(cfgs): 74 | """ 75 | Check configurations specific to face recognition tasks. 76 | 77 | Args: 78 | cfgs: Configuration dictionary 79 | """ 80 | check_cfgs_common(cfgs=cfgs) 81 | 82 | model_cfg = cfgs['model'] 83 | data_cfg = cfgs['data'] 84 | 85 | # Determine data source type 86 | is_local = os.path.isdir(data_cfg['root']) 87 | 88 | # Check number of classes based on data source 89 | if is_local: 90 | train_classes = [x for x in os.listdir(Path(data_cfg['root'])/'train') 91 | if not (x.startswith('.') or x.startswith('_'))] 92 | num_classes = len(train_classes) 93 | else: 94 | try: 95 | dataset = load_dataset(data_cfg['root'], split='train') 96 | num_classes = len(set(dataset['label'])) 97 | except Exception as e: 98 | raise ValueError(f"Dataset loading error: Unable to load HuggingFace dataset from {data_cfg['root']}. Details: {str(e)}") 99 | 100 | # Check model configuration 101 | head_key = next(iter(model_cfg['head'].keys())) 102 | model_classes = model_cfg['head'][head_key]['num_class'] 103 | 104 | assert model_classes == num_classes, \ 105 | f'Model configuration error: Number of classes mismatch. Expected {num_classes} from dataset, but got {model_classes} in model configuration' 106 | 107 | # # Check face recognition specific configurations 108 | # if cfgs['model']['task'] == 'face': 109 | # pair_txt_path = data_cfg['val']['pair_txt'] 110 | # 111 | # # Verify pair text file existence 112 | # if not os.path.isfile(pair_txt_path): 113 | # raise ValueError(f'Validation data error: Pair text file not found at {pair_txt_path}') 114 | 115 | # # Validate pair list format 116 | # from doraemon.engine.representation.eval_face import Evaluator 117 | # try: 118 | # with open(pair_txt_path) as f: 119 | # pair_list = [line.strip() for line in f.readlines()] 120 | # Evaluator.check_nps(pair_list) 121 | # except Exception as e: 122 | # raise ValueError(f'Pair list validation error: Invalid format in {pair_txt_path}. Details: {str(e)}') 123 | 124 | def check_cfgs_cbir(cfgs): 125 | """ 126 | Check configurations specific to CBIR (Content-Based Image Retrieval) tasks. 127 | 128 | Args: 129 | cfgs: Configuration dictionary 130 | """ 131 | check_cfgs_common(cfgs=cfgs) 132 | 133 | model_cfg = cfgs['model'] 134 | data_cfg = cfgs['data'] 135 | 136 | # Determine data source type 137 | is_local = os.path.isdir(data_cfg['root']) 138 | 139 | # Check number of classes based on data source 140 | if is_local: 141 | train_classes = [x for x in os.listdir(Path(data_cfg['root'])/'train') 142 | if not (x.startswith('.') or x.startswith('_'))] 143 | num_classes = len(train_classes) 144 | else: 145 | try: 146 | dataset = load_dataset(data_cfg['root'], split='train') 147 | num_classes = len(set(dataset['label'])) 148 | except Exception as e: 149 | raise ValueError(f"Dataset loading error: Unable to load HuggingFace dataset from {data_cfg['root']}. Details: {str(e)}") 150 | 151 | # Check model configuration 152 | head_key = next(iter(model_cfg['head'].keys())) 153 | model_classes = model_cfg['head'][head_key]['num_class'] 154 | 155 | assert model_classes == num_classes, \ 156 | f'Model configuration error: Number of classes mismatch. Expected {num_classes} from dataset, but got {model_classes} in model configuration' 157 | 158 | def check_cfgs_classification(cfgs): 159 | """ 160 | Check configurations specific to classification tasks. 161 | 162 | Args: 163 | cfgs: Configuration dictionary 164 | """ 165 | check_cfgs_common(cfgs=cfgs) 166 | 167 | model_cfg = cfgs['model'] 168 | data_cfg = cfgs['data'] 169 | hyp_cfg = cfgs['hyp'] 170 | 171 | # Determine data source type 172 | is_csv = data_cfg['root'].endswith('.csv') 173 | is_local = os.path.isdir(data_cfg['root']) 174 | 175 | # Check loss configuration based on data source 176 | if is_csv: 177 | if hyp_cfg['loss']['ce']: 178 | raise ValueError('Loss configuration error: Multi-label tasks (CSV format) require BCE loss. Please set ce: false in hyp.loss') 179 | if not hyp_cfg['loss']['bce'][0]: 180 | raise ValueError('Loss configuration error: Multi-label tasks (CSV format) require BCE loss. Please set bce: [true, ...] in hyp.loss') 181 | else: 182 | if not hyp_cfg['loss']['ce']: 183 | raise ValueError('Loss configuration error: Single-label tasks (folder structure/HuggingFace) require CE loss. Please set ce: true in hyp.loss') 184 | if hyp_cfg['loss']['bce'][0]: 185 | raise ValueError('Loss configuration error: Single-label tasks (folder structure/HuggingFace) do not support BCE loss. Please set bce: [false, ...] in hyp.loss') 186 | 187 | # Check num_classes 188 | if is_local: 189 | train_classes = [x for x in os.listdir(Path(data_cfg['root'])/'train') 190 | if not (x.startswith('.') or x.startswith('_'))] 191 | num_classes = len(train_classes) 192 | elif is_csv: 193 | import pandas as pd 194 | df = pd.read_csv(data_cfg['root']) 195 | class_columns = [col for col in df.columns if col not in ['image_path', 'train']] 196 | num_classes = len(class_columns) 197 | if hyp_cfg['loss']['bce'][0]: 198 | assert num_classes == len(hyp_cfg['loss']['bce'][1]), \ 199 | f'Loss configuration error: Number of classes mismatch. Expected {len(hyp_cfg["loss"]["bce"][1])} from dataset, but got {num_classes} in model configuration' 200 | else: 201 | try: 202 | dataset = load_dataset(data_cfg['root'], split='train') 203 | num_classes = len(set(dataset['label'])) 204 | except Exception as e: 205 | raise ValueError(f"Dataset loading error: Unable to load HuggingFace dataset from {data_cfg['root']}. Details: {str(e)}") 206 | 207 | assert model_cfg['num_classes'] == num_classes, \ 208 | f'Model configuration error: Number of classes mismatch. Expected {num_classes} from dataset, but got {model_cfg["num_classes"]} in model configuration' 209 | 210 | # Check model configuration 211 | assert model_cfg['name'].split('-')[0] == 'timm', \ 212 | 'Model name error: Format should be [timm-ModelName] for timm models' 213 | 214 | if model_cfg['kwargs'] and model_cfg['pretrained']: 215 | for k in model_cfg['kwargs'].keys(): 216 | if k not in {'dropout', 'attention_dropout', 'stochastic_depth_prob'}: 217 | raise KeyError('Model kwargs error: When using pretrained models, only [dropout, attention_dropout, stochastic_depth_prob] are allowed') 218 | 219 | # Check training strategies 220 | if hyp_cfg['strategy']['focal'][0]: 221 | assert hyp_cfg['loss']['bce'], \ 222 | 'Strategy configuration error: Focal loss requires BCE loss. Please enable BCE loss' 223 | 224 | if hyp_cfg['strategy']['ohem'][0]: 225 | assert not hyp_cfg['loss']['bce'][0], \ 226 | 'Strategy configuration error: OHEM is not compatible with BCE loss. Please disable BCE loss' 227 | 228 | # Check mixup configuration 229 | mix_ratio, mix_duration = hyp_cfg['strategy']['mixup']["ratio"], hyp_cfg['strategy']['mixup']["duration"] 230 | 231 | # Basic ratio check 232 | assert 0 <= mix_ratio <= 1, 'Mixup configuration error: ratio must be in [0,1]' 233 | 234 | # Only check duration when mixup is enabled 235 | if mix_ratio > 0: 236 | assert 0 < mix_duration <= hyp_cfg['epochs'], \ 237 | f'Mixup configuration error: when mixup is enabled (ratio > 0), duration must be in (0,{hyp_cfg["epochs"]}]' 238 | 239 | hyp_cfg['strategy']['mixup'] = [mix_ratio, mix_duration] 240 | 241 | def check(task, cfgs): 242 | if task == 'face': check_cfgs_face(cfgs) 243 | elif task == 'cbir': check_cfgs_cbir(cfgs) 244 | elif task == 'classification': check_cfgs_classification(cfgs) 245 | else: raise ValueError(f'{task} is not supported') -------------------------------------------------------------------------------- /doraemon/engine/representation/eval_cbir.py: -------------------------------------------------------------------------------- 1 | from ...dataset.basedataset import CBIRDatasets 2 | from ...dataset.transforms import create_AugTransforms 3 | from ...dataset.dataprocessor import SmartDataProcessor 4 | from ...models.representation.face_model import FeatureExtractor 5 | from ...utils.logger import SmartLogger 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import numpy as np 9 | from tqdm import tqdm 10 | import faiss 11 | from typing import Optional 12 | from sklearn.metrics import roc_auc_score, ndcg_score 13 | 14 | class CBIRMetrics: 15 | def __init__(self, 16 | cutoffs: list[int] = [1,10, 100]): 17 | _cutoffs = cutoffs.copy() 18 | self.cutoffs = _cutoffs 19 | self.metrics = {} 20 | 21 | def compute_mrr(self, 22 | preds: list[list[str]], 23 | labels: list[list[str]]): 24 | cutoffs = self.cutoffs.copy() 25 | mrrs = np.zeros(len(cutoffs)) 26 | 27 | for pred, label in zip(preds, labels): 28 | jump = False 29 | for i, x in enumerate(pred, 1): 30 | if x in label: 31 | for k, cutoff in enumerate(cutoffs): 32 | if i <= cutoff: 33 | mrrs[k] += 1 / i 34 | jump = True 35 | if jump: 36 | break 37 | mrrs /= len(preds) 38 | for i, cutoff in enumerate(cutoffs): 39 | mrr = mrrs[i] 40 | self.metrics[f"MRR@{cutoff}"] = mrr 41 | 42 | def compute_recall(self, 43 | preds: list[list[str]], 44 | labels: list[list[str]]): 45 | cutoffs = self.cutoffs.copy() 46 | recalls = np.zeros(len(cutoffs)) 47 | for pred, label in zip(preds, labels): 48 | for k, cutoff in enumerate(cutoffs): 49 | recall = np.intersect1d(label, pred[:cutoff]) 50 | recalls[k] += len(recall) / len(label) 51 | recalls /= len(preds) 52 | for i, cutoff in enumerate(cutoffs): 53 | recall = recalls[i] 54 | self.metrics[f"Recall@{cutoff}"] = recall 55 | 56 | def compute_precision(self, 57 | preds: list[list[str]], 58 | labels: list[list[str]]): 59 | cutoffs = self.cutoffs.copy() 60 | precisions = np.zeros(len(cutoffs)) 61 | cutoffs = self.cutoffs.copy() 62 | for pred, label in zip(preds, labels): 63 | for k, cutoff in enumerate(cutoffs): 64 | precision = np.intersect1d(label, pred[:cutoff]) 65 | precisions[k] += len(precision) / min(cutoff, len(label)) 66 | precisions /= len(preds) 67 | for i, cutoff in enumerate(cutoffs): 68 | self.metrics[f"Precision@{cutoff}"] = precisions[i] 69 | 70 | def compute_auc(self, 71 | preds: list[list[str]], 72 | labels: list[list[str]], 73 | preds_scores: list[list[float]]): 74 | pred_hard_encodings = self.encode_pred2hard(preds=preds, labels=labels) 75 | 76 | pred_hard_encodings1d = np.asarray(pred_hard_encodings).flatten() 77 | preds_scores1d = preds_scores.flatten() 78 | auc = roc_auc_score(pred_hard_encodings1d, preds_scores1d) 79 | 80 | self.metrics[f'AUC@{self.cutoffs[-1]}'] = auc 81 | 82 | def compute_ndcg(self, 83 | preds: list[list[str]], 84 | labels: list[list[str]], 85 | preds_scores: list[list[float]]): 86 | cutoffs = self.cutoffs.copy() 87 | pred_hard_encodings = self.encode_pred2hard(preds=preds, labels=labels) 88 | for _, cutoff in enumerate(cutoffs): 89 | nDCG = ndcg_score(pred_hard_encodings, preds_scores, k=cutoff) 90 | self.metrics[f"nDCG@{cutoff}"] = nDCG 91 | 92 | def encode_pred2hard(self, 93 | preds: list[list[str]], 94 | labels: list[list[str]]) -> list[list[int]]: 95 | pred_hard_encodings = [] 96 | for pred, label in zip(preds, labels): 97 | pred_hard_encoding = np.isin(pred, label).astype(int).tolist() 98 | pred_hard_encodings.append(pred_hard_encoding) 99 | 100 | return pred_hard_encodings 101 | 102 | def reset(self): 103 | self.metrics.clear() 104 | 105 | 106 | def index(extractor: FeatureExtractor, 107 | gallery_dataloader: DataLoader, 108 | device: torch.device, 109 | logger: SmartLogger, 110 | index_factory: str = "Flat", 111 | # need memmap 112 | memmap_feat_dim: Optional[int] = None, 113 | memmap_dtype: torch.dtype = torch.float16, 114 | memmap_save_path: Optional[str] = None, 115 | memmap_load_embedding: bool = False, 116 | ): 117 | 118 | """ 119 | 1. Encode the entire corpus into dense embeddings; 120 | 2. Create faiss index; 121 | 3. Optionally save embeddings. 122 | """ 123 | 124 | if memmap_load_embedding: 125 | gallery_embeddings = np.memmap( 126 | memmap_save_path, 127 | mode="r", 128 | dtype=memmap_dtype 129 | ).reshape(-1, memmap_feat_dim) 130 | 131 | else: 132 | gallery_embeddings = extractor.extract_cbir(gallery_dataloader, device) 133 | 134 | if memmap_save_path is not None: 135 | logger.console(f"saving embeddings at {memmap_save_path}...") 136 | memmap = np.memmap( 137 | memmap_save_path, 138 | shape=gallery_embeddings.shape, 139 | mode="w+", 140 | dtype=gallery_embeddings.dtype 141 | ) 142 | 143 | length = gallery_embeddings.shape[0] 144 | # add in batch 145 | save_batch_size = 10000 146 | if length > save_batch_size: 147 | for i in tqdm(range(0, length, save_batch_size), leave=False, desc="Saving Embeddings"): 148 | j = min(i + save_batch_size, length) 149 | memmap[i: j] = gallery_embeddings[i: j] 150 | else: 151 | memmap[:] = gallery_embeddings 152 | 153 | dim = gallery_embeddings.shape[-1] 154 | # create faiss index 155 | 156 | faiss_index = faiss.index_factory(dim, index_factory, faiss.METRIC_INNER_PRODUCT) 157 | logger.console(f"Creating CPU FAISS index with dimension {dim}...") 158 | 159 | # if device.type == 'cuda': 160 | # # co = faiss.GpuClonerOptions() 161 | # co = faiss.GpuMultipleClonerOptions() 162 | # co.useFloat16 = True 163 | # # faiss_index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), 0, faiss_index, co) 164 | # faiss_index = faiss.index_cpu_to_all_gpus(faiss_index, co) 165 | 166 | # NOTE: faiss only accepts float32 167 | logger.console("Adding embeddings...") 168 | gallery_embeddings = gallery_embeddings.astype(np.float32) 169 | faiss_index.train(gallery_embeddings) 170 | faiss_index.add(gallery_embeddings) 171 | return faiss_index 172 | 173 | def search(extractor: FeatureExtractor, 174 | query_dataloader: DataLoader, 175 | faiss_index: faiss.Index, 176 | device: torch.device, 177 | logger: SmartLogger, 178 | k:int = 100, 179 | batch_size: int = 256, 180 | ): 181 | """ 182 | 1. Encode queries into dense embeddings; 183 | 2. Search through faiss index 184 | """ 185 | query_embeddings = extractor.extract_cbir(query_dataloader, device) 186 | query_size = query_embeddings.shape[0] 187 | 188 | all_scores = [] 189 | all_indices = [] 190 | 191 | logger.console('Searching ...') 192 | for i in range(0, query_size, batch_size): 193 | j = min(i + batch_size, query_size) 194 | query_embedding = query_embeddings[i: j] 195 | score, indice = faiss_index.search(query_embedding.astype(np.float32), k=k) 196 | all_scores.append(score) 197 | all_indices.append(indice) 198 | 199 | all_scores = np.concatenate(all_scores, axis=0) 200 | all_indices = np.concatenate(all_indices, axis=0) 201 | 202 | return all_scores, all_indices 203 | 204 | def compute_metrics(preds, 205 | preds_scores, 206 | labels, 207 | metrics = ['mrr', 'precision', 'recall', 'auc', 'ndcg'], 208 | cutoffs=[1, 3, 10]): 209 | 210 | metrics_engine = CBIRMetrics(cutoffs=cutoffs) 211 | 212 | for m in metrics: 213 | if m == 'mrr': 214 | metrics_engine.compute_mrr(preds=preds, labels=labels) 215 | elif m == 'precision': 216 | metrics_engine.compute_precision(preds=preds, labels=labels) 217 | elif m == 'recall': 218 | metrics_engine.compute_recall(preds=preds, labels=labels) 219 | elif m == 'auc': 220 | metrics_engine.compute_auc(preds=preds, labels=labels, preds_scores=preds_scores) 221 | elif m == 'ndcg': 222 | metrics_engine.compute_ndcg(preds=preds, labels=labels, preds_scores=preds_scores) 223 | else: 224 | raise ValueError(f'{m} is not supported') 225 | 226 | return metrics_engine.metrics 227 | 228 | def valuate(model, 229 | data_cfg: dict, 230 | device: torch.device, 231 | logger: SmartLogger, 232 | vis: bool = False): 233 | """ 234 | Arguments: 235 | vis(bool): for cbir visualization 236 | """ 237 | 238 | query_dataset, gallery_dataset = CBIRDatasets.build(root=data_cfg['root'], 239 | transforms=create_AugTransforms(data_cfg['val']['augment'])) 240 | 241 | query_dataloader = SmartDataProcessor.set_dataloader(query_dataset, 242 | bs=data_cfg['val']['bs'], 243 | nw=data_cfg['nw'], 244 | shuffle=False) # must be False, otherwise metrics are computed wrong 245 | 246 | gallery_dataloader = SmartDataProcessor.set_dataloader(gallery_dataset, 247 | bs=data_cfg['val']['bs'], 248 | nw=data_cfg['nw'], 249 | shuffle=False) # must be False, otherwise metrics are computed wrong 250 | 251 | feature_extractor = FeatureExtractor(model) 252 | 253 | faiss_index = index( 254 | extractor=feature_extractor, 255 | gallery_dataloader=gallery_dataloader, 256 | device=device, 257 | logger=logger 258 | ) 259 | 260 | cutoffs = data_cfg['val']['metrics']['cutoffs'] 261 | scores, indices = search( 262 | extractor=feature_extractor, 263 | query_dataloader=query_dataloader, 264 | faiss_index=faiss_index, 265 | device=device, 266 | logger=logger, 267 | k = cutoffs[-1], 268 | batch_size=data_cfg['val']['bs'], 269 | ) 270 | 271 | retrieval_results = [] 272 | for indice in indices: 273 | # filter invalid indices 274 | indice = indice[indice != -1].tolist() 275 | retrieval_results.append(gallery_dataset.gallery[indice]['gallery']) 276 | 277 | ground_truths = [] 278 | for pos in query_dataset.data['pos']: 279 | ground_truths.append(pos) 280 | 281 | metrics = compute_metrics(retrieval_results, 282 | scores, 283 | ground_truths, 284 | metrics=data_cfg['val']['metrics']['metrics'], 285 | cutoffs=cutoffs) 286 | 287 | for k, v in metrics.items(): 288 | metrics[k] = float(v) 289 | 290 | if vis: 291 | return metrics, retrieval_results, scores, ground_truths, query_dataset.data['query'], query_dataset, gallery_dataset 292 | 293 | return metrics 294 | 295 | -------------------------------------------------------------------------------- /doraemon/engine/procedure/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | from torch import Tensor 4 | from .eval_recog import valuate 5 | from ..representation.eval_face import valuate as valuate_face 6 | from ..representation.eval_cbir import valuate as valuate_cbir 7 | from typing import Callable 8 | import math 9 | from ..optimizer import SAM 10 | from torch.utils.tensorboard import SummaryWriter 11 | import os 12 | from copy import deepcopy 13 | 14 | __all__ = ['Trainer'] 15 | 16 | def make_divisible(x: int, divisor = 32): 17 | # Returns nearest x divisible by divisor 18 | return math.ceil(x / divisor) * divisor 19 | 20 | def print_imgsz(images: torch.Tensor): 21 | h, w = images.shape[-2:] 22 | return [h,w] 23 | 24 | def mixup_data(x, y, device, lam): 25 | '''Returns mixed inputs, pairs of targets''' 26 | batch_size = x.size()[0] 27 | # to device 28 | index = torch.randperm(batch_size).to(device) 29 | 30 | mixed_x = lam * x + (1 - lam) * x[index, :] 31 | y_a, y_b = y, y[index] 32 | return mixed_x, y_a, y_b 33 | 34 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 35 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 36 | 37 | class Trainer: 38 | def __init__(self, 39 | model, 40 | train_dataloader, 41 | val_dataloader, 42 | optimizer, 43 | scaler, 44 | device: torch.device, 45 | epochs: int, 46 | logger, 47 | rank: int, 48 | scheduler, 49 | ema, 50 | sampler = None, 51 | thresh = 0, 52 | teacher = None, 53 | mixup_sampler = None, 54 | # face 55 | task: str = 'face', 56 | print_freq = 50, 57 | save_freq = 5, 58 | cfgs: dict = None, 59 | out_dir = None 60 | ): 61 | 62 | self.model = model 63 | self.train_dataloader = train_dataloader 64 | self.val_dataloader = val_dataloader 65 | self.optimizer = optimizer 66 | self.scaler = scaler 67 | self.device = device 68 | self.epochs = epochs 69 | self.logger = logger 70 | self.rank = rank 71 | self.scheduler = scheduler 72 | self.ema = ema 73 | self.sampler = sampler 74 | self.thresh = thresh 75 | self.teacher = teacher 76 | self.sam: bool = type(self.optimizer) is SAM 77 | self.distill: bool = teacher is not None 78 | self.mixup_sampler = mixup_sampler 79 | 80 | # face 81 | self.task = task 82 | self.print_freq = print_freq 83 | self.save_freq = save_freq 84 | self.data_cfg = cfgs['data'] 85 | self.model_cfg = cfgs['model'] 86 | self.hyp_cfg = cfgs['hyp'] 87 | if rank in (-1, 0) and out_dir is not None: 88 | self.writer = SummaryWriter(log_dir=out_dir) 89 | 90 | def train_one_epoch(self, epoch: int, criterion: Callable): 91 | # train mode 92 | self.model.train() 93 | 94 | cuda: bool = self.device != torch.device('cpu') 95 | 96 | if self.rank != -1: 97 | self.train_dataloader.sampler.set_epoch(epoch) 98 | pbar = enumerate(self.train_dataloader) 99 | if self.rank in {-1, 0}: 100 | pbar = tqdm(enumerate(self.train_dataloader), 101 | total=len(self.train_dataloader), 102 | bar_format='{l_bar}{bar:10}{r_bar}') 103 | 104 | tloss, fitness = 0., 0. 105 | 106 | for i, (images, labels) in pbar: # progress bar 107 | 108 | if self.mixup_sampler is not None: 109 | lam = self.mixup_sampler.sample() 110 | else: 111 | lam = 0 112 | 113 | images, labels = images.to(self.device, non_blocking=True), labels.to(self.device) 114 | if self.sampler is not None: # OHEM-Softmax 115 | with torch.no_grad(): 116 | valid = self.sampler.sample(self.model(images), labels) 117 | images, labels = images[valid], labels[valid] 118 | with torch.autocast(device_type=self.device.type, enabled=(self.device != torch.device('cpu'))): 119 | loss = self.compute_loss(images, labels, lam, criterion) 120 | 121 | if self.rank in {-1, 0}: 122 | tloss = (tloss * i + loss.item()) / (i + 1) # update mean losses 123 | mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if cuda else 0) # (GB) 124 | pbar.desc = f"{f'{epoch + 1}/{self.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36 125 | pbar.postfix = f'lr:{self.optimizer.param_groups[0]["lr"]:.5f}, imgsz:{print_imgsz(images)}' 126 | 127 | if i == len(pbar) - 1: # last batch 128 | self.logger.log(f'EPOCH:{epoch + 1:d} Train-Loss:{tloss:4f} LR:{self.optimizer.param_groups[0]["lr"]:.5f}') 129 | if self.thresh == 0: 130 | # val 131 | top1, top5, v_loss = valuate(self.ema.ema, self.val_dataloader, self.device, pbar, True, criterion, self.logger, 132 | self.thresh) 133 | self.logger.log(f'VAL-LOSS:{v_loss:4f}\n') 134 | else: 135 | self.logger.log(f'{"name":<8}{"nums":>8}{"precision":>15}{"recall":>10}{"f1score":>10}') 136 | # val 137 | precision, recall, f1score, v_loss = valuate(self.ema.ema, self.val_dataloader, self.device, pbar, True, 138 | criterion, self.logger, self.thresh) 139 | self.logger.log(f'VAL-Loss:{v_loss:4f}\n') 140 | 141 | fitness = top1 if self.thresh == 0 else f1score # define fitness as top1 accuracy 142 | 143 | self.scheduler.step() # step epoch-wise 144 | 145 | return fitness 146 | 147 | @staticmethod 148 | def update_sam(model: torch.nn.Module, inputs, targets, optimizer, lossfn, rank, ema=None, mixup=False, **kwargs): 149 | # first forward-backward step 150 | optimizer.enable_running_stats(model) 151 | if not mixup: 152 | loss = lossfn(model(inputs), targets) 153 | else: 154 | loss = mixup_criterion(lossfn, model(inputs), **kwargs) 155 | if rank >= 0: # multi-gpu 156 | with model.no_sync(): 157 | loss.mean().backward() 158 | else: 159 | loss.mean().backward() 160 | optimizer.first_step(zero_grad=True) 161 | 162 | # second forward-backward step 163 | optimizer.disable_running_stats(model) 164 | if not mixup: 165 | lossfn(model(inputs), targets).mean().backward() 166 | else: 167 | mixup_criterion(lossfn, model(inputs), **kwargs).mean().backward() 168 | optimizer.second_step(zero_grad=True) 169 | 170 | if ema: 171 | ema.update(model) 172 | 173 | return loss 174 | 175 | def compute_loss(self, images: Tensor, labels: Tensor, lam: float, criterion: Callable, face: bool = False): 176 | mixup: bool = lam > 0 177 | 178 | assert not (mixup and self.distill), 'distill not be True when mixup is True' 179 | if mixup and self.sam: # close 180 | images, targets_a, targets_b = mixup_data(images, labels, self.device, lam) 181 | kwargs = dict(y_a=targets_a, y_b=targets_b, lam=lam) 182 | loss = Trainer.update_sam(self.model, images, labels, self.optimizer, criterion, self.rank, self.ema, mixup=True, **kwargs) 183 | elif mixup: # close 184 | images, targets_a, targets_b = mixup_data(images, labels, self.device, lam) 185 | loss = mixup_criterion(criterion, self.model(images), targets_a, targets_b, lam) 186 | Trainer.update(self.model, loss, self.scaler, self.optimizer, self.ema) 187 | elif self.sam and self.distill: 188 | raise ValueError('SAM optimizer and Knowledge distilling have not been implemented yet.') 189 | elif self.sam: # close 190 | loss = Trainer.update_sam(self.model, images, labels, self.optimizer, criterion, self.rank, self.ema, mixup=False) 191 | elif self.distill: 192 | raise ValueError('Knowledge distilling have not been implemented yet.') 193 | else: # close 194 | loss = criterion(self.model(images), labels) if not face else criterion(self.model(images, labels), labels) 195 | Trainer.update(self.model, loss, self.scaler, self.optimizer, self.ema) 196 | 197 | return loss 198 | 199 | # scale + backward + grad_clip + step + zero_grad 200 | @staticmethod 201 | def update(model, loss, scaler, optimizer, ema=None): 202 | # backward 203 | scaler.scale(loss).backward() 204 | 205 | # optimize 206 | scaler.unscale_(optimizer) # unscale gradients 207 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0) # clip gradients 208 | scaler.step(optimizer) 209 | scaler.update() 210 | 211 | optimizer.zero_grad() 212 | if ema: 213 | ema.update(model) 214 | 215 | def train_one_epoch_emb(self, criterion, cur_epoch, loss_meter): 216 | """Tain one epoch by traditional training. 217 | """ 218 | 219 | # config 220 | config = {} 221 | config["model"] = self.model_cfg 222 | config["data"] = self.data_cfg 223 | config["hyp"] = self.hyp_cfg 224 | 225 | self.model.train() 226 | iters_per_epoch = len(self.train_dataloader) 227 | 228 | for batch_idx, (images, labels) in enumerate(self.train_dataloader): 229 | images, labels = images.to(self.device, non_blocking=True), labels.to(self.device) 230 | 231 | loss = self.compute_loss(images, labels, lam=0, criterion = criterion, face = True) 232 | 233 | global_batch_idx = cur_epoch * iters_per_epoch + batch_idx 234 | self.scheduler.step() # step batch-wise 235 | 236 | if self.rank in (-1, 0): 237 | loss_meter.update(loss.item(), images.shape[0]) 238 | 239 | if self.rank in (-1, 0) and batch_idx % self.print_freq == 0: 240 | loss_avg = loss_meter.avg 241 | lr = self.optimizer.param_groups[0]["lr"] 242 | self.logger.both('Epoch %d, iter %d/%d, lr %f, loss %f' % 243 | (cur_epoch+1, batch_idx+1, iters_per_epoch, lr, loss_avg)) 244 | self.writer.add_scalar('Train_loss', loss_avg, global_batch_idx) 245 | self.writer.add_scalar('Train_lr', lr, global_batch_idx) 246 | loss_meter.reset() 247 | 248 | if self.rank in (-1, 0) and (((cur_epoch * iters_per_epoch + batch_idx + 1) % (self.save_freq * iters_per_epoch)== 0) or ((cur_epoch * iters_per_epoch + batch_idx + 1) == (self.epochs * iters_per_epoch))): 249 | saved_name = 'Epoch_%d.pt' % (cur_epoch+1) 250 | if self.task == 'face': 251 | mean, std = valuate_face(self.ema.ema.trainingwrapper['backbone'], 252 | self.data_cfg, 253 | self.device) 254 | self.writer.add_scalar('Val_mean', mean, global_batch_idx) 255 | self.writer.add_scalar('Val_std', std, global_batch_idx) 256 | 257 | fitness = {'fitness': {'Val_mean': float(mean), 'Val_std': float(std)}} 258 | elif self.task == 'cbir': 259 | metrics = valuate_cbir(self.ema.ema.trainingwrapper['backbone'], 260 | self.data_cfg, 261 | self.device, 262 | self.logger) 263 | for k, v in metrics.items(): 264 | self.writer.add_scalar(f'Val_{k}', v, global_batch_idx) 265 | fitness = {'fitness': metrics} 266 | 267 | fitness['checkpoint'] = saved_name 268 | 269 | ckpt = { 270 | 'epoch': cur_epoch, 271 | 'batch_id': batch_idx, 272 | 'fitness': fitness, 273 | 'state_dict': self.model.trainingwrapper['backbone'].state_dict() if self.rank == -1 else self.model.module.trainingwrapper['backbone'].state_dict(), 274 | 'ema': deepcopy(self.ema.ema.trainingwrapper['backbone'].state_dict()), 275 | 'updates': self.ema.updates, 276 | 'optimizer': self.optimizer.state_dict(), # optimizer.state_dict(), 277 | 'scheduler': self.scheduler.state_dict(), 278 | 'config': config 279 | } 280 | if self.device != torch.device('cpu'): 281 | ckpt['scaler'] = self.scaler.state_dict() 282 | 283 | torch.save(ckpt, os.path.join(self.writer.log_dir, saved_name)) 284 | self.logger.both(fitness) 285 | torch.cuda.empty_cache() --------------------------------------------------------------------------------