├── torch_frame ├── vision │ ├── tools │ │ ├── __init__.py │ │ └── object_detection.py │ ├── __init__.py │ └── augmentations │ │ ├── __init__.py │ │ ├── base.py │ │ ├── colors.py │ │ └── boxes.py ├── datasets │ ├── __init__.py │ ├── dataset_wrappers.py │ └── dataloader_wrappers.py ├── _get_logger.py ├── hooks │ ├── __init__.py │ ├── hookbase.py │ ├── checkpoint_hook.py │ ├── eval_hook.py │ └── logger_hook.py ├── utils │ ├── __init__.py │ ├── dist_utils.py │ ├── progress_bar.py │ ├── history_buffer.py │ ├── ema.py │ ├── test_speed.py │ ├── config_parser.py │ ├── logger.py │ ├── misc.py │ └── metric.py ├── __init__.py ├── ddp_trainer.py ├── lr_scheduler.py ├── accelerate_trainer.py └── trainer.py ├── setup.py └── README.md /torch_frame/vision/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .object_detection import * -------------------------------------------------------------------------------- /torch_frame/vision/__init__.py: -------------------------------------------------------------------------------- 1 | from .tools import * 2 | from .augmentations import * 3 | -------------------------------------------------------------------------------- /torch_frame/vision/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from .colors import * 2 | from .boxes import * -------------------------------------------------------------------------------- /torch_frame/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_wrappers import * 2 | from .dataloader_wrappers import * 3 | -------------------------------------------------------------------------------- /torch_frame/_get_logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | logging.basicConfig(level=logging.INFO) 5 | logger = logging.getLogger("torch_frame") 6 | -------------------------------------------------------------------------------- /torch_frame/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint_hook import CheckpointerHook 2 | from .eval_hook import EvalHook, EvalTotalHook 3 | from .hookbase import HookBase 4 | from .logger_hook import LoggerHook 5 | -------------------------------------------------------------------------------- /torch_frame/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .config_parser import ConfigArgumentParser, save_args 2 | from .misc import * 3 | from .logger import setup_logger 4 | from .history_buffer import HistoryBuffer 5 | from .progress_bar import ProgressBar 6 | from .ema import EMA 7 | -------------------------------------------------------------------------------- /torch_frame/__init__.py: -------------------------------------------------------------------------------- 1 | from ._get_logger import logger 2 | from .hooks import * 3 | from .lr_scheduler import LRWarmupScheduler 4 | from .utils import HistoryBuffer, misc, metric 5 | from .trainer import Trainer, MetricStorage 6 | from .ddp_trainer import DDPTrainer 7 | from .accelerate_trainer import AccelerateTrainer 8 | 9 | 10 | # 关闭opencv的多线程防止pytorch死锁 11 | import cv2 12 | cv2.setNumThreads(0) 13 | cv2.ocl.setUseOpenCL(False) 14 | -------------------------------------------------------------------------------- /torch_frame/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | 包含ddp相关函数 3 | """ 4 | 5 | import torch.distributed as dist 6 | 7 | 8 | def is_dist_avail_and_initialized(): 9 | if not dist.is_available(): 10 | return False 11 | if not dist.is_initialized(): 12 | return False 13 | return True 14 | 15 | 16 | def get_world_size(): 17 | if not is_dist_avail_and_initialized(): 18 | return 1 19 | return dist.get_world_size() 20 | 21 | 22 | def get_rank(): 23 | if not is_dist_avail_and_initialized(): 24 | return 0 25 | return dist.get_rank() -------------------------------------------------------------------------------- /torch_frame/vision/augmentations/base.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import numpy as np 4 | 5 | 6 | __all__ = [ 7 | "mixup" 8 | ] 9 | 10 | 11 | def mixup(img1: np.ndarray, img2: np.ndarray, 12 | mixup_scale: float = 0.1): 13 | def translate(image, h, w, alpha): 14 | y_offset = random.randint(0, dst_img.shape[0] - h) 15 | x_offset = random.randint(0, dst_img.shape[1] - w) 16 | dst_img[y_offset: h+y_offset, x_offset: w+x_offset] = image * alpha 17 | 18 | h1, w1 = img1.shape[:2] 19 | h2, w2 = img2.shape[:2] 20 | dst_img = np.zeros((max(h1, h2), max(w1, w2), 3), dtype=np.float) 21 | alpha = random.uniform(0.5 - mixup_scale, 0.5 + mixup_scale) 22 | img1 = img1.astype(np.float32) 23 | translate(img1, h1, w1, alpha) 24 | translate(img1, h2, w2, 1 - alpha) 25 | return dst_img.astype(np.uint8) -------------------------------------------------------------------------------- /torch_frame/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | __all__ = [ 5 | "RepeatDataset", 6 | ] 7 | 8 | 9 | # 源自https://github.com/open-mmlab/mmdetection/blob/master/mmdet/datasets/dataset_wrappers.py 10 | class RepeatDataset: 11 | """A wrapper of repeated dataset. 12 | 13 | The length of repeated dataset will be `times` larger than the original 14 | dataset. This is useful when the data loading time is long but the dataset 15 | is small. Using RepeatDataset can reduce the data loading time between 16 | epochs. 17 | 18 | Args: 19 | dataset (:obj:`Dataset`): The dataset to be repeated. 20 | times (int): Repeat times. 21 | """ 22 | 23 | def __init__(self, dataset, times): 24 | self.dataset = dataset 25 | self.times = times 26 | 27 | self._ori_len = len(self.dataset) 28 | 29 | def __getitem__(self, idx): 30 | return self.dataset[idx % self._ori_len] 31 | 32 | def __len__(self): 33 | """Length after repetition.""" 34 | return self.times * self._ori_len -------------------------------------------------------------------------------- /torch_frame/datasets/dataloader_wrappers.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataloader import DataLoader 2 | 3 | 4 | __all__ = [ 5 | "InfiniteDataLoader", 6 | ] 7 | 8 | 9 | # 源自https://github.com/AlibabaResearch/efficientteacher/blob/main/utils/datasets_ssod.py 10 | class InfiniteDataLoader(DataLoader): 11 | """ Dataloader that reuses workers 12 | 13 | Uses same syntax as vanilla DataLoader 14 | """ 15 | 16 | def __init__(self, *args, **kwargs): 17 | super().__init__(*args, **kwargs) 18 | object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) 19 | self.iterator = super().__iter__() 20 | 21 | def __len__(self): 22 | return len(self.batch_sampler.sampler) 23 | 24 | def __iter__(self): 25 | for i in range(len(self)): 26 | yield next(self.iterator) 27 | 28 | 29 | class _RepeatSampler(object): 30 | """ Sampler that repeats forever 31 | 32 | Args: 33 | sampler (Sampler) 34 | """ 35 | 36 | def __init__(self, sampler): 37 | self.sampler = sampler 38 | 39 | def __iter__(self): 40 | while True: 41 | yield from iter(self.sampler) -------------------------------------------------------------------------------- /torch_frame/utils/progress_bar.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | from .dist_utils import is_dist_avail_and_initialized, get_rank 3 | 4 | 5 | class ProgressBar: 6 | """tqdm基础上包了一层,主要兼容ddp的多卡""" 7 | def __init__(self, *args, **kwargs): 8 | if not is_dist_avail_and_initialized() or get_rank() == 0: 9 | self.pbar = tqdm.tqdm(*args, **kwargs) 10 | else: 11 | self.pbar = None 12 | 13 | def update(self, *args, **kwargs): 14 | if self.pbar is not None: 15 | self.pbar.update(*args, **kwargs) 16 | 17 | def close(self): 18 | if self.pbar is not None: 19 | self.pbar.close() 20 | 21 | def set_postfix(self, *args, **kwargs): 22 | if self.pbar is not None: 23 | self.pbar.set_postfix(*args, **kwargs) 24 | 25 | def set_description(self, *args, **kwargs): 26 | if self.pbar is not None: 27 | self.pbar.set_description(*args, **kwargs) 28 | 29 | def set_postfix_str(self, *args, **kwargs): 30 | if self.pbar is not None: 31 | self.pbar.set_postfix_str(*args, **kwargs) 32 | 33 | def set_description_str(self, *args, **kwargs): 34 | if self.pbar is not None: 35 | self.pbar.set_description_str(*args, **kwargs) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import setuptools 3 | 4 | # 如果readme文件中有中文,那么这里要指定encoding='utf-8',否则会出现编码错误 5 | with open(os.path.join(os.path.dirname(__file__), 'README.md'), encoding='utf-8') as readme: 6 | README = readme.read() 7 | 8 | # 允许setup.py在任何路径下执行 9 | os.chdir(os.path.normpath(os.path.join(os.path.abspath(__file__), os.pardir))) 10 | 11 | setuptools.setup( 12 | name="torch-frame", # 库名, 需要在pypi中唯一 13 | version="1.7.9", # 版本号 14 | author="Darkn Lxs", # 作者 15 | author_email="1187220556@qq.com", # 作看都将(方便使用索类现问图后成我我们) 16 | description="用于深度学习快速实现代码的框架", # 简介 17 | long_description="见readme", # 详细描述(一般会写在README.md中) 18 | long_description_content_type="text/markdown", # README.md中描述的语法(一般为markdown) 19 | url="https://github.com/darknli/Pytorch-Frame/tree/main/torch_frame", # 库/项目主页,放该项目的远程库地址即可 20 | packages=setuptools.find_packages(), # 默认值即可,这个是方便以后我们给库拓展新功能的 21 | classifiers=[ # 指定该库依赖的Python版本、license、操作系统之类的 22 | "Programming Language :: Python :: 3", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: OS Independent", 25 | ], 26 | install_requires=[ # 该库需要的依前库 27 | "termcolor", 28 | "numpy>=1.17", 29 | "opencv-python", 30 | "tabulate", 31 | "torch", 32 | "transformers>=4.25.1", 33 | "accelerate>=0.16.0", 34 | "diffusers", 35 | "pyyaml", 36 | "tqdm", 37 | "tensorboard" 38 | ], 39 | python_requires='>=3.6', 40 | ) 41 | 42 | # python setup.py sdist bdist_wheel -------------------------------------------------------------------------------- /torch_frame/utils/history_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import deque 3 | 4 | 5 | class HistoryBuffer: 6 | """ 7 | 该类保存了一些数, 并且可以得到平滑均值 8 | 9 | Example:: 10 | 11 | >>> his_buf = HistoryBuffer() 12 | >>> his_buf.update(0.1) 13 | >>> his_buf.update(0.2) 14 | >>> his_buf.avg 15 | 0.15 16 | """ 17 | 18 | def __init__(self, window_size: int = 20) -> None: 19 | """ 20 | Parameters 21 | ---------- 22 | window_size : int, default 20. 滑窗大小 23 | """ 24 | self._history = deque(maxlen=window_size) 25 | self._count: int = 0 26 | self._sum: float = 0.0 27 | 28 | def update(self, value: float) -> None: 29 | """ 30 | 在列表新增变量, 如果新增后超出窗口大小, 舍去最早的那个值 31 | """ 32 | self._history.append(value) 33 | self._count += 1 34 | self._sum += value 35 | 36 | @property 37 | def latest(self) -> float: 38 | return self._history[-1] 39 | 40 | @property 41 | def avg(self) -> float: 42 | return np.mean(self._history) 43 | 44 | @property 45 | def global_avg(self) -> float: 46 | return self._sum / self._count 47 | 48 | @property 49 | def global_sum(self) -> float: 50 | return self._sum 51 | 52 | def __le__(self, other): 53 | return self.avg <= other 54 | 55 | def __lt__(self, other): 56 | return self.avg < other 57 | 58 | def __ge__(self, other): 59 | return self.avg >= other 60 | 61 | def __gt__(self, other): 62 | return self.avg > other 63 | 64 | def __str__(self): 65 | return str(round(self.avg, 4)) 66 | -------------------------------------------------------------------------------- /torch_frame/utils/ema.py: -------------------------------------------------------------------------------- 1 | import math 2 | from copy import deepcopy 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class EMA: 8 | """ 9 | Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models 10 | Keep a moving average of everything in the model state_dict (parameters and buffers). 11 | This is intended to allow functionality like 12 | https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage 13 | A smoothed version of the weights is necessary for some training schemes to perform well. 14 | This class is sensitive where it is initialized in the sequence of model init, 15 | GPU assignment and distributed training wrappers. 16 | """ 17 | 18 | def __init__(self, model, decay=0.9999, updates=0): 19 | """ 20 | Args: 21 | model (nn.Module): model to apply EMA. 22 | decay (float): ema decay reate. 23 | updates (int): counter of EMA updates. 24 | """ 25 | # Create EMA(FP32) 26 | self.model = deepcopy(model).eval() 27 | self.updates = updates 28 | # decay exponential ramp (to help early epochs) 29 | self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) 30 | for p in self.model.parameters(): 31 | p.requires_grad_(False) 32 | 33 | def update(self, model): 34 | # Update EMA parameters 35 | with torch.no_grad(): 36 | self.updates += 1 37 | d = self.decay(self.updates) 38 | msd = model.state_dict() 39 | for k, v in self.model.state_dict().items(): 40 | if v.dtype.is_floating_point: 41 | v *= d 42 | v += (1.0 - d) * msd[k].detach() 43 | -------------------------------------------------------------------------------- /torch_frame/utils/test_speed.py: -------------------------------------------------------------------------------- 1 | """ 2 | 该脚本主要用于测试计算资源对模型推理的速度 3 | """ 4 | from torchvision.models.resnet import resnet18, resnet50, resnet101, resnet152 5 | from torchvision.models.mobilenet import mobilenet_v2 6 | from torchvision.models.vgg import vgg16 7 | from .progress_bar import ProgressBar 8 | import torch 9 | import time 10 | import pandas as pd 11 | 12 | 13 | BENCHMARK = { 14 | (16, 3, 224, 224): { 15 | "mobilenet_v2": { 16 | "RTX-3090": 6, 17 | "V100": 12, 18 | "A100": 8, 19 | }, 20 | "resnet50": { 21 | "RTX-3090": 7, 22 | "V100": 9, 23 | "A100": 8 24 | }, 25 | "vgg16": { 26 | "RTX-3090": 9, 27 | "V100": 7, 28 | "A100": 5 29 | }, 30 | }, 31 | (64, 3, 224, 224): { 32 | "mobilenet_v2": { 33 | "RTX-3090": 16, 34 | "V100": 25, 35 | "A100": 24, 36 | }, 37 | "resnet50": { 38 | "RTX-3090": 19, 39 | "V100": 33, 40 | "A100": 24, 41 | }, 42 | "vgg16": { 43 | "RTX-3090": 29, 44 | "V100": 28, 45 | "A100": 22, 46 | }, 47 | } 48 | } 49 | 50 | 51 | def gpu_cnn_speed(model_type: str = None, input_size: tuple = (16, 3, 224, 224), gpu_id: int = None): 52 | """ 53 | 测试gpu在cnn上的性能 54 | 55 | Parameters 56 | ---------- 57 | model_type : str, default resnet50. 模型类型 58 | * vgg16 59 | * resnet18 60 | * resnet50 61 | * resnet101 62 | * resnet152 63 | * mobilenet_v2 64 | input_size : tuple, default (16, 3, 224, 224). 模型输入尺寸 65 | gpu_id : int, default None. 显卡号,如果是None,且显卡可用,则默认为0 66 | """ 67 | if model_type is None: 68 | model_type = "resnet50" 69 | assert model_type in ("vgg16", "resnet18", "resnet50", "resnet101", "resnet152", "mobilenet_v2") 70 | 71 | if gpu_id is None: 72 | gpu_id = 0 73 | device = f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu" 74 | model = eval(model_type)().to(device) 75 | iters = 300 76 | 77 | begin = time.time() 78 | pbar = ProgressBar(total=iters, desc=model_type) 79 | for i in range(iters): 80 | x = torch.randn(input_size).to(device) 81 | _ = model(x) 82 | pbar.update(1) 83 | print(f"全程耗时: {int(time.time() - begin + 0.5)}s") 84 | print("参考各显卡benchmark") 85 | for bis, v in BENCHMARK.items(): 86 | print("*" * 50) 87 | print("benchmark input size:", bis) 88 | print(pd.DataFrame(v)) 89 | -------------------------------------------------------------------------------- /torch_frame/hooks/hookbase.py: -------------------------------------------------------------------------------- 1 | # from ..trainer import Trainer, MetricStorage 2 | # import numpy as np 3 | 4 | 5 | class HookBase: 6 | """ 7 | hooks的基类 8 | 9 | Hook类在Trainer类中被初始化。每个Hook可以在六个阶段执行, 对应Trainer的六个方法分别为: 10 | 11 | * before_train, 训练前 12 | * after_train, 训练后 13 | * before_epoch, 一轮epoch前 14 | * after_epoch, 一轮epoch后 15 | * before_iter, 一个iter前 16 | * after_iter, 一个iter后 17 | 目前在Hook中, 不能得到通过self.trainer得到类似model、optimizer等信息 18 | 19 | Examples: 20 | 21 | >>> hook.before_train() 22 | >>> for epoch in range(start_epoch, max_epochs): 23 | >>> hook.before_epoch() 24 | >>> for iter in range(epoch_len): 25 | >>> hook.before_iter() 26 | >>> train_one_iter() 27 | >>> hook.after_iter() 28 | >>> hook.after_epoch() 29 | >>> hook.after_train() 30 | """ 31 | 32 | # A weak reference to the trainer object. Set by the trainer when the hook is registered. 33 | trainer: "torch_frame.Trainer" = None 34 | 35 | def before_train(self) -> None: 36 | """整体训练前调用""" 37 | pass 38 | 39 | def after_train(self) -> None: 40 | """全部训练结束后调用""" 41 | pass 42 | 43 | def before_epoch(self) -> None: 44 | """epoch前调用""" 45 | pass 46 | 47 | def after_epoch(self) -> None: 48 | """epoch结束后调用""" 49 | pass 50 | 51 | def before_iter(self) -> None: 52 | """iter前调用""" 53 | pass 54 | 55 | def after_iter(self) -> None: 56 | """iter结束后调用""" 57 | pass 58 | 59 | @property 60 | def checkpointable(self) -> bool: 61 | """A hook is checkpointable when it has :meth:`state_dict` method. 62 | Its state will be saved into checkpoint. 63 | """ 64 | return callable(getattr(self, "state_dict", None)) 65 | 66 | @property 67 | def class_name(self) -> str: 68 | """返回类名""" 69 | return self.__class__.__name__ 70 | 71 | @property 72 | def metric_storage(self) -> "torch_frame.MetricStorage": 73 | return self.trainer.metric_storage 74 | 75 | def log(self, *args, **kwargs) -> None: 76 | self.trainer.log(*args, **kwargs) 77 | 78 | # belows are some helper functions that are often used in hook 79 | def every_n_epochs(self, n: int) -> bool: 80 | return self.trainer.epoch % n == 0 if n > 0 else False 81 | 82 | def every_n_iters(self, n: int) -> bool: 83 | return (self.trainer.iter + 1) % n == 0 if n > 0 else False 84 | 85 | def every_n_inner_iters(self, n: int) -> bool: 86 | return (self.trainer.inner_iter + 1) % n == 0 if n > 0 else False 87 | 88 | def is_last_epoch(self) -> bool: 89 | return self.trainer.epoch == self.trainer.max_epochs 90 | 91 | def is_last_iter(self) -> bool: 92 | return self.trainer.iter == self.trainer.max_iters - 1 93 | 94 | def is_last_inner_iter(self) -> bool: 95 | return self.trainer.inner_iter == self.trainer.epoch_len - 1 96 | -------------------------------------------------------------------------------- /torch_frame/vision/augmentations/colors.py: -------------------------------------------------------------------------------- 1 | from random import randint, random, uniform 2 | import numpy as np 3 | import cv2 4 | 5 | __all__ = [ 6 | "RandomBrightness", 7 | "RandomGammaCorrection", 8 | "RandomHueSaturation", 9 | "RandomContrast", 10 | ] 11 | 12 | 13 | class RandomBrightness: 14 | def __init__(self, low=-10, high=10, p=0.2): 15 | if low < -80 or high > 80: 16 | raise ValueError("亮度low不可以小于-80, high不可以大于80") 17 | self.low = low 18 | self.high = high 19 | self.p = p 20 | 21 | def __call__(self, image): 22 | if random() < self.p: 23 | return image 24 | image = image.astype(np.float32) 25 | brightness = randint(self.low, self.high) 26 | image[:, :, :] += brightness 27 | image = np.clip(image, 0, 255).astype(np.uint8) 28 | return image 29 | 30 | 31 | class RandomGammaCorrection: 32 | def __init__(self, lower=0.666, upper=1.5, p=0.5): 33 | self.lower = lower 34 | self.upper = upper 35 | self.p = p 36 | 37 | def __call__(self, image): 38 | if random() < self.p: 39 | return image 40 | cur_gamma = uniform(self.lower, self.upper) 41 | max_channels = (255, 255, 255) 42 | image = np.power(image / max_channels, cur_gamma) 43 | image = (image * 255).astype(np.uint8) 44 | return image 45 | 46 | 47 | class RandomHueSaturation: 48 | def __init__(self, hue_lower=-12, hue_upper=12, sat_lower=0.6, sat_upper=1.5, p=0.5): 49 | self.hue_lower = hue_lower 50 | self.hue_upper = hue_upper 51 | self.sat_lower = sat_lower 52 | self.sat_upper = sat_upper 53 | self.p = p 54 | 55 | def __call__(self, image): 56 | if random() < self.p: 57 | return image 58 | random_h = uniform(self.hue_lower, self.hue_upper) 59 | random_s = uniform(self.sat_lower, self.sat_upper) 60 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV).astype(np.float32) 61 | h = image[:, :, 0] 62 | h += random_h 63 | h[h > 180.0] -= 180.0 64 | h[h < 0.0] += 180.0 65 | 66 | s = image[:, :, 1] 67 | s *= random_s 68 | # 一定要先转uint8 69 | image = cv2.cvtColor(image.astype(np.uint8), cv2.COLOR_HSV2BGR) 70 | return image 71 | 72 | 73 | class RandomContrast: 74 | def __init__(self, low=0.9, high=1.1, p=0.2): 75 | if low < 0.5 or high > 2: 76 | raise ValueError("对比度low不可以小于0.5, high不可以大于2.0") 77 | if low >= high: 78 | raise ValueError("对比度low不可以大于或等于high") 79 | self.low = low 80 | self.high = high 81 | self.p = p 82 | 83 | def __call__(self, image): 84 | if random() < self.p: 85 | return image 86 | image = image.astype(np.float32) 87 | contrast = uniform(self.low, self.high) 88 | image[:, :, :] *= contrast 89 | image = np.clip(image, 0, 255).astype(np.uint8) 90 | return image -------------------------------------------------------------------------------- /torch_frame/utils/config_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import argparse 4 | import yaml 5 | from copy import deepcopy 6 | from argparse import Namespace 7 | from typing import Optional, List 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class ConfigArgumentParser(argparse.ArgumentParser): 13 | """Argument parser that supports loading a YAML configuration file. 14 | 15 | A small issue: config file values are processed using ArgumentParser.set_defaults(..) 16 | which means "required" and "choices" are not handled as expected. For example, if you 17 | specify a required value in a config file, you still have to specify it again on the 18 | command line. The ``ConfigArgParse`` library (http://pypi.python.org/pypi/ConfigArgParse) 19 | can be used as a substitute. 20 | """ 21 | 22 | def __init__(self, *args, **kwargs): 23 | self.config_parser = argparse.ArgumentParser(add_help=False) 24 | self.config_parser.add_argument("-c", "--config", default=None, metavar="FILE", 25 | help="where to load YAML configuration") 26 | self.option_names = [] 27 | super().__init__(*args, 28 | # Inherit options from config_parser 29 | parents=[self.config_parser], 30 | # Don't mess with format of description 31 | formatter_class=argparse.RawDescriptionHelpFormatter, 32 | **kwargs) 33 | 34 | def add_argument(self, *args, **kwargs): 35 | arg = super().add_argument(*args, **kwargs) 36 | self.option_names.append(arg.dest) 37 | return arg 38 | 39 | def parse_args(self, args=None): 40 | res, remaining_argv = self.config_parser.parse_known_args(args) 41 | 42 | if res.config is not None: 43 | with open(res.config, "r") as f: 44 | config_vars = yaml.safe_load(f) 45 | for key in config_vars: 46 | if key not in self.option_names: 47 | self.error(f"unexpected configuration entry: {key}") 48 | self.set_defaults(**config_vars) 49 | 50 | return super().parse_args(remaining_argv) 51 | 52 | 53 | def save_args(args: Namespace, filepath: str, excluded_fields: Optional[List[str]] = None) -> None: 54 | """Save args with some excluded fields to a ``.yaml`` file. 55 | 56 | Args: 57 | args (Namespace): The parsed arguments to be saved. 58 | filepath (str): A filepath ends with ".yaml". 59 | excluded_fields (list[str]): The names of some fields that are not saved. 60 | Defaults to ["config"]. 61 | """ 62 | assert isinstance(args, Namespace) 63 | assert filepath.endswith(".yaml") 64 | os.makedirs(os.path.dirname(os.path.abspath(filepath)), exist_ok=True) 65 | save_dict = deepcopy(args.__dict__) 66 | for field in excluded_fields or ["config"]: 67 | save_dict.pop(field) 68 | with open(filepath, "w") as f: 69 | yaml.dump(save_dict, f) 70 | logger.info(f"Args is saved to {filepath}") 71 | -------------------------------------------------------------------------------- /torch_frame/hooks/checkpoint_hook.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import os.path as osp 4 | from typing import Any, Dict, List, Optional 5 | from types import LambdaType 6 | import logging 7 | from .hookbase import HookBase 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class CheckpointerHook(HookBase): 13 | """ 14 | 周期性保存参数 15 | """ 16 | 17 | def __init__(self, period: int = 1, 18 | max_to_keep: Optional[int] = None, 19 | save_metric: Optional[str] = None, 20 | max_first: bool = True, 21 | save_last: bool = True 22 | ) -> None: 23 | """ 24 | 初始化 25 | Parameters 26 | ---------- 27 | period : int, 保存checkpoints的周期 28 | max_to_keep : int, 保存checkpoints的数量, 更早期的checkpoints会被删除 29 | save_metric : int, default None. 30 | 保存模型的指标是哪个, 需要从trainer.metric_storage选择 31 | max_first : bool, default True. 32 | 用于保存模型的指标是取最大还是最小作为最优模型 33 | save_last : bool, default True 34 | 是否保存最近一次的epoch的模型, 如果是True, 每轮将更新模型到latest.pth中 35 | """ 36 | self._period = period 37 | assert max_to_keep is None or max_to_keep > 0 38 | if max_to_keep is None and save_metric is None and not save_last: 39 | raise ValueError("创建了无效的`CheckpointerHook`对象,因为不会保存任何模型") 40 | self._max_to_keep = max_to_keep 41 | if save_metric is None: 42 | if max_to_keep is None: 43 | logger.warning("没有指定保存模型的指标,不会保存best模型") 44 | else: 45 | logger.warning("没有指定保存模型的指标,因此每period都将保存模型") 46 | self.save_metric = save_metric 47 | if max_first: 48 | self.cur_best = float("-inf") 49 | self.is_better = lambda a: a > self.cur_best 50 | else: 51 | self.cur_best = float("inf") 52 | self.is_better = lambda a: a < self.cur_best 53 | 54 | self.save_last = save_last 55 | 56 | self._recent_checkpoints: List[str] = [] 57 | 58 | def after_epoch(self) -> None: 59 | if self.every_n_epochs(self._period) or self.is_last_epoch(): 60 | self.save_model() 61 | 62 | def save_model(self): 63 | if self.save_last: 64 | self.trainer.save_checkpoint("latest.pth", False) 65 | 66 | # 如果当前epoch指标没有更好, 则不保存模型 67 | if self.save_metric is not None: 68 | if not self.is_better(self.trainer.metric_storage[self.save_metric]): 69 | return 70 | self.cur_best = self.trainer.metric_storage[self.save_metric].avg 71 | logger.info(f"{self.save_metric} update to {round(self.cur_best, 4)}") 72 | self.trainer.save_checkpoint("best.pth") 73 | 74 | if self._max_to_keep is not None and self._max_to_keep >= 1: 75 | epoch = self.trainer.epoch # ranged in [0, max_epochs - 1] 76 | checkpoint_name = f"epoch_{epoch}.pth" 77 | self.trainer.save_checkpoint(checkpoint_name) 78 | self._recent_checkpoints.append(checkpoint_name) 79 | if len(self._recent_checkpoints) > self._max_to_keep: 80 | # delete the oldest checkpoint 81 | file_name = self._recent_checkpoints.pop(0) 82 | file_path = osp.join(self.trainer.ckpt_dir, file_name) 83 | if os.path.exists(file_path): 84 | os.remove(file_path) 85 | 86 | def state_dict(self) -> Dict[str, Any]: 87 | state = {} 88 | for key, value in self.__dict__.items(): 89 | if key == "trainer" or isinstance(value, LambdaType): 90 | continue 91 | try: 92 | pickle.dumps(value) 93 | except BaseException: 94 | continue 95 | state[key] = value 96 | return state 97 | 98 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 99 | self.__dict__.update(state_dict) 100 | -------------------------------------------------------------------------------- /torch_frame/ddp_trainer.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer, nn, optim, Optional, List, HookBase, setup_logger, logger 2 | import os 3 | from torch.utils.data import DataLoader, Dataset, DistributedSampler, RandomSampler 4 | from .utils.dist_utils import * 5 | import warnings 6 | 7 | 8 | class DDPTrainer(Trainer): 9 | """ 10 | Trainer的DDP版本。需要注意的是,在调用该类前依然要执行init_process_group,为了避免重复打印或保存数据,应该在外部选择性 11 | 的根据gpu_id调用hooker, 在创建DDPTrainer对象时会自动根据环境判断是否使用ddp,无需手动设置 12 | 13 | 14 | Parameters 15 | --------- 16 | model : torch.nn.Module, 训练模型, 训练时的输出只能是以下三种: 17 | * torch.Tensor, 对于这种输出是模型backward用的loss, 在torch-frame框架中被称为total_loss 18 | * dict, 里面是模型的各路分支的loss,需要是标量, Trainer会自动将其求和得到total_loss, 再做backward 19 | * Tuple[Union[dict, torch.Tensor], dict]. 前两种的混合体, 元组第一个输出是前面两种的任意一种; 20 | 第二个输出是非需要backward类的, Trainer不会把这个dict汇总到total_loss上 21 | optimizer : torch.optim.Optimizer, 优化器 22 | lr_scheduler : optim.lr_scheduler._LRScheduler, 学习率调节器 23 | dataset : torch.utils.data.Dataset, 训练集数据生成器, 不需要创建dataloader, 由DDPTrainer内部创建 24 | dataset_params : dict, dataset的参数, key是诸如batch_size的参数 25 | max_epochs : int, 训练的总轮数 26 | work_dir : str, 保存模型和日志的根目录地址 27 | clip_grad_norm : float, default 0.0 28 | 梯度裁剪的设置, 如果置为小于等于0, 则不作梯度裁剪 29 | enable_amp : bool, 使用混合精度 30 | warmup_method : str, default None 31 | warmup的类型, 包含以下四种取值 32 | * constant 33 | * linear 34 | * exp 35 | * None : 不使用warmup 36 | warmup_iters : int, default 1000, warmup最后的iter数 37 | warmup_factor : float, default 0.001 38 | warmup初始学习率 = warmup_factor * initial_lr 39 | hooks : List[HookBase], default None. 40 | hooks, 保存模型、输出评估指标、loss等用 41 | use_ema : bool, default False. 是否使用EMA技术 42 | ema_decay: float = 0.9999. EMA模型衰减系数 43 | create_new_dir : Optional[str], default time 44 | 存在同名目录时以何种策略创建目录 45 | * None, 直接使用同名目录 46 | * `time_s`, 如果已经存在同名目录, 则以时间(精确到秒)为后缀创建新目录 47 | * `time_m`, 如果已经存在同名目录, 则以时间(精确到分)为后缀创建新目录 48 | * `time_h`, 如果已经存在同名目录, 则以时间(精确到小时)为后缀创建新目录 49 | * `time_d`, 如果已经存在同名目录, 则以时间(精确到日)为后缀创建新目录 50 | * `count`, 如果已经存在同名目录, 则以序号为后缀创建新目录 51 | """ 52 | 53 | def __init__(self, 54 | model: nn.Module, 55 | optimizer: optim.Optimizer, 56 | lr_scheduler: optim.lr_scheduler._LRScheduler, 57 | dataset: Dataset, 58 | dataset_params: dict, 59 | max_epochs: int, 60 | work_dir: str = "work_dir", 61 | clip_grad_norm: float = 0.0, 62 | enable_amp=False, 63 | warmup_method: Optional[str] = None, 64 | warmup_iters: int = 1000, 65 | warmup_factor: float = 0.001, 66 | hooks: Optional[List[HookBase]] = None, 67 | use_ema: bool = False, 68 | ema_decay: float = 0.9999, 69 | create_new_dir: Optional[str] = "time_s" 70 | ): 71 | warnings.warn("为了更快训练,建议使用AccelerateTrainer来替代这个trainer", DeprecationWarning) 72 | self.use_dist = is_dist_avail_and_initialized() 73 | if self.use_dist: 74 | num_tasks = get_world_size() 75 | global_rank = get_rank() 76 | sampler_trainer = DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank) 77 | else: 78 | sampler_trainer = RandomSampler(dataset) 79 | data_loader = DataLoader(dataset, sampler=sampler_trainer, **dataset_params) 80 | super(DDPTrainer, self).__init__(model, optimizer, lr_scheduler, data_loader, max_epochs, work_dir, 81 | clip_grad_norm, enable_amp, warmup_method, warmup_iters, warmup_factor, 82 | hooks, use_ema, ema_decay, create_new_dir) 83 | 84 | def _train_one_epoch(self) -> None: 85 | """执行模型一个epoch的全部操作""" 86 | if self.use_dist: 87 | # dist.barrier() 这里可能会造成阻塞 88 | self.data_loader.sampler.set_epoch(self.epoch) 89 | super(DDPTrainer, self)._train_one_epoch() 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Frame 2 | 原代码来自https://github.com/machineko/coreml_torch_utils 3 | ,此为改版。在原版基础之上加入了大量功能 4 | 5 | # 安装 6 | ❌ pip方式不再维护: 7 | ~~pip install torch-frame~~ 8 | 9 | ✅ 推荐使用pip install git+https://github.com/darknli/Pytorch-Frame.git 10 | # 单卡训练 11 | 使用Trainer训练,下面是代码示例(Trainer支持混合精度训练,可自行去Trainer类中翻阅) 12 | ```commandline 13 | # 创建dataset和dataloader 14 | from torch.util.dataset import Dataset, DataLoader 15 | from torch.optim.lr_scheduler import MultiStepLR 16 | from torch_frame import Trainer 17 | 18 | train_dataset = Dataset(...) 19 | train_dataloader = DataLoader(...) 20 | 21 | # 创建网络相关对象 22 | model = get_model(conf) 23 | optimizer = Adam(model.parameters(), lr) 24 | lr_scheduler = MultiStepLR(optimizer, ...) 25 | 26 | # 创建hooker,承载验证集部分和评估保存模型的任务 27 | # 这里也可以不做定制化创建,走Trainer默认的hooks,这个时候只支持log和checkpoint latest保存 28 | hooks = [EvalHook(...), LoggerHook(...)] 29 | 30 | # 创建Trainer对象并开始训练 31 | trainer = Trainer(model, optimizer, lr_scheduler, train_dataloader, num_epochs, "保存路径", hooks=hooks) 32 | traine.train() # 开始正式训练 33 | ``` 34 | 也可以加载之前训练到一半的模型以及的训练状态,接着训练 35 | ```commandline 36 | trainer = Trainer(model, optimizer, lr_scheduler, train_dataloader, num_epochs, "保存路径", hooks=hooks) 37 | trainer.load_checkpoint("latest.pth") 38 | trainer.train(1, 1) 39 | ``` 40 | 也可以只加载模型参数 41 | ```commandline 42 | trainer = Trainer(model, optimizer, lr_scheduler, train_dataloader, num_epochs, "保存路径", hooks=hooks) 43 | weights = torch.load("best.pth") 44 | trainer.load_checkpoint(checkpoint=weights) 45 | trainer.train(1, 1) 46 | ``` 47 | 48 | # 多卡训练 49 | 使用DDPTrainer训练,在单卡的基础上扩展了多卡训练的能力,以multiprocessing的方式举例 50 | ```commandline 51 | from torch.util.dataset import Dataset, DataLoader 52 | from torch.optim.lr_scheduler import MultiStepLR 53 | from torch_frame import DDPTrainer 54 | import torch.multiprocessing as mp 55 | import torch.distributed as dist 56 | 57 | def main(cur_gpu, args): 58 | if args.use_dist: 59 | rank = args.nprocs * args.machine_rank + cur_gpu 60 | dist.init_process_group(backend="nccl", init_method=args.url, world_size=args.world_size, rank=rank) 61 | args.train_batch_size = args.train_batch_size // args.world_size # 这里需要把一个batch平分到每个gpu的数量 62 | else: 63 | args.train_batch_size = args.batch_size 64 | 65 | # 创建网络相关对象 66 | model = get_model(conf) 67 | 68 | if args.use_dist: 69 | model = torcg,nn.SyncBatchNorn.convert_batchnorm(model) 70 | torch.cuda.set_device(cur_gpu) 71 | model.cuda(cur_gpu) 72 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[cur_gpu], find_unused_parameter=True) 73 | else: 74 | device = "cuda" if torch.cuda.is_available() else "cpu" 75 | model.to(device) 76 | optimizer = Adam(model.parameters(), lr) 77 | lr_scheduler = MultiStepLR(optimizer, ...) 78 | 79 | train_dataset = Dataset(...) 80 | train_params = dict(batch_size=32, ...) 81 | 82 | # 这里的hooks和单卡同理 83 | if cur_gpu == 0: 84 | hooks = [EvalHook(...), LoggerHook(...)] 85 | else: 86 | hooks = [] 87 | 88 | trainer = DDPtrainer(model, optimizer, lr_scheduler, train_dataset, train_params, num_epochs, hooks=hooks, 89 | use_dist=args.use_dist) 90 | 91 | 92 | if __name__ == "__name__": 93 | nprocs = torch.cuda.device_count() 94 | if args.use_dist: 95 | mp.spawn(main, nprocs=nprocs, args=(args, )) 96 | else: 97 | main(0, args) 98 | 99 | ``` 100 | 101 | # Accelerate加速训练 102 | 基于accelerate库的trainer做训练,支持多卡,相对于上述两种训练,推荐下面这种训练方式。相对于普通的Trainer只需要做少量修改即可运行 103 | ```commandline 104 | # 创建dataset和dataloader 105 | from torch.utils.data import Dataset, DataLoader 106 | from torch_frame import AccelerateTrainer 107 | from torch.optim import Adam 108 | 109 | train_dataset = Dataset(...) 110 | train_dataloader = DataLoader(...) 111 | 112 | # 创建网络相关对象 113 | model = get_model(config) 114 | optimizer = Adam(model.parameters(), lr) 115 | lr_scheduler = "constant" # 这里建议用字符串而不是类似MultiStepLR的scheduler类,具体和diffusers一致 116 | 117 | # 创建hooker,承载验证集部分和评估保存模型的任务 118 | hooks = [EvalHook(...), LoggerHook(...)] 119 | 120 | # 创建Trainer对象并开始训练,支持混合精度fp16/bp16 121 | trainer = AccelerateTrainer(model, optimizer, lr_scheduler, train_dataloader, num_epochs, "保存路径", mixed_precision="fp16", hooks=hooks) 122 | trainer.train() # 开始正式训练 123 | ``` -------------------------------------------------------------------------------- /torch_frame/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | 5 | class LRWarmupScheduler(_LRScheduler): 6 | """ 7 | 支持warmup的LR调节器。封装了LR scheduler, 可支持warmup, 每到iter结束后便调用。与pytorch不同的是, 该类的调用单位是iter 8 | 而非epoch 9 | 10 | .. code-block:: python 11 | :emphasize-lines: 15-18 12 | 13 | lr_scheduler = LRWarmupScheduler( 14 | StepLR(optimizer, step_size=10, gamma=0.1), 15 | epoch_len=9999, # 9999 iterations per epoch 16 | warmup_method="linear", 17 | warmup_iters=1000 18 | warmup_factor=0.001 19 | ) 20 | for epoch in range(start_epoch, max_epochs): 21 | for batch in data_loader: 22 | output = model(batch) 23 | loss = criterion(output, target) 24 | optimizer.zero_grad() 25 | loss.backward() 26 | optimizer.step() 27 | # should be here 28 | lr_scheduler.step() 29 | # not here 30 | # lr_scheduler.step() 31 | """ 32 | 33 | def __init__( 34 | self, 35 | scheduler: _LRScheduler, 36 | epoch_len: int, 37 | warmup_method: Optional[str] = None, 38 | warmup_iters: int = 1000, 39 | warmup_factor: float = 0.001, 40 | last_epoch: int = -1, 41 | ): 42 | """ 43 | 44 | Parameters 45 | ---------- 46 | scheduler : torch.optim.lr_scheduler._LRScheduler. 一个pytorch的标准调节器 47 | epoch_len : int. 一个epoch的长度, 用于在epoch结束时调用LR调节器 48 | warmup_method : str, default None. 49 | warmup的类型 50 | * constant, 常量 51 | * linear, 线性 52 | * exp, 指数 53 | * None, 不使用warmup 54 | warmup_iters : int, default 1000. warmup的总共iter数 55 | warmup_factor : float, default 0.001. warmup最初的学习率 = warmup_factor * 初始学习率 56 | last_epoch : int, default -1. 57 | """ 58 | self.scheduler = scheduler 59 | self.epoch_len = epoch_len 60 | self.warmup_method = warmup_method 61 | self.warmup_iters = warmup_iters 62 | self.warmup_factor = warmup_factor 63 | 64 | if self._enable_warmup(): 65 | assert warmup_method in ["constant", "linear", "exp"], ( 66 | f"'{warmup_method}' is not a supported type for warmup, " 67 | "valid types are 'constant', 'linear' or 'exp'" 68 | ) 69 | assert callable( 70 | getattr(scheduler, "_get_closed_form_lr", None) 71 | ), "`scheduler` must implement `_get_closed_form_lr()` method" 72 | assert warmup_iters > 0, "'warmup_iters' must be a positive integer" 73 | assert 0 < warmup_factor <= 1.0, "'warmup_ratio' must be in range (0,1]" 74 | 75 | # expected lr if no warming up is performed 76 | self.regular_lrs = scheduler.get_last_lr() 77 | 78 | super().__init__(scheduler.optimizer, last_epoch) 79 | 80 | def _enable_warmup(self) -> bool: 81 | return self.warmup_method is not None 82 | 83 | def _reach_epoch_end(self) -> bool: 84 | return self.last_epoch and self.last_epoch % self.epoch_len == 0 85 | 86 | def _get_warmup_factor(self) -> float: 87 | # `self.last_epoch` should be understood as `self.last_iter` 88 | if not self._enable_warmup() or self.last_epoch >= self.warmup_iters: 89 | return 1.0 90 | 91 | alpha = self.last_epoch / self.warmup_iters 92 | if self.warmup_method == "constant": 93 | return self.warmup_factor 94 | elif self.warmup_method == "linear": 95 | return self.warmup_factor * (1 - alpha) + alpha 96 | else: 97 | return self.warmup_factor ** (1 - alpha) 98 | 99 | def get_lr(self) -> List[float]: 100 | warmup_factor = self._get_warmup_factor() 101 | if self._reach_epoch_end(): 102 | # `self.scheduler.last_epoch` is really the last epoch 103 | self.scheduler.last_epoch += 1 104 | self.regular_lrs = self.scheduler._get_closed_form_lr() 105 | return [warmup_factor * lr for lr in self.regular_lrs] 106 | 107 | def state_dict(self): 108 | state = { 109 | key: value 110 | for key, value in self.__dict__.items() 111 | if key != "optimizer" and key != "scheduler" 112 | } 113 | state["scheduler_state_dict"] = self.scheduler.state_dict() 114 | return state 115 | 116 | def load_state_dict(self, state_dict): 117 | self.scheduler.load_state_dict(state_dict.pop("scheduler_state_dict")) 118 | self.__dict__.update(state_dict) 119 | -------------------------------------------------------------------------------- /torch_frame/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | from typing import Optional 5 | from .dist_utils import get_rank 6 | from termcolor import colored 7 | 8 | logger_initialized = {} 9 | 10 | 11 | class _ColorfulFormatter(logging.Formatter): 12 | def formatMessage(self, record): 13 | log = super(_ColorfulFormatter, self).formatMessage(record) 14 | if record.levelno == logging.DEBUG: 15 | prefix = colored("DEBUG", "magenta") 16 | elif record.levelno == logging.WARNING: 17 | prefix = colored("WARNING", "red", attrs=["blink"]) 18 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 19 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 20 | else: 21 | return log 22 | return prefix + " " + log 23 | 24 | 25 | def setup_logger( 26 | name: Optional[str] = None, 27 | output: Optional[str] = None, 28 | console_log_level: int = logging.INFO, 29 | file_log_level: int = logging.INFO, 30 | color: bool = False, 31 | ) -> logging.Logger: 32 | """ 33 | 初始化 logger 34 | 35 | 如果logger没被初始化, 这个函数会使用1~2个handlers来初始化logger。否则这个已经初始化过的logger会直接被返回。 36 | 在初始化时只有主程序的logger会添加handlers,以:class:`StreamHandler`的形式添加。 37 | 如果output被赋值, :class:`FileHandler` 也会被添加 38 | 39 | 这里是常用的用法. 我们假设文件结构如下:: 40 | 41 | project 42 | ├── module1 43 | └── module2 44 | 45 | - Only setup the parent logger (``project``), then all children loggers 46 | (``project.module1`` and ``project.module2``) will use the handlers of the parent logger. 47 | 48 | Example:: 49 | 50 | >>> setup_logger(name="project") 51 | >>> logging.getLogger("project.module1") 52 | >>> logging.getLogger("project.module2") 53 | 54 | - Only setup the root logger, then all loggers will use the handlers of the root logger. 55 | 56 | Example:: 57 | 58 | >>> setup_logger() 59 | >>> logging.getLogger(name="project") 60 | >>> logging.getLogger(name="project.module1") 61 | >>> logging.getLogger(name="project.module2") 62 | 63 | - Setup all loggers, each logger uses independent handlers. 64 | 65 | Example:: 66 | 67 | >>> setup_logger(name="project") 68 | >>> setup_logger(name="project.module1") 69 | >>> setup_logger(name="project.module2") 70 | 71 | Parameters 72 | ---------- 73 | name : str, default None 74 | Logger 名字。 75 | output : str, default None 76 | 一个保存log的文件名或目录名 77 | * None: 不会保存log到文件 78 | * 后缀带.txt或.log : 将其设置成文件名 79 | * other : 文件名变为 output/log.txt 80 | console_log_level : int, default logging.INFO 81 | logger输出到控制台/终端的等级 82 | file_log_level : int, default logging.INFO 83 | logger输出到文件的等级 84 | color : bool, default False 85 | 如果是True,logger将会有颜色 86 | 87 | Returns 88 | ------- 89 | logging.Logger: 一个已初始化的logger 90 | """ 91 | if name in logger_initialized: 92 | return logger_initialized[name] 93 | 94 | # get root logger if name is None 95 | logger = logging.getLogger(name) 96 | logger.setLevel(console_log_level) 97 | # the messages of this logger are not propagated to its parent 98 | logger.propagate = False 99 | 100 | plain_formatter = logging.Formatter( 101 | "[%(asctime)s %(name)s %(levelname)s]: %(message)s", datefmt="%m/%d %H:%M:%S" 102 | ) 103 | 104 | # stdout and file logging: master only 105 | if get_rank() == 0: 106 | ch = logging.StreamHandler(stream=sys.stdout) 107 | ch.setLevel(console_log_level) 108 | if color: 109 | formatter = _ColorfulFormatter( 110 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 111 | datefmt="%m/%d %H:%M:%S", 112 | ) 113 | else: 114 | formatter = plain_formatter 115 | ch.setFormatter(formatter) 116 | logger.addHandler(ch) 117 | 118 | if output is not None: 119 | if output.endswith(".txt") or output.endswith(".log"): 120 | filename = output 121 | else: 122 | filename = os.path.join(output, "log.txt") 123 | # If a single file name is passed as argument, os.path.dirname() will return an empty 124 | # string. For example, os.path.dirname("log.txt") == "". This will cause an error 125 | # in os.makedirs(). So we need to wrap `filename` with os.path.abspath(). 126 | os.makedirs(os.path.dirname(os.path.abspath(filename)), exist_ok=True) 127 | 128 | fh = logging.FileHandler(filename) 129 | fh.setLevel(file_log_level) 130 | fh.setFormatter(plain_formatter) 131 | logger.addHandler(fh) 132 | 133 | logger_initialized[name] = logger 134 | return logger 135 | -------------------------------------------------------------------------------- /torch_frame/hooks/eval_hook.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional 2 | from tqdm import tqdm 3 | import torch 4 | from torch.utils.data.dataloader import DataLoader 5 | from .checkpoint_hook import CheckpointerHook 6 | import numpy as np 7 | 8 | 9 | class EvalHook(CheckpointerHook): 10 | """`CheckpointerHook` 的派生类, 周期性执行的评估器, 在每个epoch的最后阶段执行""" 11 | 12 | def __init__(self, 13 | dataloader: DataLoader, 14 | eval_func: Callable, 15 | period: int = 1, 16 | max_to_keep: Optional[int] = None, 17 | save_metric: Optional[str] = None, 18 | max_first: bool = True, 19 | save_last: bool = True, 20 | prefix: str = "eval" 21 | ): 22 | """ 23 | 24 | Parameters 25 | ---------- 26 | dataloader : DataLoader. 27 | 测试数据的dataloader 28 | eval_func : Callable. 29 | 一个函数, 输入参数是model和batch, 返回一个评估结果的Dict[list], k对应指标名称, v是包含每个样本得分的list 30 | period : int, default 1. 31 | 执行eval_func函数的周期 32 | max_to_keep : int, 保存checkpoints的数量, 更早期的checkpoints会被删除 33 | save_metric : int, default None. 34 | 保存模型的指标是哪个, 需要从trainer.metric_storage选择 35 | max_first : bool, default True. 36 | 用于保存模型的指标是取最大还是最小作为最优模型 37 | save_last : bool, default True 38 | 是否保存最近一次的epoch的模型, 如果是True, 每轮将更新模型到latest.pth中 39 | """ 40 | self.prefix = prefix+"_" 41 | super(EvalHook, self).__init__(period, max_to_keep, self.prefix + save_metric, max_first, save_last) 42 | self._eval_func = eval_func 43 | self.dataloader = dataloader 44 | 45 | @torch.no_grad() 46 | def _do_eval(self): 47 | tot_res = {} 48 | mode_model = self.trainer.model_evaluate.training 49 | self.trainer.model_evaluate.eval() 50 | with tqdm(self.dataloader, desc="eval") as pbar: 51 | for batch in pbar: 52 | res = self._eval_func(self.trainer.model_evaluate, batch) 53 | for k, v in res.items(): 54 | tot_res.setdefault(k, []).extend(v) 55 | self.trainer.model_evaluate.train(mode_model) 56 | if tot_res: 57 | rename_res = {self.prefix + k: np.mean(v) for k, v in tot_res.items()} 58 | self.log(self.trainer.epoch, **rename_res, smooth=False, window_size=1) 59 | 60 | def after_epoch(self): 61 | if self.every_n_epochs(self._period) or self.is_last_epoch(): 62 | self._do_eval() 63 | self.save_model() 64 | 65 | 66 | class EvalTotalHook(CheckpointerHook): 67 | """ 68 | `CheckpointerHook` 的派生类, 周期性执行的评估器, 在每个epoch的最后阶段执行. 69 | 与`EvalHook`评估器区别是: 一个是每个batch都去评估, 最后求每次评估结果的均值; 一个是先把batch结果存下来, 最后一起评估 70 | """ 71 | 72 | def __init__(self, 73 | dataloader: DataLoader, 74 | eval_metric: object, 75 | period: int = 1, 76 | max_to_keep: Optional[int] = None, 77 | save_metric: Optional[str] = None, 78 | max_first: bool = True, 79 | save_last: bool = True, 80 | prefix: str = "eval" 81 | ): 82 | """ 83 | 84 | Parameters 85 | ---------- 86 | dataloader : DataLoader. 87 | 测试数据的dataloader 88 | eval_metric : object. 89 | 一个对象, 需要包含`update`和`evaluate`方法. 其中: 90 | * `update`方法, 需要输入参数是model和batch, 无返回值, 建议在内部存储当前batch的值 91 | * `evaluate`方法, 无形参, return的是一个Dict[str, float]类型的评估结果 92 | period : int, default 1. 93 | 执行eval_func函数的周期 94 | max_to_keep : int, 保存checkpoints的数量, 更早期的checkpoints会被删除 95 | save_metric : int, default None. 96 | 保存模型的指标是哪个, 需要从trainer.metric_storage选择 97 | max_first : bool, default True. 98 | 用于保存模型的指标是取最大还是最小作为最优模型 99 | save_last : bool, default True 100 | 是否保存最近一次的epoch的模型, 如果是True, 每轮将更新模型到latest.pth中 101 | """ 102 | self.prefix = prefix+"_" 103 | super(EvalTotalHook, self).__init__(period, max_to_keep, self.prefix + save_metric, max_first, save_last) 104 | assert hasattr(eval_metric, "update") and isinstance(getattr(eval_metric, "update"), Callable) 105 | assert hasattr(eval_metric, "evaluate") and isinstance(getattr(eval_metric, "evaluate"), Callable) 106 | self._eval_metric = eval_metric 107 | self.dataloader = dataloader 108 | 109 | @torch.no_grad() 110 | def _do_eval(self): 111 | mode_model = self.trainer.model_evaluate.training 112 | self.trainer.model_evaluate.eval() 113 | with tqdm(self.dataloader, desc="eval") as pbar: 114 | for batch in pbar: 115 | self._eval_metric.update(self.trainer.model_evaluate, batch) 116 | self.trainer.model_evaluate.train(mode_model) 117 | tot_res = self._eval_metric.evaluate() 118 | self._eval_metric.reset() 119 | if tot_res: 120 | rename_res = {self.prefix + k: np.mean(v) for k, v in tot_res.items()} 121 | self.log(self.trainer.epoch, **rename_res, smooth=False, window_size=1) 122 | 123 | def after_epoch(self): 124 | if self.every_n_epochs(self._period) or self.is_last_epoch(): 125 | self._do_eval() 126 | self.save_model() -------------------------------------------------------------------------------- /torch_frame/hooks/logger_hook.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | from typing import Dict, Optional 5 | 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | from .hookbase import HookBase 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class LoggerHook(HookBase): 15 | """写入评估指标到控制台和tensorboard""" 16 | 17 | def __init__(self, period: int = 50, tb_log_dir: Optional[str] = None, modes: Optional[list] = None, **kwargs) -> None: 18 | """ 19 | Parameters 20 | ---------- 21 | period : int, default 50. 写入的周期 22 | tb_log_dir : str, default None. 如果没有特别指定的话,日志默认写到trainer设置的目录 23 | modes : list, default None 24 | 通过指标的关键字来判定输出哪类数据 25 | kwargs : torch.utils.tensorboard.SummaryWriter的其他参数 26 | """ 27 | self._period = period 28 | self.kwargs = kwargs 29 | # metric name -> the latest iteration written to tensorboard file 30 | self._last_write: Dict[str, int] = {} 31 | 32 | if modes is None: 33 | modes = ["train", "eval"] 34 | if "train" not in modes: 35 | modes.insert(0, "train") 36 | self.modes = {m + "_" for m in modes} 37 | self.tb_log_dir = tb_log_dir 38 | 39 | def before_train(self) -> None: 40 | self._train_start_time = time.perf_counter() 41 | if self.tb_log_dir is None: 42 | self.tb_log_dir = self.trainer.work_dir 43 | self._tb_writer = SummaryWriter(self.tb_log_dir, **self.kwargs) 44 | 45 | def after_train(self) -> None: 46 | self._tb_writer.close() 47 | total_train_time = time.perf_counter() - self._train_start_time 48 | total_hook_time = total_train_time - self.metric_storage["iter_time"].global_sum 49 | logger.info( 50 | "Total train time: {} ({} on hooks)".format( 51 | str(datetime.timedelta(seconds=int(total_train_time))), 52 | str(datetime.timedelta(seconds=int(total_hook_time))), 53 | ) 54 | ) 55 | 56 | def after_epoch(self) -> None: 57 | self._write_console() 58 | self._write_tensorboard() 59 | 60 | def _write_console(self) -> None: 61 | # These fields ("data_time", "iter_time", "lr", "loss") may does not 62 | # exist when user overwrites `self.trainer.train_one_iter()` 63 | data_time = self.metric_storage["data_time"].avg if "data_time" in self.metric_storage else None 64 | iter_time = self.metric_storage["iter_time"].avg if "iter_time" in self.metric_storage else None 65 | lr = self.metric_storage["lr"].latest if "lr" in self.metric_storage else None 66 | 67 | if iter_time is not None: 68 | eta_seconds = iter_time * (self.trainer.max_iters - self.trainer.cur_iter - 1) 69 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 70 | else: 71 | eta_string = None 72 | 73 | if torch.cuda.is_available(): 74 | max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 75 | else: 76 | max_mem_mb = None 77 | 78 | exclude = ("data_time", "iter_time", "lr") 79 | keys_dict = {mode: set() for mode in self.modes} 80 | for key in self.metric_storage: 81 | if key in exclude: 82 | continue 83 | for mode in self.modes: 84 | if key.startswith(mode): 85 | keys_dict[mode].add(key) 86 | break 87 | else: 88 | keys_dict["train_"].add(key) 89 | 90 | process_string = f"Epoch: [{self.trainer.epoch}][{self.trainer.inner_iter}/{self.trainer.epoch_len - 1}]" 91 | 92 | space = " " * 2 93 | logger.info("----------") 94 | logger.info( 95 | "{process}{eta}{iter_time}{data_time}{lr}{memory}".format( 96 | process=process_string, 97 | eta=space + f"ETA: {eta_string}" if eta_string is not None else "", 98 | iter_time=space + f"iter_time: {iter_time:.4f}" if iter_time is not None else "", 99 | data_time=space + f"data_time: {data_time:.4f} " if data_time is not None else "", 100 | lr=space + f"lr: {lr:.5g}" if lr is not None else "", 101 | memory=space + f"max_mem: {max_mem_mb:.0f}M" if max_mem_mb is not None else "", 102 | ) 103 | ) 104 | 105 | for keys in keys_dict.values(): 106 | key_list = sorted(list(keys), key=lambda x: "total" not in x) 107 | info = " ".join([f"{k}: {self.metric_storage[k]}" for k in key_list]) 108 | if info == "": 109 | continue 110 | logger.info(info) 111 | 112 | def _write_tensorboard(self) -> None: 113 | for key, (iter, value) in self.metric_storage.values_maybe_smooth.items(): 114 | if key not in self._last_write or iter > self._last_write[key]: 115 | for mode in self.modes: 116 | if key.startswith(mode): 117 | key = f"{mode.strip('_')}/{key}" 118 | break 119 | else: 120 | key = f"train/{key}" 121 | self._tb_writer.add_scalar(key, value, iter) 122 | self._last_write[key] = iter 123 | 124 | def after_iter(self) -> None: 125 | if self.every_n_inner_iters(self._period): 126 | # self._write_console() 127 | self._write_tensorboard() 128 | -------------------------------------------------------------------------------- /torch_frame/utils/misc.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import sys 5 | from collections import defaultdict 6 | from typing import Any, Dict 7 | 8 | import datetime 9 | import numpy as np 10 | import torch 11 | from tabulate import tabulate 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | __all__ = [ 16 | "set_random_seed", 17 | "collect_env", 18 | "symlink", 19 | "create_small_table", 20 | "get_workspace" 21 | ] 22 | 23 | 24 | def collect_env() -> str: 25 | """Collect the information of the running environments. 26 | 27 | The following information are contained. 28 | 29 | - sys.platform: The variable of ``sys.platform``. 30 | - Python: Python version. 31 | - Numpy: Numpy version. 32 | - CUDA available: Bool, indicating if CUDA is available. 33 | - GPU devices: Device type of each GPU. 34 | - PyTorch: PyTorch version. 35 | - PyTorch compiling details: The output of ``torch.__config__.show()``. 36 | - TorchVision (optional): TorchVision version. 37 | - OpenCV (optional): OpenCV version. 38 | 39 | Returns: 40 | str: A string describing the running environment. 41 | """ 42 | env_info = [] 43 | env_info.append(("sys.platform", sys.platform)) 44 | env_info.append(("Python", sys.version.replace("\n", ""))) 45 | env_info.append(("Numpy", np.__version__)) 46 | 47 | cuda_available = torch.cuda.is_available() 48 | env_info.append(("CUDA available", cuda_available)) 49 | 50 | if cuda_available: 51 | devices = defaultdict(list) 52 | for k in range(torch.cuda.device_count()): 53 | devices[torch.cuda.get_device_name(k)].append(str(k)) 54 | for name, device_ids in devices.items(): 55 | env_info.append(("GPU " + ",".join(device_ids), name)) 56 | 57 | env_info.append(("PyTorch", torch.__version__)) 58 | 59 | try: 60 | import torchvision 61 | 62 | env_info.append(("TorchVision", torchvision.__version__)) 63 | except ModuleNotFoundError: 64 | pass 65 | 66 | try: 67 | import cv2 68 | 69 | env_info.append(("OpenCV", cv2.__version__)) 70 | except ModuleNotFoundError: 71 | pass 72 | 73 | torch_config = torch.__config__.show() 74 | env_str = tabulate(env_info) + "\n" + torch_config 75 | return env_str 76 | 77 | 78 | def set_random_seed(seed: int, rank: int = 0) -> None: 79 | """Set random seed. 80 | 81 | Args: 82 | seed (int): Nonnegative integer. 83 | rank (int): Process rank in the distributed training. Defaults to 0. 84 | """ 85 | assert seed >= 0, f"Got invalid seed value {seed}." 86 | seed += rank 87 | torch.manual_seed(seed) 88 | torch.cuda.manual_seed(seed) 89 | torch.cuda.manual_seed_all(seed) 90 | np.random.seed(seed) # Numpy module. 91 | random.seed(seed) # Python random module. 92 | torch.backends.cudnn.benchmark = False # 保证每次卷积的算子都是固定的,而非使用最高效的方法 93 | torch.backends.cudnn.deterministic = True 94 | 95 | os.environ["PYTHONHASHSEED"] = str(seed) 96 | 97 | 98 | def symlink(src: str, dst: str, overwrite: bool = True, **kwargs) -> None: 99 | """Create a symlink, dst -> src. 100 | 101 | Args: 102 | src (str): Path to source. 103 | dst (str): Path to target. 104 | overwrite (bool): If True, remove existed target. Defaults to True. 105 | """ 106 | if os.path.lexists(dst) and overwrite: 107 | os.remove(dst) 108 | os.symlink(src, dst, **kwargs) 109 | 110 | 111 | def create_small_table(small_dict: Dict[str, Any]) -> str: 112 | """Create a small table using the keys of ``small_dict`` as headers. 113 | This is only suitable for small dictionaries. 114 | 115 | Args: 116 | small_dict (dict): A result dictionary of only a few items. 117 | 118 | Returns: 119 | str: The table as a string. 120 | """ 121 | keys, values = tuple(zip(*small_dict.items())) 122 | table = tabulate( 123 | [values], 124 | headers=keys, 125 | tablefmt="pipe", 126 | floatfmt=".3f", 127 | stralign="center", 128 | numalign="center", 129 | ) 130 | return table 131 | 132 | 133 | def get_workspace(work_dir, create_new_dir): 134 | if create_new_dir not in (None, "time_s", "time_m", "time_h", "time_d", "count"): 135 | logger.warning("create_new_dir参数输入错误, 使用`time_s`为其赋值") 136 | create_new_dir = "time_s" 137 | if os.path.exists(work_dir): 138 | if create_new_dir == "time_s": 139 | now = datetime.datetime.now() 140 | now_format = now.strftime("%Y-%m-%d %H_%M_%S") 141 | work_dir = f"{work_dir}_{now_format}" 142 | elif create_new_dir == "time_m": 143 | now = datetime.datetime.now() 144 | now_format = now.strftime("%Y-%m-%d %H_%M") 145 | work_dir = f"{work_dir}_{now_format}" 146 | elif create_new_dir == "time_h": 147 | now = datetime.datetime.now() 148 | now_format = now.strftime("%Y-%m-%d %H") 149 | work_dir = f"{work_dir}_{now_format}" 150 | elif create_new_dir == "time_d": 151 | now = datetime.datetime.now() 152 | now_format = now.strftime("%Y-%m-%d") 153 | work_dir = f"{work_dir}_{now_format}" 154 | elif create_new_dir == "count": 155 | for i in range(10000): 156 | if not os.path.exists(f"{work_dir}_{i}"): 157 | break 158 | work_dir = f"{work_dir}_{i}" 159 | return work_dir 160 | -------------------------------------------------------------------------------- /torch_frame/utils/metric.py: -------------------------------------------------------------------------------- 1 | """ 2 | 该脚本包含主流评估方法, 区别于history_buffer, 没有窗口大小的概念, 换言之不需要做平滑之类的操作, 而求所有数据的精准指标 3 | """ 4 | from abc import abstractmethod 5 | from functools import partial 6 | from typing import Dict, Optional, Callable 7 | 8 | import numpy as np 9 | 10 | from ..vision import det_postprocess, eval_map 11 | 12 | 13 | class BaseMetric: 14 | """评估器基类, 继承它需要实现`update`, `evaluate`和`reset`方法""" 15 | 16 | @abstractmethod 17 | def update(self, *args, **kwargs): 18 | ... 19 | 20 | @abstractmethod 21 | def evaluate(self): 22 | ... 23 | 24 | @abstractmethod 25 | def reset(self): 26 | ... 27 | 28 | 29 | class ModelMetric(BaseMetric): 30 | """ 31 | 模型评估器将update拆分成模型推理的inference和结果收集步骤. 32 | 继承它需要实现`inference`, `evaluate`和`reset`方法 33 | """ 34 | 35 | def update(self, model, inputs): 36 | result = self.inference(model, inputs) 37 | self.collection(result) 38 | 39 | @abstractmethod 40 | def inference(self, model, inputs): 41 | ... 42 | 43 | @abstractmethod 44 | def collection(self, item): 45 | ... 46 | 47 | @abstractmethod 48 | def evaluate(self): 49 | ... 50 | 51 | @abstractmethod 52 | def reset(self): 53 | ... 54 | 55 | 56 | class ObjDetMAPMetric(ModelMetric): 57 | """ 58 | 目标检测中的AP评估器 59 | """ 60 | 61 | def __init__(self, 62 | class2idx: Dict[str, int], 63 | is_xyxy: bool = False, 64 | ovthresh: float = 0.5, 65 | postprocess: Optional[Callable] = None, 66 | confthre: float = 0.7, 67 | nmsthre: float = 0.45, 68 | ): 69 | """ 70 | Parameters 71 | ---------- 72 | class2idx : Dict[str, int]. 类别到索引的映射, 不包含背景类 73 | is_xyxy: bool, default False. gt框类型 74 | * False. gt目标框是[cx, cy, w, h]类型 75 | * True. gt目标框是[x1, y1, x2, y2]类型 76 | ovthresh : float, default 0.5. 计算AP时,pred和gt被认为是TP(true positive)的iou阈值 77 | postprocess : Optional[Callable], default None. 承接模型推理后的后处理 78 | * None, 则默认Yolo风格的内置函数det_postprocess 79 | * Callable, 返回是一个list或np.ndarray. 每一行是[x1, y1, x2, y2, score, cls] 80 | confthre : float, default 0.7. 如果postprocess使用默认值, 则为det_postprocess的对应参数, postprocess非空不生效! 81 | nmsthre : float, default 0.45. 如果postprocess使用默认值, 则为det_postprocess的对应参数, postprocess非空不生效! 82 | """ 83 | self.class2idx = class2idx 84 | self.is_xyxy = is_xyxy 85 | self.ovthresh = ovthresh 86 | self.idx2class = {cls: idx for cls, idx in class2idx.items()} 87 | if postprocess is not None: 88 | self.postprocess = postprocess 89 | else: 90 | self.postprocess = partial(det_postprocess, 91 | num_classes=len(self.class2idx), 92 | conf_thre=confthre, 93 | nms_thre=nmsthre, 94 | class_agnostic=True, 95 | merge_score=True, 96 | ret_np=True 97 | ) 98 | self.all_gt_boxes = {cls: [] for cls in self.class2idx} 99 | self.all_pred_boxes = {cls: [] for cls in self.class2idx} 100 | 101 | def inference(self, model, inputs): 102 | image, gt = inputs 103 | 104 | outputs = model(image) 105 | return outputs, gt 106 | 107 | def collection(self, item): 108 | outputs, gt = item 109 | outputs = self.postprocess(outputs) 110 | for i, one_image_boxes in enumerate(outputs): 111 | if one_image_boxes is None: 112 | for boxes in self.all_pred_boxes.values(): 113 | boxes.append({}) 114 | continue 115 | for cls_name, cls_idx in self.class2idx.items(): 116 | mask = one_image_boxes[..., -1] == cls_idx 117 | if np.any(mask) == 0: 118 | self.all_pred_boxes[cls_name].append({}) 119 | else: 120 | image_cls_boxes = { 121 | "confidence": one_image_boxes[mask][:, -2], 122 | "bboxes": one_image_boxes[mask][:, :4], 123 | } 124 | self.all_pred_boxes[cls_name].append(image_cls_boxes) 125 | for i, one_image_boxes in enumerate(gt): 126 | one_image_boxes = one_image_boxes[one_image_boxes.sum(-1) > 0].numpy() 127 | if not self.is_xyxy: 128 | cxcywh = one_image_boxes[:, 1:] 129 | one_image_boxes = np.stack([ 130 | one_image_boxes[:, 0], 131 | cxcywh[:, 0] - cxcywh[:, 2] / 2, 132 | cxcywh[:, 1] - cxcywh[:, 3] / 2, 133 | cxcywh[:, 0] + cxcywh[:, 2] / 2, 134 | cxcywh[:, 1] + cxcywh[:, 3] / 2, 135 | ], -1) 136 | for cls_name, cls_idx in self.class2idx.items(): 137 | mask = one_image_boxes[..., 0] == cls_idx 138 | if np.any(mask) == 0: 139 | self.all_gt_boxes[cls_name].append({}) 140 | else: 141 | image_cls_boxes = { 142 | "labels": one_image_boxes[mask][:, 0], 143 | "bboxes": one_image_boxes[mask][:, 1:], 144 | } 145 | self.all_gt_boxes[cls_name].append(image_cls_boxes) 146 | 147 | def evaluate(self): 148 | aps = eval_map(self.all_gt_boxes, self.all_pred_boxes, self.ovthresh) 149 | new_aps = {f"ap_{k}": v for k, v in aps.items()} 150 | new_aps["map"] = np.mean(list(new_aps.values())) 151 | return new_aps 152 | 153 | def reset(self): 154 | self.all_gt_boxes = {cls: [] for cls in self.class2idx} 155 | self.all_pred_boxes = {cls: [] for cls in self.class2idx} 156 | -------------------------------------------------------------------------------- /torch_frame/accelerate_trainer.py: -------------------------------------------------------------------------------- 1 | from accelerate import Accelerator 2 | import time 3 | from accelerate.utils import ProjectConfiguration 4 | import warnings 5 | from torch import nn 6 | import torch 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from typing import Optional, Union, List, Dict, Any 10 | from .trainer import Trainer, MetricStorage, EMA, ProgressBar 11 | from .hooks import HookBase 12 | from .utils.misc import get_workspace 13 | from .lr_scheduler import LRWarmupScheduler 14 | 15 | 16 | class AccelerateTrainer(Trainer): 17 | """ 18 | accelerate版的Trainer,底层以来accelerate库 19 | 其他参数都和Trainer一致, 除了以下参数 20 | Parameters 21 | --------- 22 | lr_scheduler: Union[str, optim.lr_scheduler._LRScheduler]. 会有以下差别: 23 | 增加str类型,这时会直接调用diffusers内部的lr_scheduler, 这里建议使用str类型 24 | 支持["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] 25 | mixed_precision: str, default None 26 | * "no", 不使用amp技术 27 | * "fp16" 28 | * "bf16", 30系及以后的卡型才能使用 29 | gradient_accumulation_steps: int, default 1. 梯度累计数,显存不够用又需要大batch的时候可以加大数值 30 | hook_only_main_gpu: bool, default True. 多卡时是否只有主进程使用hooks(非主进程是空list) 31 | accelerator: Accelerator, default None. 如果需要可以自定义accelerator然后传进Trainer,比如使用deepspeed-ZeRo3训练 32 | """ 33 | 34 | def __init__( 35 | self, 36 | model: nn.Module, 37 | optimizer: optim.Optimizer, 38 | lr_scheduler: Union[str, optim.lr_scheduler._LRScheduler], 39 | data_loader: DataLoader, 40 | max_epochs: int, 41 | work_dir: str = "work_dir", 42 | clip_grad_norm: float = 0.0, 43 | mixed_precision: str = None, 44 | warmup_method: Optional[str] = None, 45 | warmup_iters: int = 0, 46 | warmup_factor: float = 0.001, 47 | hooks: Optional[List[HookBase]] = None, 48 | use_ema: bool = False, 49 | ema_decay: float = 0.9999, 50 | gradient_accumulation_steps: int = 1, 51 | hook_only_main_gpu: bool = True, 52 | create_new_dir: Optional[str] = "time_s", 53 | accelerator: Accelerator = None, 54 | ): 55 | self.work_dir = get_workspace(work_dir, create_new_dir) 56 | if accelerator is None: 57 | accelerator_project_config = ProjectConfiguration(project_dir=self.work_dir, logging_dir=self.work_dir) 58 | self.accelerator = Accelerator( 59 | mixed_precision=mixed_precision, 60 | project_config=accelerator_project_config, 61 | gradient_accumulation_steps=gradient_accumulation_steps 62 | ) 63 | elif isinstance(accelerator, Accelerator): 64 | self.accelerator = accelerator 65 | else: 66 | raise ValueError("accelerator传入类型不对!") 67 | 68 | self.weight_dtype = torch.float32 69 | if self.accelerator.mixed_precision == "fp16": 70 | self.weight_dtype = torch.float16 71 | elif self.accelerator.mixed_precision == "bf16": 72 | self.weight_dtype = torch.bfloat16 73 | 74 | self.model = model 75 | if use_ema: 76 | warnings.warn("accelerate版暂不支持ema, 敬请期待") 77 | self.model_ema = EMA(self.model_or_module, ema_decay) 78 | else: 79 | self.model_ema = None 80 | self.optimizer = optimizer 81 | # convert epoch-based scheduler to iteration-based scheduler 82 | if isinstance(lr_scheduler, str): 83 | from diffusers.optimization import get_scheduler 84 | if warmup_method is not None: 85 | self.logger_print("当`lr_scheduler`输入str类型时, `warmup_method`参数失效", warnings.warn) 86 | max_train_steps = max_epochs * len(data_loader) 87 | self.lr_scheduler = get_scheduler( 88 | lr_scheduler, 89 | optimizer=optimizer, 90 | num_warmup_steps=warmup_iters * self.accelerator.num_processes, 91 | num_training_steps=max_train_steps * self.accelerator.num_processes, 92 | ) 93 | else: 94 | self.lr_scheduler = LRWarmupScheduler( 95 | lr_scheduler, len(data_loader), warmup_method, warmup_iters, warmup_factor 96 | ) 97 | self.data_loader = data_loader 98 | self.metric_storage = MetricStorage() 99 | 100 | # counters 101 | self.inner_iter: int = -1 # [0, epoch_len - 1] 102 | self.epoch: int = -1 # [0, max_epochs - 1] 103 | self.start_epoch = 0 # [0, max_epochs - 1] 104 | self.max_epochs = max_epochs 105 | 106 | self._hooks: List[HookBase] = [] 107 | self._clip_grad_norm = clip_grad_norm 108 | 109 | if not hook_only_main_gpu or self.accelerator.is_main_process: 110 | if hooks is None: 111 | hooks = self._build_default_hooks() 112 | self.register_hooks(hooks) 113 | 114 | self.info_params = [] 115 | 116 | def prepare_model(self): 117 | """如果有多个模型的话可以在这里重写方法""" 118 | self.model, self.optimizer, self.data_loader, self.lr_scheduler = self.accelerator.prepare( 119 | self.model, self.optimizer, self.data_loader, self.lr_scheduler 120 | ) 121 | 122 | def _prepare_for_training(self, 123 | console_log_level: int = 2, 124 | file_log_level: int = 2) -> None: 125 | """ 126 | 训练前的配置工作 127 | Parameters 128 | ---------- 129 | console_log_level : int, default 2 130 | 输出到屏幕的log等级, 可选范围是0-5, 它们对应的关系分别为: 131 | * 5: FATAL 132 | * 4: ERROR 133 | * 3: WARNING 134 | * 2: INFO 135 | * 1: DEBUG 136 | * 0: NOTSET 137 | file_log_level : int, default 2 138 | 输出到文件里的log等级, 其他方面同console_log_level参数 139 | """ 140 | self.prepare_model() # 此处和原版trainer不同!!! 141 | super()._prepare_for_training(console_log_level, file_log_level) 142 | 143 | def _train_one_epoch(self) -> None: 144 | """执行模型一个epoch的全部操作""" 145 | self.accelerator.wait_for_everyone() 146 | self.model.train() 147 | self.pbar = ProgressBar(total=self.epoch_len, desc=f"epoch={self.epoch}", ascii=True) 148 | 149 | start_time_data = time.perf_counter() 150 | for self.inner_iter, batch in enumerate(self.data_loader): 151 | start_time_iter = time.perf_counter() 152 | data_time = start_time_iter - start_time_data 153 | self._call_hooks("before_iter") 154 | show_info = self.train_one_iter(batch) 155 | self._call_hooks("after_iter") 156 | self._update_iter_metrics(show_info, data_time, time.perf_counter() - start_time_data, self.lr) 157 | self.pbar.update(1) 158 | start_time_data = time.perf_counter() 159 | self.pbar.close() 160 | del self.pbar 161 | 162 | def train_one_iter(self, batch) -> dict: 163 | """ 164 | 包含了accelerate版训练的一个iter的全部操作 165 | 166 | .. Note:: 167 | 标准的学习率调节器是基于epoch的, 但torch_frame框架是基于iter的, 所以它在每次iter之后都会调用 168 | """ 169 | # 这里貌似只支持一个模型, 如果类似SD的训练任务可能需要把text_encoder和unet都包在一个模型里 170 | with self.accelerator.accumulate(self.model): 171 | ##################### 172 | # 1. 计算loss # 173 | ##################### 174 | loss_info = self.model(batch) 175 | if isinstance(loss_info, torch.Tensor): 176 | losses = loss_info 177 | loss_info = {"total_loss": loss_info} 178 | elif isinstance(loss_info, tuple): 179 | assert len(loss_info) == 2, "loss_info需要是一个二元组,第一个是需要反向传播的,第二个是其他参考指标" 180 | backward_params, metric_params = loss_info 181 | assert isinstance(metric_params, dict), "loss_info的第二个值需要是dict类型" 182 | if isinstance(backward_params, torch.Tensor): 183 | losses = backward_params 184 | metric_params["total_loss"] = backward_params 185 | loss_info = metric_params 186 | elif isinstance(backward_params, dict): 187 | losses = sum(backward_params.values()) 188 | backward_params["total_loss"] = losses 189 | backward_params.update(metric_params) 190 | loss_info = backward_params 191 | else: 192 | assert "total_loss" not in loss_info, "当model返回是一个dict的时候不可以传出包含" 193 | losses = sum(loss_info.values()) 194 | loss_info["total_loss"] = losses 195 | 196 | ########################## 197 | # 2. 计算梯度 # 198 | ########################## 199 | self.optimizer.zero_grad() 200 | self.accelerator.backward(losses) 201 | if self._clip_grad_norm > 0 and self.accelerator.sync_gradients: 202 | self.accelerator.clip_grad_norm_(self.clip_grad_params, self._clip_grad_norm) 203 | 204 | ############################## 205 | # 3. 更新模型参数 # 206 | ############################## 207 | self.optimizer.step() 208 | 209 | ########################### 210 | # 4. 调整学习率 # 211 | ########################### 212 | self.lr_scheduler.step() 213 | 214 | if self.model_ema: 215 | self.model_ema.update(self.model) 216 | 217 | show_info = {k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else v for k, v in loss_info.items()} 218 | show_info = dict(sorted(show_info.items(), key=lambda x: x[0] != "total_loss")) # 保证total_loss在最后一位 219 | self.pbar.set_postfix(show_info) 220 | return show_info 221 | 222 | @property 223 | def model_or_module(self) -> nn.Module: 224 | if self.model.__class__.__name__ == "DistributedDataParallel": 225 | return self.accelerator.unwrap_model(self.model) 226 | return self.model 227 | 228 | 229 | def check_main(self): 230 | """判断是否为主进程, 对于单卡来说永远是True, 对于多卡来说只有一个主进程""" 231 | return self.accelerator.is_main_process 232 | -------------------------------------------------------------------------------- /torch_frame/vision/augmentations/boxes.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import numpy as np 3 | import cv2 4 | import random 5 | import math 6 | from ..tools.object_detection import box_candidates 7 | 8 | __all__ = [ 9 | "random_perspective", 10 | "mixup_boxes", 11 | "mosaic" 12 | ] 13 | 14 | 15 | def random_perspective(img: np.ndarray, targets: Union[np.ndarray, list, tuple], angle: int = 10, translate: int = 0.1, 16 | scale: tuple = (0.7, 1.5), shear: float = 2.0, border: tuple = (0, 0), 17 | fill_value: tuple = (128, 128, 128)): 18 | """ 19 | 对img做随机仿射变换,并且对应的boxes也做相应变化 20 | 21 | Parameters 22 | ---------- 23 | img : np.ndarray. 原始图像. 24 | targets : np.ndarray, list or tuple. 目标box列表, 格式为[[x1, y1, x2, y2], ..., [x1, y1, x2, y2]]. 25 | angle : float or int. 旋转图像的最大度数(含正负). 26 | translate : int, default 0.1. 随机沿着宽高平移的百分比, 比如0.1代表分别向x和y方向平移宽和高的最大0.1距离. 27 | scale : tuple, default (0.7, 1.5). 缩放的范围, 第一个值代表最小界限,第二个值代表最大界限. 28 | shear : float, default 2. 仿射变换的不规则形变的范围, 如果是0,那么就不做仿射变换. 29 | border : tuple, default (0, 0). 向外padding的个数, 第一个参数是高height,第二个是宽width. 30 | fill_value: tuple, default (128, 128, 128). padding填充的颜色 31 | 32 | Returns 33 | ------- 34 | img : np.ndarray. 做过仿射变换的图片 35 | targets : np.ndarray. 经过仿射变换后做过过滤的boxes 36 | """ 37 | if isinstance(targets, tuple) or isinstance(targets, tuple): 38 | targets = np.array(targets, dtype=np.float) 39 | 40 | height = img.shape[0] + border[0] * 2 # shape(h,w,c) 41 | width = img.shape[1] + border[1] * 2 42 | 43 | padding_image = np.full((height, width, 3), fill_value, dtype=np.uint8) 44 | padding_image[border[0]:-border[0], border[1]:-border[1]] = img 45 | img = padding_image 46 | 47 | # Center 48 | C = np.eye(3) 49 | C[0, 2] = -img.shape[1] / 2 # x translation (pixels) 50 | C[1, 2] = -img.shape[0] / 2 # y translation (pixels) 51 | 52 | # Rotation and Scale 53 | R = np.eye(3) 54 | a = random.uniform(-angle, angle) 55 | # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations 56 | s = random.uniform(scale[0], scale[1]) 57 | # s = 2 ** random.uniform(-scale, scale) 58 | R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) 59 | 60 | # Shear 61 | S = np.eye(3) 62 | S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg) 63 | S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg) 64 | perspective = abs(shear) < 1e-3 65 | 66 | # Translation 67 | T = np.eye(3) 68 | T[0, 2] = ( 69 | random.uniform(0.5 - translate, 0.5 + translate) * width 70 | ) # x translation (pixels) 71 | T[1, 2] = ( 72 | random.uniform(0.5 - translate, 0.5 + translate) * height 73 | ) # y translation (pixels) 74 | 75 | # Combined rotation matrix 76 | M = T @ S @ R @ C # order of operations (right to left) is IMPORTANT, @ means matrix multiplication 77 | 78 | ########################### 79 | # For Aug out of Mosaic 80 | # s = 1. 81 | # M = np.eye(3) 82 | ########################### 83 | 84 | if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed 85 | if perspective: 86 | img = cv2.warpPerspective(img, M, dsize=(width, height), borderValue=fill_value) 87 | else: # affine 88 | img = cv2.warpAffine(img, M[:2], dsize=(width, height), borderValue=fill_value) 89 | 90 | # Transform label coordinates 91 | n = len(targets) 92 | if n: 93 | # warp points 94 | xy = np.ones((n * 4, 3)) 95 | xy[:, :2] = targets[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape( 96 | n * 4, 2 97 | ) # x1y1, x2y2, x1y2, x2y1 98 | xy = xy @ M.T # transform 99 | if perspective: 100 | xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale 101 | else: # affine 102 | xy = xy[:, :2].reshape(n, 8) 103 | 104 | # create new boxes 105 | x = xy[:, [0, 2, 4, 6]] 106 | y = xy[:, [1, 3, 5, 7]] 107 | xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T 108 | 109 | # clip boxes 110 | xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width) 111 | xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height) 112 | 113 | # filter candidates 114 | i = box_candidates(box1=targets[:, :4].T * s, box2=xy.T) 115 | targets = targets[i] 116 | targets[:, :4] = xy[i] 117 | 118 | return img, targets 119 | 120 | 121 | def mixup_boxes(img1: np.ndarray, boxes1: np.ndarray, img2: np.ndarray, boxes2: np.ndarray, 122 | mixup_scale: float = 0.1, wh_thr: int = 0, ar_thr: int = 100, area_thr: float = 0): 123 | """ 124 | mixup的目标检测版 125 | 126 | Parameters 127 | ---------- 128 | img1 : np.ndarray. 图片1 129 | boxes1 : np.ndarray. 图片1对应的boxes 130 | img2 : np.ndarray. 图片2 131 | boxes2 : np.ndarray. 图片2对应的boxes 132 | mixup_scale: float, default 0.1. 做alpha融合的alpha波动概率,比如0.1的话实际alpha值会在0.4~0.6均匀采样 133 | wh_thr : int, default 0. boxes宽和高的最小阈值,低于它的box会被删除 134 | ar_thr : int, default 100. 宽高比,低于它的box会被删除 135 | area_thr : float, default 0. 旧vs新box面积比例,低于它的box被删除 136 | 137 | Returns 138 | ------- 139 | dst_img : np.ndarray 140 | dst_boxes : np.ndarray 141 | """ 142 | def translate(image, boxes, h, w, alpha): 143 | y_offset = random.randint(0, dst_img.shape[0] - h) 144 | x_offset = random.randint(0, dst_img.shape[1] - w) 145 | dst_img[y_offset: h + y_offset, x_offset: w + x_offset] += image * alpha 146 | if len(boxes) > 0: 147 | new_boxes = boxes.copy() 148 | new_boxes[:, 1::2] = np.clip(new_boxes[:, 1::2] + x_offset, 0, dst_w) 149 | new_boxes[:, 2::2] = np.clip(new_boxes[:, 2::2] + y_offset, 0, dst_h) 150 | mask = box_candidates(boxes, new_boxes[:, 1:], wh_thr, ar_thr, area_thr) 151 | boxes = new_boxes[mask] 152 | else: 153 | boxes = np.zeros((0, 5)) 154 | return dst_img, boxes 155 | 156 | h1, w1 = img1.shape[:2] 157 | h2, w2 = img2.shape[:2] 158 | dst_img = np.zeros((max(h1, h2), max(w1, w2), 3), dtype=np.float) 159 | dst_h, dst_w = dst_img.shape[:2] 160 | alpha = random.uniform(0.5 - mixup_scale, 0.5 + mixup_scale) 161 | img1 = img1.astype(np.float32) 162 | dst_img, dst_boxes1 = translate(img1.astype(float), boxes1, h1, w1, alpha) 163 | dst_img, dst_boxes2 = translate(img2.astype(float), boxes2, h2, w2, 1 - alpha) 164 | dst_boxes = np.vstack([dst_boxes1, dst_boxes2]) 165 | return dst_img.astype(np.uint8), dst_boxes 166 | 167 | 168 | def mosaic(data: list, ouput_dim: Union[list, tuple], fill_value=114): 169 | """ 170 | mosaic拼接,用于数据增强,对于小目标检测有效 171 | 172 | Parameters 173 | ---------- 174 | data : List[tuple]. list包含(image, bboxes)tuple,长度是4 175 | ouput_dim : Union[list, tuple]. 输出是[height, width] 176 | fill_value : Union[int, tuple]. 177 | 178 | Returns 179 | ------- 180 | mosaic_img : np.ndarray. 拼接后的图片 181 | mosaic_bboxes : Union[list, np.ndarray]. 拼接后的目标框,为空时是list类型 182 | """ 183 | 184 | def get_mosaic_coordinate(mosaic_index, xc, yc, w, h): 185 | # index0 to top left part of image 186 | if mosaic_index == 0: 187 | x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc 188 | small_coord = w - (x2 - x1), h - (y2 - y1), w, h 189 | # index1 to top right part of image 190 | elif mosaic_index == 1: 191 | x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, ouput_w * 2), yc 192 | small_coord = 0, h - (y2 - y1), min(w, x2 - x1), h 193 | # index2 to bottom left part of image 194 | elif mosaic_index == 2: 195 | x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(ouput_h * 2, yc + h) 196 | small_coord = w - (x2 - x1), 0, w, min(y2 - y1, h) 197 | # index2 to bottom right part of image 198 | elif mosaic_index == 3: 199 | x1, y1, x2, y2 = xc, yc, min(xc + w, ouput_w * 2), min(ouput_h * 2, yc + h) # noqa 200 | small_coord = 0, 0, min(w, x2 - x1), min(y2 - y1, h) 201 | return (x1, y1, x2, y2), small_coord 202 | 203 | assert len(data) == 4, "`data`长度应该是4" 204 | assert len(ouput_dim) == 2, "`data`长度应该是2" 205 | 206 | mosaic_bboxes = [] 207 | ouput_h, ouput_w = ouput_dim[0], ouput_dim[1] 208 | 209 | # yc, xc = s, s # mosaic center x, y 210 | yc = int(random.uniform(0.5 * ouput_h, 1.5 * ouput_h)) 211 | xc = int(random.uniform(0.5 * ouput_w, 1.5 * ouput_w)) 212 | mosaic_img = np.full((ouput_h * 2, ouput_w * 2, 3), fill_value, dtype=np.uint8) 213 | 214 | for i_mosaic, (img, bboxes) in enumerate(data): 215 | h0, w0 = img.shape[:2] # orig hw 216 | scale = min(1. * ouput_h / h0, 1. * ouput_w / w0) 217 | img = cv2.resize( 218 | img, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR 219 | ) 220 | h, w = img.shape[:2] 221 | 222 | # suffix l means large image, while s means small image in mosaic aug. 223 | (l_x1, l_y1, l_x2, l_y2), (s_x1, s_y1, s_x2, s_y2) = get_mosaic_coordinate(i_mosaic, xc, yc, w, h) 224 | 225 | mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2] 226 | padw, padh = l_x1 - s_x1, l_y1 - s_y1 227 | 228 | # Normalized xywh to pixel xyxy format 229 | if bboxes.size > 0: 230 | scale_bboxes = bboxes.copy() 231 | scale_bboxes[:, 1] = scale * bboxes[:, 1] + padw 232 | scale_bboxes[:, 2] = scale * bboxes[:, 2] + padh 233 | scale_bboxes[:, 3] = scale * bboxes[:, 3] + padw 234 | scale_bboxes[:, 4] = scale * bboxes[:, 4] + padh 235 | else: 236 | scale_bboxes = np.zeros((0, 5)) 237 | mosaic_bboxes.append(scale_bboxes) 238 | 239 | if len(mosaic_bboxes) > 0: 240 | mosaic_bboxes = np.concatenate(mosaic_bboxes, 0) 241 | np.clip(mosaic_bboxes[:, 0], 0, 2 * ouput_w, out=mosaic_bboxes[:, 0]) 242 | np.clip(mosaic_bboxes[:, 1], 0, 2 * ouput_h, out=mosaic_bboxes[:, 1]) 243 | np.clip(mosaic_bboxes[:, 2], 0, 2 * ouput_w, out=mosaic_bboxes[:, 2]) 244 | np.clip(mosaic_bboxes[:, 3], 0, 2 * ouput_h, out=mosaic_bboxes[:, 3]) 245 | 246 | return mosaic_img, mosaic_bboxes 247 | -------------------------------------------------------------------------------- /torch_frame/vision/tools/object_detection.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Optional, List 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | 9 | __all__ = [ 10 | "filter_box", 11 | "det_postprocess", 12 | "bboxes_iou", 13 | "xyxy2cxcywh", 14 | "box_candidates", 15 | "eval_map" 16 | ] 17 | 18 | 19 | def filter_box(boxes: np.ndarray | torch.Tensor, 20 | min_scale: float = None, 21 | max_scale: float = None) -> np.ndarray | torch.Tensor: 22 | """ 23 | 目标框boxes根据尺寸做过滤 24 | Parameters 25 | ---------- 26 | boxes : np.ndarray or torch.Tensor 27 | 待过滤的目标框 28 | min_scale : float 29 | max_scale : float 30 | 31 | Returns 32 | ------- 33 | boxes : np.ndarray or torch.Tensor 34 | """ 35 | if min_scale is None: 36 | min_scale = 0.0 37 | if max_scale is None: 38 | max_scale = float('inf') 39 | w = boxes[:, 2] - boxes[:, 0] 40 | h = boxes[:, 3] - boxes[:, 1] 41 | keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale) 42 | return boxes[keep] 43 | 44 | 45 | def det_postprocess(prediction: torch.Tensor, 46 | num_classes: int, 47 | conf_thre: float = 0.7, 48 | nms_thre: float = 0.45, 49 | class_agnostic: bool = False, 50 | merge_score: bool = False, 51 | ret_np: bool = False 52 | ) -> List[Optional[torch.Tensor]]: 53 | """ 54 | 目标检测的后处理, 给定候选目标框, 经过置信度和mns双重筛选, 得到最终目标框 55 | Parameters 56 | ---------- 57 | prediction : torch.Tensor 58 | 检测模型给出的候选目标框, 采取了YOLO中boxes的形式: shape=(batch_size, num_boxes, 4 + 1 + num_classes) 59 | boxes的格式是x1,y1,x2,y2,confidence,class_score1, class_score2,...,class_scoreN 60 | num_classes : int 61 | conf_thre : float, default 0.7 62 | 置信度阈值, 目标框的置信度分数=confidence*class_score, 过滤掉所有置信度分数小于conf_thre的boxes 63 | nms_thre : float, default 0.45 64 | mns的阈值 65 | class_agnostic: bool, default False 66 | merge_score: bool, default False. 如果是True则会把obj_score和cls_score合并,返回的output是 67 | [x1,y1, x2, y2, obj_score*cls_score, cls] 68 | ret_np: bool, default False. 是否返回np.ndarray类型结果 69 | 70 | Returns 71 | ------- 72 | output : List[Optional[torch.Tensor]] 73 | 做完置信度和mns过滤的boxes 74 | """ 75 | box_corner = prediction.new(prediction.shape) 76 | box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 77 | box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 78 | box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 79 | box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 80 | prediction[:, :, :4] = box_corner[:, :, :4] 81 | 82 | output = [None for _ in range(len(prediction))] 83 | for i, image_pred in enumerate(prediction): 84 | 85 | # If none are remaining => process next image 86 | if not image_pred.size(0): 87 | continue 88 | # Get score and class with highest confidence 89 | class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True) 90 | 91 | conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze() 92 | # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) 93 | detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1) 94 | detections = detections[conf_mask] 95 | if not detections.size(0): 96 | continue 97 | 98 | if class_agnostic: 99 | nms_out_index = torchvision.ops.nms( 100 | detections[:, :4], 101 | detections[:, 4] * detections[:, 5], 102 | nms_thre, 103 | ) 104 | else: 105 | nms_out_index = torchvision.ops.batched_nms( 106 | detections[:, :4], 107 | detections[:, 4] * detections[:, 5], 108 | detections[:, 6], 109 | nms_thre, 110 | ) 111 | 112 | detections = detections[nms_out_index] 113 | if merge_score: 114 | detections = torch.cat( 115 | [ 116 | detections[:, :4], 117 | torch.unsqueeze(detections[:, 4] * detections[:, 5], 1), 118 | torch.unsqueeze(detections[:, 6], 1) 119 | ], 120 | dim=1 121 | ) 122 | if output[i] is None: 123 | output[i] = detections 124 | else: 125 | output[i] = torch.cat((output[i], detections)) 126 | if ret_np: 127 | output[i] = output[i].cpu().numpy() 128 | return output 129 | 130 | 131 | def bboxes_iou_torch( 132 | bboxes_a: torch.Tensor, 133 | bboxes_b: torch.Tensor, 134 | xyxy: bool = True, 135 | mode: str = "iou", 136 | eps: float = 1e-6 137 | ): 138 | """ 139 | Parameters 140 | ---------- 141 | bboxes_a : torch.Tensor, shape=(batch_size, 4) 142 | bboxes_b : torch.Tensor, shape=(batch_size, 4) 143 | xyxy : bool, default True 144 | 输入格式是[x1,y1,x2,y2]还是[cx,cy,w,h] 145 | mode : str, default `iou`. 支持在以下两种种类型中选择 146 | * `iou` 147 | * `iof` 148 | eps : float, default 1e-6. 为防止除以0引入的小数 149 | 150 | Returns 151 | ------- 152 | iou : np.ndarray | torch.Tensor, shape=(batch_size, ) 153 | """ 154 | if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: 155 | raise IndexError 156 | 157 | if xyxy: 158 | tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2]) 159 | br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) 160 | area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) 161 | area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) 162 | else: 163 | tl = torch.max( 164 | (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), 165 | (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), 166 | ) 167 | br = torch.min( 168 | (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), 169 | (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), 170 | ) 171 | 172 | area_a = torch.prod(bboxes_a[:, 2:], 1) 173 | area_b = torch.prod(bboxes_b[:, 2:], 1) 174 | en = (tl < br).type(tl.type()).prod(dim=2) 175 | area_i = torch.prod(br - tl, 2) * en # * ((tl < br).all()) 176 | 177 | if mode == 'iou': 178 | return area_i / torch.clip(area_a[:, None] + area_b - area_i, eps) 179 | elif mode == 'iof': 180 | return area_i / torch.clip(torch.minimum(area_a[:, None], area_b), eps) 181 | else: 182 | raise ValueError("mode没有这种模式") 183 | 184 | 185 | def bboxes_iou_numpy( 186 | bboxes_a: np.ndarray, 187 | bboxes_b: np.ndarray, 188 | xyxy: bool = True, 189 | mode: str = "iou", 190 | eps: float = 1e-6 191 | ): 192 | """ 193 | Parameters 194 | ---------- 195 | bboxes_a : np.ndarray | torch.Tensor, shape=(batch_size, 4) 196 | bboxes_b : np.ndarray | torch.Tensor, shape=(batch_size, 4) 197 | xyxy : bool, default True 198 | 输入格式是[x1,y1,x2,y2]还是[cx,cy,w,h] 199 | mode : str, default `iou`. 支持在以下两种种类型中选择 200 | * `iou` 201 | * `iof` 202 | eps : float, default 1e-6. 为防止除以0引入的小数 203 | 204 | Returns 205 | ------- 206 | iou : np.ndarray | torch.Tensor, shape=(batch_size, ) 207 | """ 208 | if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4: 209 | raise IndexError 210 | 211 | if xyxy: 212 | tl = np.maximum(bboxes_a[:, None, :2], bboxes_b[:, :2]) 213 | br = np.minimum(bboxes_a[:, None, 2:], bboxes_b[:, 2:]) 214 | area_a = np.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1) 215 | area_b = np.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1) 216 | else: 217 | tl = np.maximum( 218 | (bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2), 219 | (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2), 220 | ) 221 | br = np.minimum( 222 | (bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2), 223 | (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2), 224 | ) 225 | 226 | area_a = np.prod(bboxes_a[:, 2:], 1) 227 | area_b = np.prod(bboxes_b[:, 2:], 1) 228 | en = (tl < br).astype(float).prod(axis=2) 229 | area_i = np.prod(br - tl, 2) * en # * ((tl < br).all()) 230 | 231 | if mode == 'iou': 232 | return area_i / np.maximum(area_a[:, None] + area_b - area_i, eps) 233 | elif mode == 'iof': 234 | return area_i / np.maximum(np.minimum(area_a[:, None], area_b), eps) 235 | else: 236 | raise ValueError("mode没有这种模式") 237 | 238 | 239 | def bboxes_iou( 240 | bboxes_a: np.ndarray | torch.Tensor, 241 | bboxes_b: np.ndarray | torch.Tensor, 242 | xyxy: bool = True, 243 | mode: str = "iou", 244 | eps: float = 1e-6 245 | ): 246 | if isinstance(bboxes_a, np.ndarray) and isinstance(bboxes_b, np.ndarray): 247 | return bboxes_iou_numpy(bboxes_a, bboxes_b, xyxy, mode, eps) 248 | elif isinstance(bboxes_a, torch.Tensor) and isinstance(bboxes_b, torch.Tensor): 249 | return bboxes_iou_torch(bboxes_a, bboxes_b, xyxy, mode, eps) 250 | else: 251 | raise TypeError("`bboxes_a`和`bboxes_b`应为np.ndarray或torch.Tensor, 且相互类型一致") 252 | 253 | 254 | def xyxy2cxcywh(bboxes: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor: 255 | """ 256 | 输入boxes坐标, 格式从[[x1, y1, x2, y2]] -> [[cx, cy, w, h]], shape=[None, 4] 257 | 258 | Parameters 259 | ---------- 260 | bboxes : np.ndarray or torch.Tensor 261 | 目标框坐标数组, shape=(None, 4), 格式是[[x1, y1, x2, y2]] 262 | Returns 263 | ------- 264 | bboxes : np.ndarray or torch.Tensor 265 | 目标框坐标数组, shape=(None, 4), 格式是[[cx, cy, w, h]] 266 | """ 267 | bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] 268 | bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] 269 | bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5 270 | bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5 271 | return bboxes 272 | 273 | 274 | def box_candidates(box1: np.ndarray, box2: np.ndarray, wh_thr: int = 2, ar_thr: int = 20, 275 | area_thr: float = 0.2): 276 | """ 277 | 对box2做过滤,对于一些极端情况做删除 278 | 279 | Parameters 280 | ---------- 281 | box1 : np.ndarray. shape=(None, 4), 格式是[[x1, y1, x2, y2]]. 282 | box2 : np.ndarray. shape=(None, 4), 格式是[[x1, y1, x2, y2]]. 283 | wh_thr : int, default 2. 宽高的最小阈值, 小于该值会被过滤. 284 | ar_thr : int, default 20. 最大宽高比阈值, 大于该值会被过滤. 285 | area_thr: float, default 0.2. 与box1的面积比例,小于该值会被过滤 286 | 287 | Returns 288 | ------- 289 | mask: np.ndarray. shape=(None, ), box2的mask. 290 | """ 291 | # box1(4,n), box2(4,n) 292 | # Compute candidate boxes which include follwing 5 things: 293 | # box1 before augment, box2 after augment, wh_thr (pixels), aspect_ratio_thr, area_ratio 294 | w1, h1 = box1[:, 2] - box1[:, 0], box1[:, 3] - box1[:, 1] 295 | w2, h2 = box2[:, 2] - box2[:, 0], box2[:, 3] - box2[:, 1] 296 | ar = np.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio 297 | return ( 298 | (w2 > wh_thr) 299 | & (h2 > wh_thr) 300 | & (w2 * h2 / (w1 * h1 + 1e-16) > area_thr) 301 | & (ar < ar_thr) 302 | ) 303 | 304 | 305 | def cal_match(pred_box, gt_box): 306 | count = 0 307 | iou = bboxes_iou(pred_box[:, 1:], gt_box[:, 1:]) 308 | for i, pbox in enumerate(pred_box): 309 | for j, gbox in enumerate(gt_box.copy()): 310 | if iou[i][j] >= 0.9 and np.abs(pbox[1:] - gbox[1:]).max() < 5 and pbox[0] == gbox[0]: 311 | count += 1 312 | gt_box = np.concatenate([gt_box[:j], gt_box[j + 1:]]) 313 | iou = np.concatenate([iou[:, :j], iou[:, j + 1:]], 1) 314 | break 315 | return count / len(pred_box) 316 | 317 | 318 | def calc_ap(recall, precision, use_07_metric=True): 319 | # 是否使用 07 年的 11 点均值方式计算 ap 320 | if use_07_metric: 321 | ap = 0. 322 | for threshold in np.arange(0., 1.1, 0.1): 323 | if np.sum(recall >= threshold) == 0: 324 | p = 0 325 | else: 326 | p = np.max(precision[recall >= threshold]) 327 | ap = ap + p / 11. 328 | else: 329 | # 增加哨兵,然后算出准确率包络 330 | mrec = np.concatenate(([0.], recall, [1.])) 331 | mpre = np.concatenate(([0.], precision, [0.])) 332 | for i in range(mpre.size - 1, 0, -1): 333 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 334 | 335 | # 计算 pr 曲线中,召回率变化的下标 336 | idx = np.where(mrec[1:] != mrec[:-1])[0] 337 | 338 | # 计算 pr 曲线与坐标轴所围区域的面积 339 | ap = np.sum((mrec[idx + 1] - mrec[idx]) * mpre[idx + 1]) 340 | 341 | return ap 342 | 343 | 344 | def clac_voc_ap(annotation, prediction, ovthresh=0.5): 345 | # 统计标注框个数,这里需要将困难样本数量去掉 346 | positive_num = 0 347 | for anno in annotation: 348 | positive_num += len(anno.get("bboxes", [])) 349 | anno['b_det'] = [False] * len(anno.get("bboxes", [])) 350 | 351 | # 将检测结果格式进行转换,主要是为了方便排序 352 | # for item in prediction: 353 | image_ids, confidences, bboxes = [], [], [] 354 | for img_id, item in enumerate(prediction): 355 | bbox = item.get("bboxes", []) 356 | score = item.get("confidence", []) 357 | for i in range(len(score)): 358 | image_ids.append(img_id) 359 | confidences.append(score[i]) 360 | bboxes.append(bbox[i]) 361 | image_ids, confidences, bboxes = np.array(image_ids), np.array(confidences), np.array(bboxes) 362 | 363 | # 按照置信度排序 364 | sorted_ind = np.argsort(-confidences) 365 | bboxes = bboxes[sorted_ind] 366 | image_ids = image_ids[sorted_ind] 367 | 368 | # 计算 TP 和 FP,以计算出 AP 值 369 | detect_num = len(image_ids) 370 | tp = np.zeros(detect_num) 371 | fp = np.zeros(detect_num) 372 | for d in range(detect_num): 373 | gt_bboxes = annotation[image_ids[d]].get("bboxes", []) 374 | 375 | # 如果没有 ground truth,那么所有的检测都是错误的 376 | if len(gt_bboxes) > 0: 377 | b_dets = annotation[image_ids[d]]["b_det"] 378 | p_bboxes = bboxes[d, :] 379 | 380 | overlaps = bboxes_iou(p_bboxes[np.newaxis], gt_bboxes, mode="iou")[0] 381 | idxmax = np.argmax(overlaps) 382 | ovmax = overlaps[idxmax] 383 | 384 | if ovmax > ovthresh: 385 | # gt 只允许检测出一次 386 | if not b_dets[idxmax]: 387 | tp[d] = 1. 388 | b_dets[idxmax] = True 389 | else: 390 | fp[d] = 1. 391 | else: 392 | fp[d] = 1. 393 | else: 394 | fp[d] = 1. 395 | 396 | # 计算召回率和准确率 397 | tp = np.cumsum(tp) 398 | fp = np.cumsum(fp) 399 | recall = tp / float(positive_num) 400 | precision = tp / np.maximum(tp + fp, 1.) 401 | ap = calc_ap(recall, precision) 402 | 403 | return recall, precision, ap 404 | 405 | 406 | def eval_map(annotation_result, detector_result, ovthresh=0.5): 407 | aps = {} 408 | for cls in annotation_result.keys(): 409 | recall, precision, ap = clac_voc_ap(annotation_result[cls], detector_result[cls], ovthresh=ovthresh) 410 | aps[cls] = ap 411 | return aps 412 | -------------------------------------------------------------------------------- /torch_frame/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import os.path as osp 4 | import time 5 | import weakref 6 | from typing import Any, Dict, List, Optional, Tuple 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.cuda.amp import GradScaler, autocast 12 | from torch.nn.parallel import DataParallel, DistributedDataParallel 13 | from torch.nn.utils import clip_grad_norm_ 14 | from torch.utils.data import DataLoader 15 | from .hooks import CheckpointerHook, HookBase, LoggerHook 16 | from .utils import setup_logger, ProgressBar, HistoryBuffer, EMA 17 | from .utils.misc import get_workspace 18 | from .utils.dist_utils import get_rank 19 | from .lr_scheduler import LRWarmupScheduler 20 | from ._get_logger import logger 21 | 22 | 23 | class Trainer: 24 | """ 25 | 一个基于epoch的通用训练框架(目前只支持单gpu运行), 包含: 26 | 1. 计算从dataloader中计算loss 27 | 2. 计算梯度 28 | 3. 用optimizer更新参数 29 | 4. 调整学习率 30 | 如果想完成更复杂的功能, 也可以继承该类编写子类, 重写里面的`train_one_iter`等方法 31 | 以下是代码示例 32 | .. code-block:: python 33 | model = ... # 初始化你的模型 34 | optimizer = ... # 你的优化器 35 | lr_scheduler = ... # 初始化你的调节器 36 | data_loader = ... # 初始化你的数据生成器 37 | # 训练100轮 38 | trainer = Trainer(model, optimizer, lr_scheduler, data_loader, max_epochs=100) 39 | trainer.train() 40 | 41 | Parameters 42 | --------- 43 | model : torch.nn.Module, 训练模型, 训练时的输出只能是以下三种: 44 | * torch.Tensor, 对于这种输出是模型backward用的loss, 在torch-frame框架中被称为total_loss 45 | * dict, 里面是模型的各路分支的loss,需要是标量, Trainer会自动将其求和得到total_loss, 再做backward 46 | * Tuple[Union[dict, torch.Tensor], dict]. 前两种的混合体, 元组第一个输出是前面两种的任意一种; 47 | 第二个输出是非需要backward类的, Trainer不会把这个dict汇总到total_loss上 48 | optimizer : torch.optim.Optimizer, 优化器 49 | lr_scheduler : optim.lr_scheduler._LRScheduler, 学习率调节器 50 | data_loader : torch.utils.data.DataLoader, 数据生成器 51 | max_epochs : int, 训练的总轮数 52 | work_dir : str, 保存模型和日志的根目录地址 53 | clip_grad_norm : float, default 0.0 54 | 梯度裁剪的设置, 如果置为小于等于0, 则不作梯度裁剪 55 | enable_amp : bool, 使用混合精度 56 | warmup_method : str, default None 57 | warmup的类型, 包含以下四种取值 58 | * constant 59 | * linear 60 | * exp 61 | * None : 不使用warmup 62 | warmup_iters : int, default 1000, warmup最后的iter数 63 | warmup_factor : float, default 0.001 64 | warmup初始学习率 = warmup_factor * initial_lr 65 | hooks : List[HookBase], default None. 66 | hooks, 保存模型、输出评估指标、loss等用 67 | use_ema : bool, default False. 是否使用EMA技术 68 | ema_decay: float = 0.9999. EMA模型衰减系数 69 | create_new_dir : Optional[str], default time 70 | 存在同名目录时以何种策略创建目录 71 | * None, 直接使用同名目录 72 | * `time_s`, 如果已经存在同名目录, 则以时间(精确到秒)为后缀创建新目录 73 | * `time_m`, 如果已经存在同名目录, 则以时间(精确到分)为后缀创建新目录 74 | * `time_h`, 如果已经存在同名目录, 则以时间(精确到小时)为后缀创建新目录 75 | * `time_d`, 如果已经存在同名目录, 则以时间(精确到日)为后缀创建新目录 76 | * `count`, 如果已经存在同名目录, 则以序号为后缀创建新目录 77 | """ 78 | 79 | def __init__( 80 | self, 81 | model: nn.Module, 82 | optimizer: optim.Optimizer, 83 | lr_scheduler: optim.lr_scheduler._LRScheduler, 84 | data_loader: DataLoader, 85 | max_epochs: int, 86 | work_dir: str = "work_dir", 87 | clip_grad_norm: float = 0.0, 88 | enable_amp=False, 89 | warmup_method: Optional[str] = None, 90 | warmup_iters: int = 1000, 91 | warmup_factor: float = 0.001, 92 | hooks: Optional[List[HookBase]] = None, 93 | use_ema: bool = False, 94 | ema_decay: float = 0.9999, 95 | create_new_dir: Optional[str] = "time_s" 96 | ): 97 | logger.setLevel(logging.INFO) 98 | 99 | self.work_dir = get_workspace(work_dir, create_new_dir) 100 | 101 | self.model = model 102 | if use_ema: 103 | self.model_ema = EMA(self.model_or_module, ema_decay) 104 | else: 105 | self.model_ema = None 106 | self.optimizer = optimizer 107 | # convert epoch-based scheduler to iteration-based scheduler 108 | self.lr_scheduler = LRWarmupScheduler( 109 | lr_scheduler, len(data_loader), warmup_method, warmup_iters, warmup_factor 110 | ) 111 | self.data_loader = data_loader 112 | self.metric_storage = MetricStorage() 113 | 114 | # counters 115 | self.inner_iter: int = -1 # [0, epoch_len - 1] 116 | self.epoch: int = -1 # [0, max_epochs - 1] 117 | self.start_epoch = 0 # [0, max_epochs - 1] 118 | self.max_epochs = max_epochs 119 | 120 | self._hooks: List[HookBase] = [] 121 | self._clip_grad_norm = clip_grad_norm 122 | if not torch.cuda.is_available(): 123 | enable_amp = False 124 | self.logger_print("torch环境无法使用cuda, AMP无法使用", logger.warning) 125 | self._enable_amp = enable_amp 126 | 127 | if self._enable_amp: 128 | self.logger_print("自动混合精度 (AMP) 训练") 129 | self._grad_scaler = GradScaler() 130 | 131 | self.rank = get_rank() 132 | if hooks is None: 133 | self.register_hooks(self._build_default_hooks()) 134 | elif self.rank != 0: 135 | self.logger_print(f"{self.rank}号卡也被传入了hook, 将被自动忽略", logger.warning) 136 | else: 137 | self.register_hooks(hooks) 138 | 139 | self.info_params = [] 140 | 141 | def log_param(self, *args, **kwargs): 142 | """打印信息到logger上""" 143 | inforamtions_tuple = "\n".join(args) 144 | inforamtions_dict = "\n".join([f"{k}: {v}" for k, v in kwargs.items()]) 145 | if len(inforamtions_tuple) > 0: 146 | self.info_params.append(inforamtions_tuple) 147 | if len(inforamtions_dict) > 0: 148 | self.info_params.append(inforamtions_dict) 149 | 150 | @property 151 | def lr(self) -> float: 152 | return self.optimizer.param_groups[0]["lr"] 153 | 154 | @property 155 | def epoch_len(self) -> int: 156 | return len(self.data_loader) 157 | 158 | @property 159 | def max_iters(self) -> int: 160 | return self.max_epochs * self.epoch_len 161 | 162 | @property 163 | def cur_iter(self) -> int: 164 | """返回当前iter数, 范围在 [0, max_iters - 1].""" 165 | return self.epoch * self.epoch_len + self.inner_iter 166 | 167 | @property 168 | def start_iter(self) -> int: 169 | """从哪一个iter开始. 最小的值是0.""" 170 | return self.start_epoch * self.epoch_len 171 | 172 | @property 173 | def clip_grad_params(self) -> list: 174 | if not hasattr(self, "_clip_grad_params"): 175 | params = [] 176 | for p in self.optimizer.param_groups: 177 | params += p["params"] 178 | setattr(self, "_clip_grad_params", params) 179 | return getattr(self, "_clip_grad_params") 180 | 181 | @property 182 | def ckpt_dir(self) -> str: 183 | return osp.join(self.work_dir, "checkpoints") 184 | 185 | @property 186 | def tb_log_dir(self) -> str: 187 | return osp.join(self.work_dir, "tb_logs") 188 | 189 | @property 190 | def log_file(self) -> str: 191 | return osp.join(self.work_dir, "log.txt") 192 | 193 | @property 194 | def model_or_module(self) -> nn.Module: 195 | if isinstance(self.model, (DistributedDataParallel, DataParallel)): 196 | return self.model.module 197 | return self.model 198 | 199 | @property 200 | def model_evaluate(self) -> nn.Module: 201 | """评估用的模型""" 202 | if self.model_ema: 203 | return self.model_ema.model 204 | return self.model_or_module 205 | 206 | @property 207 | def registered_hook_names(self) -> List[str]: 208 | """注册的所有hook名字""" 209 | return [h.__class__.__name__ for h in self._hooks] 210 | 211 | def log(self, *args, **kwargs) -> None: 212 | """更新评估指标""" 213 | self.metric_storage.update(*args, **kwargs) 214 | 215 | def _prepare_for_training(self, 216 | console_log_level: int = 2, 217 | file_log_level: int = 2) -> None: 218 | """ 219 | 训练前的配置工作 220 | Parameters 221 | ---------- 222 | console_log_level : int, default 2 223 | 输出到屏幕的log等级, 可选范围是0-5, 它们对应的关系分别为: 224 | * 5: FATAL 225 | * 4: ERROR 226 | * 3: WARNING 227 | * 2: INFO 228 | * 1: DEBUG 229 | * 0: NOTSET 230 | file_log_level : int, default 2 231 | 输出到文件里的log等级, 其他方面同console_log_level参数 232 | """ 233 | # setup the root logger of the `cpu` library to show 234 | # the log messages generated from this library 235 | assert console_log_level in (0, 1, 2, 3, 4, 5), f"console_log_level必须在0~5之间而不是{console_log_level}" 236 | assert file_log_level in (0, 1, 2, 3, 4, 5), f"file_log_level必须在0~5之间而不是{file_log_level}" 237 | console_log_level *= 10 238 | file_log_level *= 10 239 | setup_logger("torch_frame", output=self.log_file, 240 | console_log_level=console_log_level, file_log_level=file_log_level) 241 | 242 | if self.start_epoch == 0: 243 | if self.check_main(): 244 | os.makedirs(self.ckpt_dir, exist_ok=True) 245 | split_line = "-" * 50 246 | self.logger_print( 247 | f"\n{split_line}\n" 248 | f"Work directory: {self.work_dir}\n" 249 | f"{split_line}" 250 | ) 251 | if len(self.info_params) > 0: 252 | self.logger_print("\n"+"\n".join(self.info_params)) 253 | 254 | def register_hooks(self, hooks: List[Optional[HookBase]]) -> None: 255 | """ 256 | Trainer运行时调用hook 257 | hook执行时根据它们注册的顺序来进行 258 | 259 | Parameters 260 | --------- 261 | hooks : list[HookBase] 262 | """ 263 | hooks = [h for h in hooks if h is not None] 264 | for h in hooks: 265 | assert isinstance(h, HookBase) 266 | # To avoid circular reference, hooks and trainer cannot own each other. This normally 267 | # does not matter, but will cause memory leak if the involved objects contain __del__. 268 | # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/ 269 | h.trainer = weakref.proxy(self) 270 | # We always keep :class:`LoggerHook` as the last hook to avoid losing any records 271 | # that should have been logged. The order of other hooks remains the same. 272 | if self._hooks and isinstance(self._hooks[-1], LoggerHook): 273 | self._hooks.insert(len(self._hooks) - 1, h) 274 | else: 275 | self._hooks.append(h) 276 | self.logger_print(f"Registered default hooks: {self.registered_hook_names}", logger.warning) 277 | 278 | def _call_hooks(self, stage: str) -> None: 279 | for h in self._hooks: 280 | getattr(h, stage)() 281 | 282 | def _build_default_hooks(self) -> List[HookBase]: 283 | return [ 284 | CheckpointerHook(), 285 | LoggerHook(tb_log_dir=self.tb_log_dir), 286 | ] 287 | 288 | def _update_iter_metrics(self, loss_dict: Dict[str, torch.Tensor], data_time: float, 289 | iter_time: float, lr: float) -> None: 290 | """ 291 | 每个iter评估的log 292 | Parameters 293 | ---------- 294 | loss_dict : dict, losses的标量字典 295 | data_time : float, dataloader的一个iter耗时 296 | iter_time : float, 一个iter全部耗时 297 | lr : float, 该iter的学习率 298 | """ 299 | self.log(self.cur_iter, data_time=data_time, iter_time=iter_time) 300 | self.log(self.cur_iter, lr=lr, smooth=False) 301 | 302 | loss_value = sum(loss_dict.values()) 303 | if not np.isfinite(loss_value): 304 | raise FloatingPointError( 305 | f"Loss became infinite or NaN at epoch={self.epoch}! loss_dict = {loss_dict}." 306 | ) 307 | 308 | self.log(self.cur_iter, **loss_dict) 309 | 310 | def train_one_iter(self, batch) -> dict: 311 | """ 312 | 包含了训练的一个iter的全部操作 313 | 314 | .. Note:: 315 | 标准的学习率调节器是基于epoch的, 但torch_frame框架是基于iter的, 所以它在每次iter之后都会调用 316 | """ 317 | ##################### 318 | # 1. 计算loss # 319 | ##################### 320 | if self._enable_amp: 321 | with autocast(): 322 | loss_info = self.model(batch) 323 | else: 324 | loss_info = self.model(batch) 325 | if isinstance(loss_info, torch.Tensor): 326 | losses = loss_info 327 | loss_info = {"total_loss": loss_info} 328 | elif isinstance(loss_info, tuple): 329 | assert len(loss_info) == 2, "loss_info需要是一个二元组,第一个是需要反向传播的,第二个是其他参考指标" 330 | backward_params, metric_params = loss_info 331 | assert isinstance(metric_params, dict), "loss_info的第二个值需要是dict类型" 332 | if isinstance(backward_params, torch.Tensor): 333 | losses = backward_params 334 | metric_params["total_loss"] = backward_params 335 | loss_info = metric_params 336 | elif isinstance(backward_params, dict): 337 | losses = sum(backward_params.values()) 338 | backward_params["total_loss"] = losses 339 | backward_params.update(metric_params) 340 | loss_info = backward_params 341 | else: 342 | assert "total_loss" not in loss_info, "当model返回是一个dict的时候不可以传出包含" 343 | losses = sum(loss_info.values()) 344 | loss_info["total_loss"] = losses 345 | 346 | ########################## 347 | # 2. 计算梯度 # 348 | ########################## 349 | self.optimizer.zero_grad() 350 | if self._enable_amp: 351 | self._grad_scaler.scale(losses).backward() 352 | else: 353 | losses.backward() 354 | if self._clip_grad_norm > 0: 355 | if self._enable_amp: 356 | self._grad_scaler.unscale_(self.optimizer) 357 | clip_grad_norm_(self.clip_grad_params, self._clip_grad_norm) 358 | 359 | ############################## 360 | # 3. 更新模型参数 # 361 | ############################## 362 | if self._enable_amp: 363 | self._grad_scaler.step(self.optimizer) 364 | self._grad_scaler.update() 365 | else: 366 | self.optimizer.step() 367 | if self.model_ema: 368 | self.model_ema.update(self.model) 369 | 370 | ########################### 371 | # 4. 调整学习率 # 372 | ########################### 373 | self.lr_scheduler.step() 374 | 375 | show_info = {k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else v for k, v in loss_info.items()} 376 | show_info = dict(sorted(show_info.items(), key=lambda x: x[0] != "total_loss")) # 保证total_loss在最后一位 377 | self.pbar.set_postfix(show_info) 378 | return show_info 379 | 380 | def _train_one_epoch(self) -> None: 381 | """执行模型一个epoch的全部操作""" 382 | self.model.train() 383 | self.pbar = ProgressBar(total=self.epoch_len, desc=f"epoch={self.epoch}", ascii=True) 384 | 385 | start_time_data = time.perf_counter() 386 | for self.inner_iter, batch in enumerate(self.data_loader): 387 | start_time_iter = time.perf_counter() 388 | data_time = start_time_iter - start_time_data 389 | self._call_hooks("before_iter") 390 | show_info = self.train_one_iter(batch) 391 | self._call_hooks("after_iter") 392 | self._update_iter_metrics(show_info, data_time, time.perf_counter() - start_time_data, self.lr) 393 | self.pbar.update(1) 394 | start_time_data = time.perf_counter() 395 | self.pbar.close() 396 | del self.pbar 397 | 398 | def train(self, 399 | console_log_level: int = 2, 400 | file_log_level: int = 2) -> None: 401 | """ 402 | 训练入口 403 | 404 | Parameters 405 | ---------- 406 | console_log_level : int, default 2. 407 | 输出到屏幕的log等级, 可选范围是0-5, 它们对应的关系分别为: 408 | * 5: FATAL 409 | * 4: ERROR 410 | * 3: WARNING 411 | * 2: INFO 412 | * 1: DEBUG 413 | * 0: NOTSET 414 | file_log_level : int, default 2. 415 | 输出到文件里的log等级, 其他方面同console_log_level参数 416 | """ 417 | self.logger_print(f"从第{self.start_epoch + 1}个epoch训练开始") 418 | self._prepare_for_training(console_log_level, file_log_level) 419 | self._call_hooks("before_train") 420 | for self.epoch in range(self.start_epoch + 1, self.max_epochs + 1): 421 | self._call_hooks("before_epoch") 422 | self._train_one_epoch() 423 | self._call_hooks("after_epoch") 424 | self._call_hooks("after_train") 425 | 426 | def save_checkpoint(self, file_name: str, save_single_model: bool = True, 427 | print_info: bool = False) -> None: 428 | """ 429 | 保存参数, 包含: 430 | 431 | * epoch : 当前轮数 432 | * model : 当前模型参数 433 | * optimizer : 当前优化器 434 | * lr_scheduler : 当前调节器 435 | * metric_storage : 评估指标 436 | * hooks(非必须) : 一些中间量 437 | * grad_scaler(非必须) : 混合精度的参数 438 | 439 | Parameters 440 | ---------- 441 | file_name : str, 保存文件名 442 | save_single_model : bool, default True. 443 | * True, 会保存模型参数本身 444 | * False, 会保存包含模型、hook及优化器等权重 445 | print_info : bool, default True. 如果是True, 则输出保存模型的提示信息 446 | """ 447 | 448 | data = { 449 | "epoch": self.epoch, 450 | "model": self.model_or_module.state_dict(), 451 | "optimizer": self.optimizer.state_dict(), 452 | "lr_scheduler": self.lr_scheduler.state_dict(), 453 | "metric_storage": self.metric_storage, 454 | "work_dir": self.work_dir, 455 | } 456 | hook_states = {h.class_name: h.state_dict() for h in self._hooks if h.checkpointable} 457 | if hook_states: 458 | data["hooks"] = hook_states 459 | if hasattr(self, "_enable_amp") and self._enable_amp: 460 | data["grad_scaler"] = self._grad_scaler.state_dict() 461 | 462 | file_path = osp.join(self.ckpt_dir, file_name) 463 | if print_info: 464 | self.logger_print(f"Saving checkpoint to {file_path}") 465 | if save_single_model: 466 | torch.save(self.model_evaluate.state_dict(), file_path) 467 | else: 468 | torch.save(data, file_path) 469 | 470 | def load_checkpoint(self, path: str = None, checkpoint: Dict[str, Any] = None): 471 | """ 472 | 加载参数 473 | 474 | Parameters 475 | ---------- 476 | path : str, default None. checkpoint的地址 477 | checkpoint : dict, default None. 478 | 如果path非空, 优先使用path的数据, 否则直接加载checkpoint的数据。 479 | 直接加载的时候,将只加载模型,而不带各种状态 480 | """ 481 | assert checkpoint is None or path is None 482 | if path is None: 483 | incompatible = self.model_or_module.load_state_dict(checkpoint, strict=False) 484 | if incompatible.missing_keys: 485 | self.logger_print("Encounter missing keys when loading model weights:\n" 486 | f"{incompatible.missing_keys}", 487 | logger.warning) 488 | if incompatible.unexpected_keys: 489 | self.logger_print("Encounter unexpected keys when loading model weights:\n" 490 | f"{incompatible.unexpected_keys}", 491 | logger.warning) 492 | self.logger_print("只加载模型本身...") 493 | return 494 | 495 | checkpoint = torch.load(path, map_location="cpu") 496 | 497 | # 1. 加载 epoch 498 | self.start_epoch = checkpoint["epoch"] + 1 499 | 500 | # 2. 加载 metric_storage 501 | self.metric_storage = checkpoint["metric_storage"] 502 | 503 | # 3. 加载 optimizer 504 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 505 | 506 | # 4. 加载 lr_scheduler 507 | self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) 508 | 509 | # 5. 加载 grad scaler 510 | consistent_amp = not (self._enable_amp ^ ("grad_scaler" in checkpoint)) 511 | assert consistent_amp, "Found inconsistent AMP training setting when loading checkpoint." 512 | if hasattr(self, "_enable_amp") and self._enable_amp: 513 | self._grad_scaler.load_state_dict(checkpoint["grad_scaler"]) 514 | 515 | # 6. 加载 模型 516 | incompatible = self.model_or_module.load_state_dict(checkpoint["model"], strict=False) 517 | if incompatible.missing_keys: 518 | self.logger_print("Encounter missing keys when loading model weights:\n" 519 | f"{incompatible.missing_keys}", 520 | logger.warning) 521 | if incompatible.unexpected_keys: 522 | self.logger_print("Encounter unexpected keys when loading model weights:\n" 523 | f"{incompatible.unexpected_keys}", 524 | logger.warning) 525 | 526 | # 7. 加载 hooks 527 | hook_states = checkpoint.get("hooks", {}) 528 | hook_names = [h.class_name for h in self._hooks if h.checkpointable] 529 | missing_keys = [name for name in hook_names if name not in hook_states] 530 | unexpected_keys = [key for key in hook_states if key not in hook_names] 531 | if missing_keys: 532 | self.logger_print(f"Encounter missing keys when loading hook state dict:\n{missing_keys}", 533 | logger.warning) 534 | if unexpected_keys: 535 | self.logger_print(f"Encounter unexpected keys when loading hook state dict:\n{unexpected_keys}", 536 | logger.warning) 537 | 538 | for key, value in hook_states.items(): 539 | for h in self._hooks: 540 | if h.class_name == key and h.checkpointable: 541 | h.load_state_dict(value) 542 | break 543 | 544 | # 8. 加载保存目录 545 | self.work_dir = checkpoint["work_dir"] 546 | 547 | # 9. 加载ema 模型 548 | if self.model_ema: 549 | lambda_decay = self.model_ema.decay 550 | self.model_ema = EMA(self.model_or_module, updates=self.cur_iter) 551 | self.model_ema.decay = lambda_decay 552 | 553 | if path: 554 | self.logger_print(f"加载模型{path}成功") 555 | 556 | def check_main(self): 557 | """判断是否为主进程, 对于单卡来说永远是True, 对于多卡来说只有一个主进程""" 558 | flag = get_rank() == 0 559 | return flag 560 | 561 | def logger_print(self, string: str, function=logger.info): 562 | """打印logger信息""" 563 | if self.check_main(): 564 | function(string) 565 | 566 | 567 | 568 | class MetricStorage(dict): 569 | """The class stores the values of multiple metrics (some of them may be noisy, e.g., loss, 570 | batch time) in training process, and provides access to the smoothed values for better logging. 571 | 572 | The class is designed for automatic tensorboard logging. User should specify the ``smooth`` 573 | when calling :meth:`update`, in order to we can determine which metrics should be 574 | smoothed when performing tensorboard logging. 575 | 576 | Example:: 577 | 578 | >>> metric_storage = MetricStorage() 579 | >>> metric_storage.update(iter=0, loss=0.2) 580 | >>> metric_storage.update(iter=0, lr=0.01, smooth=False) 581 | >>> metric_storage.update(iter=1, loss=0.1) 582 | >>> metric_storage.update(iter=1, lr=0.001, smooth=False) 583 | >>> # loss will be smoothed, but lr will not 584 | >>> metric_storage.values_maybe_smooth 585 | {"loss": (1, 0.15), "lr": (1, 0.001)} 586 | >>> # like dict, can be indexed by string 587 | >>> metric_storage["loss"].avg 588 | 0.15 589 | """ 590 | 591 | def __init__(self, default_win_size: int = 20) -> None: 592 | self._default_win_size = default_win_size 593 | self._history: Dict[str, HistoryBuffer] = self 594 | self._smooth: Dict[str, bool] = {} 595 | self._latest_iter: Dict[str, int] = {} 596 | 597 | def update(self, iter: Optional[int] = None, smooth: bool = True, window_size: int = None, **kwargs) -> None: 598 | """Add new scalar values of multiple metrics produced at a certain iteration. 599 | 600 | Args: 601 | iter (int): The iteration in which these values are produced. 602 | If None, use the built-in counter starting from 0. 603 | smooth (bool): If True, return the smoothed values of these metrics when 604 | calling :meth:`values_maybe_smooth`. Otherwise, return the latest values. 605 | The same metric must have the same ``smooth`` in different calls to :meth:`update`. 606 | window_size : int 607 | """ 608 | for key, value in kwargs.items(): 609 | if key in self._smooth: 610 | assert self._smooth[key] == smooth 611 | else: 612 | self._smooth[key] = smooth 613 | self._history[key] = HistoryBuffer(window_size=window_size if window_size else self._default_win_size) 614 | self._latest_iter[key] = -1 615 | if iter is not None: 616 | assert iter > self._latest_iter[key], "检查total_loss是不是存在于model给出的loss_dict中" 617 | self._latest_iter[key] = iter 618 | else: 619 | self._latest_iter[key] += 1 620 | self._history[key].update(value) 621 | 622 | @property 623 | def values_maybe_smooth(self) -> Dict[str, Tuple[int, float]]: 624 | """Return the smoothed values or the latest values of multiple metrics. 625 | The specific behavior depends on the ``smooth`` when updating metrics. 626 | 627 | Returns: 628 | dict[str -> (int, float)]: Mapping from metric name to its 629 | (the latest iteration, the avg/latest value) pair. 630 | """ 631 | return { 632 | key: (self._latest_iter[key], his_buf.avg if self._smooth[key] else his_buf.latest) 633 | for key, his_buf in self._history.items() 634 | } 635 | --------------------------------------------------------------------------------