├── .gitignore ├── README.md ├── app.py ├── applications └── balloon │ └── config.py ├── doc ├── demo_api.png ├── head.png ├── infer_flow.gif ├── 基本流程.svg ├── 数据流动.svg └── 独占资源分配器.svg ├── pyinfer ├── __init__.py ├── core │ ├── __init__.py │ ├── build.py │ ├── engine │ │ ├── __init__.py │ │ └── engine.py │ ├── hook │ │ ├── __init__.py │ │ └── hooks.py │ ├── infer │ │ ├── __init__.py │ │ ├── base.py │ │ └── detection.py │ ├── job.py │ └── mono_allocator.py └── utils │ ├── __init__.py │ ├── common │ ├── __init__.py │ ├── config.py │ ├── logger.py │ └── registry.py │ ├── detection │ ├── __init__.py │ ├── bbox.py │ └── nms.py │ └── functional │ ├── __init__.py │ ├── coco.py │ ├── slice.py │ └── traits.py ├── requirments.txt ├── static ├── README.md ├── redoc.standalone.js ├── swagger-ui-bundle.js └── swagger-ui.css ├── tools ├── generate_gif.py ├── mmdet_export_onnx │ └── balloon │ │ ├── export_yolox_onnx.py │ │ └── yolox_s_8x8_300e_coco.py ├── slice_coco.py └── visual_coco.py └── workspace ├── balloon.jpg └── infer_balloon.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | models 2 | __pycache__/ 3 | .vscode 4 | *.pth 5 | *.onnx 6 | doc/ori -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![img](doc/head.png) 2 | 3 | # 概述 4 | 5 | 机器学习高性能推理解决方案(python 版); 6 | 7 | **功能** 8 | 9 | 1. 基于事件循环的非阻塞同步动态 Batch 推理:推理的并发和并行; 10 | 2. 独占资源分配器:预处理和推理并行,并且防止内存溢出; 11 | 3. 分解推理:分解为多个子任务推理; 12 | 13 | **基本流程** 14 | 15 | ![](./doc/基本流程.svg) 16 | 17 | 1. 异步服务:FastAPI 注册的路由 `route func`,`route func`为异步协程,实现并发访问; 18 | 2. 异步推理:`route func`解析请求参数获取输入,将输入封装成 job,并立即返回未完成计算 `future`;`route func`通过 await 挂起,直到协程 `future`返回结果; 19 | 3. 执行任务:创建消费者线程,从任务队列 `queue`中获取任务并执行具体的推理工作,并将推理结果填充至 `future`; 20 | 4. 结果返回:异步协程 `future`在结果填充后返回,程序回到 `route func`,收集推理结果通过 FastAPI 返回,完成一次推理; 21 | 22 | **数据流** 23 | 24 | ![infer_flow](./doc/infer_flow.gif) 25 | 26 | 1. 接收数据:FastAPI 路由函数 `route func`接收数据流 `raw_data`; 27 | 2. 解析数据:生产者线程通过推理器 `infer.parse_raw`方法,将 `raw_data`解析为推理器输入格式 `infer.Input`; 28 | 3. 存储推理任务输入:将解析后输入 `input`存储在 `job.input`; 29 | 4. 预处理:通过 `preprocess`对 `job.input`做预处理; 30 | 5. 存储模型输入:申请数据资源 `mono_data`,将预处理后模型输入存储至 `mono_data.input`; 31 | 6. 多个任务模型输入组成 batch:多个任务 `mono_data.input`组合为 `batch_input`; 32 | 7. 存储模型输出:根据索引,从 `batch_output`中获取任务输出记录在 `mono_data.output`中; 33 | 8. 后处理:通过 `postprocess`,将模型输出 `mono_data.output`转换为最终推理结果存储在 `job.output`中; 34 | 9. 返回推理结果:将推理结果 `job.output`填充至未完成工作 `future`,使 `future`变成完成状态,FastAPI 获取 `future.result()`后返回; 35 | 36 | # 安装 37 | 38 | ``` 39 | pip install fastapi[all] 40 | ``` 41 | 42 | # 关键技术 43 | 44 | ## 异步动态 Batch 推理 45 | 46 | 1. 异步 47 | 48 | python 的异步模块并非真实的异步,而是基于事件循环的非阻塞同步操作; 49 | 50 | - 阻塞和非阻塞:调用方等待计算完成; 51 | - 同步和异步:同步,调用方轮询计算是否完成;异步,约定某一通信方式,当任务完成时通知调用方。 52 | 53 | 生产者提交推理任务,返回未完成计算 `future`,调用方不必等待推理任务完成,可继续接收 http 请求,生产新的推理任务,因此是非阻塞的。生产者轮询推理任务是否完成,完成则通过 http 返回推理结果,因此,是同步的。`python asyncio`源码如下: 54 | 55 | ```python 56 | from asyncio.futures import Future 57 | 58 | 59 | async def func(): 60 | future = Future() # 非阻塞 61 | await future # 同步 62 | 63 | 64 | 65 | class Future(): 66 | def __init__(self, *, loop=None): 67 | if loop is None: 68 | # 事件循环,事件指future.set_result事件, 事件发生时执行绑定事件的callback回调函数 69 | self._loop = events.get_event_loop() 70 | 71 | def set_result(self, result): 72 | # set_result事件修改任务状态,之后出发set_result事件绑定回调,事件循环loop与future的非阻塞同步无关; 73 | self._state = _FINISHED 74 | self.__schedule_callbacks() 75 | 76 | def __await__(self): 77 | if not self.done(): # 判断是否完成计算 78 | self._asyncio_future_blocking = True 79 | yield self #出让CPU时间片 80 | if not self.done(): 81 | raise RuntimeError("await wasn't used with future") 82 | return self.result() # 返回推理结果 83 | 84 | __iter__ = __await__ # 循环迭代,轮询方式; 85 | 86 | ``` 87 | 88 | python 中提供两种未完成计算 `future`,`concurrent.futures`和 `asyncio.futures`,前者基于线程通信实现非阻塞异步,后者基于事件循环+协程实现非阻塞同步。 89 | 90 | 2. 动态 Batch 推理 91 | 92 | 通过 Batch 实现并行:基于生产消费设计模式、任务队列、线程通信等机制,推理时可将多个任务合并为一个 Batch 推理,提供推理并行度。 93 | 94 | ## 独占资源分配器 95 | 96 | 独占资源分配器实现两个功能:预处理和推理并行和内存防溢出。 97 | 98 | 由 job 结构和数据流动可知,要想将 job 加入任务队列必须经过预处理 preprocess,而预处理必须申请独占数据资源(python 类对象),才能够存储预处理输出(模型输入)和后处理输入(模型输出),因此,设计独占资源分配器,可以巧妙控制整个推理的进行; 99 | 100 | 1. 内存防溢出:独占资源分配器设定指定数量的独占资源,当所有资源被分配占用时,再次 commit 必须等待之前的任务推理完成,释放资源;因此内存占用恒定,大量 commit 不会造成服务器内存溢出崩溃; 101 | 2. 预处理和推理解耦:独占数据资源数量一般为最大 batch_size 两倍,此时,满足一个 batch_size 进行预处理,一个 batch_size 进行推理,即预备一个 batch_size,实现 prefetch 的功能,从而使预处理和推理同时进行,提高推理效率; 102 | 103 | ## 分解推理 104 | 105 | 分解推理:将推理任务拆分为若干子任务,合并子任务推理结果,最终得到完整的推理结果;分解推理是通过 `jobset`任务集类型实现的。 106 | 107 | **场景一: 遥感图像目标检测** 108 | 将遥感大图 A 拆分为瓦片图 `A1、A2、...、An`,对应创建子任务 `job1、job2、...、jobn`,融合每个瓦片图检测结果,获取大图检测结果; 109 | 110 | # 测试 111 | 112 | 1. 下载balloon 数据集的目标检测模型文件; 113 | 114 | [下载地址](https://pan.baidu.com/s/13HCb_N-Gc1oLp2DG1l2yJw),提取码:lhvs 115 | 116 | 117 | ``` 118 | balloon/ 119 | config.py 120 | yolo.onnx 121 | ``` 122 | 123 | 2. 启动服务 124 | 125 | 切换路径至根目录下: 126 | 127 | ```bash 128 | python app.py 129 | ``` 130 | 131 | 3. 测试接口 132 | 133 | 上传 workspace 目录下 balloon.jpg; 134 | 135 | ![img](doc/demo_api.png) 136 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | import argparse 3 | from enum import Enum 4 | from fastapi import FastAPI, UploadFile, File 5 | 6 | from pyinfer.core.build import build_infer, build_logger 7 | from pyinfer.core.infer import Infer 8 | from pyinfer.utils.common.config import Config 9 | 10 | 11 | 12 | 13 | app = FastAPI() 14 | 15 | infers = {} # 存储初始化后的推理器 16 | 17 | 18 | class OnlineInferName(str, Enum): 19 | """app上线的推理服务""" 20 | DetectionInfer = "Detection目标检测推理" 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--config", default="/volume/wzy/project/PyInfer/applications/balloon/config.py") 26 | return parser.parse_args() 27 | 28 | 29 | def infer_func(infer_name): 30 | 31 | async def wrapper(file: UploadFile = File(...)): 32 | """ 33 | \f 34 | :param infer_name: 推理器名称; 35 | :param file: API接口输入文件流; 36 | 37 | 1. 获取推理器; 38 | 2. 解析文件流至推理器输入; 39 | 3. 提交输入; 40 | 4. 返回推理结果; 41 | 42 | """ 43 | infer: Infer = infers.get(infer_name) 44 | input = infer.parse_raw(filename=file.filename, raw_data=await file.read()) 45 | future = infer.commit(input) 46 | return future if future is None else await future 47 | 48 | return wrapper 49 | 50 | 51 | @app.on_event("startup") 52 | async def startup_event(): 53 | """ 54 | 初始化所有待上线的推理器 55 | 56 | 1. 依据推理器名称,获取相应的配置参数; 57 | 2. 创建并初始化推理器; 58 | 3. 一个推理器初始化成功,则添加其对应的推理服务路由 59 | 60 | """ 61 | args = parse_args() 62 | cfg = Config.fromfile(args.config) 63 | 64 | logger = build_logger(cfg.log) 65 | for item in OnlineInferName: 66 | if cfg.infer.get(item.name) is None: # 1 67 | logger.error(f"{item.name} config lack, build infer failed.") 68 | continue 69 | 70 | infer = build_infer(cfg.infer.get(item.name), logger=logger) # 2 71 | app.post(f'/infer/{item.name}', response_model=infer.Output, 72 | tags=[item.value])(infer_func(item.name)) # 3 73 | infers[item.name] = infer 74 | 75 | 76 | @app.on_event("shutdown") 77 | def shutdown_event(): 78 | """销毁推理器""" 79 | for infer in infers.values(): 80 | infer.destory() 81 | 82 | 83 | @app.get('/health', tags=["服务状态"], summary="") 84 | async def health(): 85 | """服务网络状态测试""" 86 | return {'START': 'UP'} 87 | 88 | 89 | if __name__ == "__main__": 90 | uvicorn.run(app, port = 8805) 91 | -------------------------------------------------------------------------------- /applications/balloon/config.py: -------------------------------------------------------------------------------- 1 | # 支持三种环境变量: {{ AppFolder }}、{{ WorkspaceFolder }}、{{ RootFolder }} 2 | 3 | log = dict(filename=None, level="INFO") 4 | 5 | 6 | infer = dict( 7 | DetectionInfer=dict( 8 | type="DetectionInfer", 9 | engine=dict(type="OnnxInferEngine", 10 | onnx_file = "{{ AppFolder }}/balloon/yolox.onnx"), 11 | hooks=[dict(type="DrawBBoxHook", 12 | out_dir="{{ WorkspaceFolder }}")], 13 | confidence_threshold=0.7, 14 | nms_threshold=0.2, 15 | max_batch_size=16, 16 | max_object=100, 17 | width=640, 18 | height=640, 19 | workspace="{{ WorkspaceFolder }}", 20 | slice=False, 21 | subsize=640, 22 | rate=1, 23 | gap=200, 24 | padding=True, 25 | labelnames=("balloon", ), 26 | device="cuda")) 27 | -------------------------------------------------------------------------------- /doc/demo_api.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/doc/demo_api.png -------------------------------------------------------------------------------- /doc/head.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/doc/head.png -------------------------------------------------------------------------------- /doc/infer_flow.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/doc/infer_flow.gif -------------------------------------------------------------------------------- /doc/基本流程.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 |
aiohttp
aiohttp
创建job
创建job
非阻塞返回
非阻塞返回
forward
forward
job queue
job queue
input
input
route func
route func
异步await结果 
异步await结果 
future
future
set value/触发事件
set value/触发事件
http
http
future result
future result
route func
route func
commit
commit
推理
推理
waits_for_jobs
waits_for_jobs
InferEngine
InferEngine
生产者
生产者
任务队列
任务队列
消费者
消费者
批量获取job
批量获取job
work
work
Viewer does not support full SVG 1.1
-------------------------------------------------------------------------------- /doc/数据流动.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 |
Job
Job
input
input
mono_data
mono_data
traits
traits
future
future
input
input
output
output
input
input
提取数据特征
提取数据特征
input
input
set_result
set_result
output
output
Infer
Infer
input
input
⑦ 存储后处理结果
⑦ 存储后处理结果
postprocess
postprocess
③ 存储预处理结果
③ 存储预处理结果
preprocess
preprocess
input
input
① 存储输入数据
① 存储输入数据
commit
commit
② 待预处理
② 待预处理
⑤ 存储推理结果
⑤ 存储推理结果
work
work
④ 待推理
④ 待推理
⑥ 待后处理 
⑥ 待后处理 
input
input
traits
traits
mono_data
mono_data
output
output
future
future
output
output
input
input
commit
commit
preprocess
preprocess
work
work
Viewer does not support full SVG 1.1
-------------------------------------------------------------------------------- /doc/独占资源分配器.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 |
forward
forward
model
model
preprocess
preprocess
commit
commit
独占资源分配器 Monopoly Allocator
独占资源分配器 Monopoly Allocator
Monopoly
Data
Monopoly...
Monopoly
Data
Monopoly...
Monopoly
Data
Monopoly...


Monopoly
Data


Monopoly...


Monopoly
Data


Monopoly...


Monopoly
Data


Monopoly...
容量 capacity
容量 capacity
unavailable
unavailable
available
available
query
query
分配独占数据资源
分配独占数据资源
unavailable
unavailable
请求独占数据资源
请求独占数据资源
Viewer does not support full SVG 1.1
-------------------------------------------------------------------------------- /pyinfer/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import dirname as dname 3 | 4 | # SET ENV 5 | os.environ["RootFolder"] = dname(dname(__file__)) 6 | os.environ["WorkspaceFolder"] = f"{os.environ['RootFolder']}/workspace" 7 | os.environ["AppFolder"] = f"{os.environ['RootFolder']}/applications" -------------------------------------------------------------------------------- /pyinfer/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .engine import * 2 | from .infer import * 3 | from .hook import * -------------------------------------------------------------------------------- /pyinfer/core/build.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from ..utils.common.registry import INFERS, ENGINES, HOOKS 3 | from ..utils.common.logger import Logger 4 | 5 | __all__ = ["build_infer", "build_engine", "build_hook"] 6 | 7 | 8 | def build_infer(config: Dict, **kwargs): 9 | if INFERS.get(config.get('type')) is None: 10 | raise KeyError(f"Cannot found infer type {config.get('type')}.") 11 | 12 | infer = INFERS.get(config.pop('type'))(**kwargs) 13 | start_params = infer.StartParams.parse_obj(config) 14 | if not infer.startup(start_params): 15 | return 16 | else: 17 | return infer 18 | 19 | 20 | def build_engine(config: Dict, **kwargs): 21 | if ENGINES.get(config.get('type')) is None: 22 | raise KeyError(f"Cannot found infer engine {config.get('type')}.") 23 | 24 | engine = ENGINES.get(config.pop('type'))(**kwargs) 25 | if not engine.build(**config): 26 | return 27 | else: 28 | return engine 29 | 30 | 31 | def build_hook(config): 32 | return HOOKS.build(config) 33 | 34 | 35 | def build_logger(config) -> Logger: 36 | return Logger(**dict(config)) 37 | -------------------------------------------------------------------------------- /pyinfer/core/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .engine import OnnxInferEngine 2 | -------------------------------------------------------------------------------- /pyinfer/core/engine/engine.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABCMeta 2 | import numpy as np 3 | import onnxruntime 4 | from ...utils.common.registry import ENGINES 5 | from ...utils.common.logger import Logger 6 | 7 | __all__ = ["OnnxInferEngine"] 8 | 9 | 10 | 11 | 12 | class InferEngine(metaclass=ABCMeta): 13 | 14 | def __init__(self, device, logger=None, **kwargs) -> None: 15 | self.device = device 16 | self.logger = Logger() if logger is None else logger 17 | 18 | @abstractmethod 19 | def build(self): 20 | pass 21 | 22 | @abstractmethod 23 | def forward(self, job): 24 | pass 25 | 26 | 27 | @ENGINES.register_module() 28 | class OnnxInferEngine(InferEngine): 29 | __providers__ = {"cuda":"CUDAExecutionProvider", "TensorRT":"TensorrtExecutionProvider", "cpu":"TensorrtExecutionProvider"} 30 | 31 | def build(self, onnx_file): 32 | providers = onnxruntime.get_available_providers() 33 | 34 | if self.__providers__[self.device] not in providers: 35 | self.logger.fatal(f"Onnxruntime lack {self.__providers__[self.device]}, build model failed.") 36 | return False 37 | 38 | self.logger.info(f"onnxruntime use [{self.__providers__[self.device]}], support [{','.join(providers)}]") 39 | self.session = onnxruntime.InferenceSession(onnx_file, providers=providers) 40 | self.input_names = [inp.name for inp in self.session.get_inputs()] 41 | self.output_names = [out.name for out in self.session.get_outputs()] 42 | 43 | return True 44 | 45 | def forward(self, batch_input: np.ndarray): 46 | """ 47 | 默认onnx单输入单输出 48 | 49 | batch_input[batch_size, ...] 50 | """ 51 | # 输入格式转换 52 | batch_input = batch_input.transpose(0,3,1,2) 53 | batch_output = self.session.run(self.output_names, {self.input_names[0]:batch_input})[0] 54 | return batch_output 55 | -------------------------------------------------------------------------------- /pyinfer/core/hook/__init__.py: -------------------------------------------------------------------------------- 1 | from .hooks import DrawBBoxHook 2 | -------------------------------------------------------------------------------- /pyinfer/core/hook/hooks.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from typing import List 4 | from PIL import Image, ImageDraw 5 | from ..job import Job 6 | from ...utils.common.registry import HOOKS 7 | from ...utils.detection.bbox import QuadrangleBBox 8 | 9 | __all__ = ["DrawBBoxHook"] 10 | 11 | 12 | class Hook(): 13 | 14 | def __init__(self): 15 | pass 16 | 17 | def after_set_result(self, job): 18 | pass 19 | 20 | 21 | @HOOKS.register_module() 22 | class DrawBBoxHook(Hook): 23 | 24 | def __init__(self, out_dir, prefix = "infer_") -> None: 25 | self.out_dir = out_dir 26 | self.prefix = prefix 27 | 28 | def after_set_result(self, job: Job): 29 | """绘制目标框""" 30 | image_np, filename = job.input.image.astype(np.uint8), job.input.filename 31 | bboxes: List[QuadrangleBBox] = job.future.result().bboxes 32 | 33 | image_background = Image.fromarray(image_np) 34 | draw_background = ImageDraw.Draw(image_background) 35 | for bbox in bboxes: 36 | draw_background.rectangle((bbox.left,bbox.top,bbox.right,bbox.bottom),fill='#FF00FF') 37 | 38 | image = Image.fromarray(cv2.addWeighted(image_np, 0.8, np.array(image_background),0.2, 0)) 39 | draw = ImageDraw.Draw(image) 40 | for bbox in bboxes: 41 | draw.rectangle((bbox.left,bbox.top,bbox.right,bbox.bottom),outline='#FF00FF') 42 | draw.text((bbox.left,bbox.top),bbox.labelname) 43 | cv2.imwrite(f"{self.out_dir}/{self.prefix}{filename}", np.array(image)) 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /pyinfer/core/infer/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection import DetectionInfer 2 | from .base import Infer -------------------------------------------------------------------------------- /pyinfer/core/infer/base.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import threading 3 | import numpy as np 4 | from abc import abstractmethod, ABCMeta 5 | from typing import List, Union, Any, Dict 6 | from pydantic import BaseModel, Field 7 | 8 | from ..mono_allocator import MonoAllocator 9 | from ..job import Job, JobSet 10 | from ...utils.common.logger import Logger 11 | from ..build import build_engine, build_hook 12 | 13 | __all__ = ["Infer"] 14 | 15 | 16 | class Infer(metaclass=ABCMeta): 17 | """ 18 | 推理器基类,接口类 19 | 20 | parser_raw commit preprocess append await job future 21 | 生产者:raw_input----------------> infer input--------->job-------------->valid job--------->job queue---------------------> 22 | 23 | engine forward postprocess set_result 24 | 消费者:job.mono_data.input---------------------> job.mono_data.output---------------->job.output-----------------> future result 25 | """ 26 | class Input(BaseModel): 27 | """输入""" 28 | pass 29 | 30 | class Output(BaseModel): 31 | """输出""" 32 | pass 33 | 34 | class StartParams(BaseModel): 35 | """推理配置参数""" 36 | engine: Dict = Field(description="infer engine config dict") 37 | hooks: List[Dict] = Field(description="钩子", default=[]) 38 | max_batch_size: int 39 | workspace: str = Field(default="workspace/") 40 | 41 | def __init__(self, logger=None): 42 | self._loop = asyncio.get_running_loop() # 绑定异步事件循环,实现基于异步future的事件通知机制; 43 | self._cond = threading.Condition() 44 | self._job_queue = [] 45 | self.logger = Logger() if logger is None else logger 46 | self._hooks = [] 47 | 48 | def startup(self, start_params: StartParams): 49 | """启动推理消费线程""" 50 | self.start_params = start_params 51 | self._run = True 52 | self._mono_allocator = MonoAllocator( 53 | self.start_params.max_batch_size * 2) 54 | start_job = Job() 55 | t = threading.Thread(target=self.work, args=(start_job, )) 56 | t.start() 57 | # 阻塞,确保模型引擎初始化完成 58 | return start_job.future.result() 59 | 60 | @abstractmethod 61 | def parse_raw(self, filename, raw_data) -> Union[Input, List[Input]]: 62 | """ 63 | 由网络输入解析为Infer的Input格式,或List[Input]格式 64 | """ 65 | pass 66 | 67 | @abstractmethod 68 | def preprocess(self, job: Job): 69 | """前处理""" 70 | pass 71 | 72 | @abstractmethod 73 | def postprocess(self, job: Job): 74 | """后处理""" 75 | pass 76 | 77 | 78 | 79 | def commit(self, inp: Union[Input, List[Input]]): 80 | """ 81 | 提交任务 82 | 83 | 将输入封装成job,提交至任务队列; 84 | """ 85 | if isinstance(inp, list): 86 | future = self.__commits(inp) 87 | else: 88 | future = self.__commit(inp) 89 | return future 90 | 91 | 92 | def __commit(self, inp: Input): 93 | "提交单个任务" 94 | job = Job(inp, self._loop) 95 | 96 | if not self.preprocess(job): 97 | return 98 | 99 | # 预处理完成,将valid job添加至任务队列 100 | with self._cond: 101 | self._job_queue.append(job) 102 | self._cond.notify() 103 | 104 | # 添加钩子 105 | for hook in self._hooks: 106 | job.add_done_call_back(hook.after_set_result) 107 | 108 | return job.future 109 | 110 | def __commits(self, inps: List[Input]): 111 | """ 112 | 提交任务集 113 | 114 | inps[0]表示原始输入,inps[1:]表示原始输入拆分生成的n个子输入; 115 | """ 116 | assert len(inps) >= 2, "inps num less than 2, check parse_raw method." 117 | real_inp, *sub_inps = inps 118 | 119 | # 创建子任务 120 | jobs = [Job(sub_inp, self._loop) for sub_inp in sub_inps] 121 | # 若存在无法预处理的子任务,则提交失败; 122 | if not all([self.preprocess(job) for job in jobs]): 123 | return 124 | 125 | # 预处理完成,将子任务添加至任务队列 126 | with self._cond: 127 | self._job_queue.extend(jobs) 128 | self._cond.notify() 129 | 130 | # 创建任务集 131 | jobset = JobSet(jobs, self.collect_fn, real_inp, self._loop) 132 | 133 | # 添加钩子 134 | for hook in self._hooks: 135 | jobset.add_done_call_back(hook.after_set_result) 136 | return jobset.future 137 | 138 | def collect_fn(self, jobset: JobSet) -> Any: 139 | """ 140 | 任务集结果融合:汇集所有子任务推理结果,collect_fn返回值作为任务集结果;当推理输入被拆分为多个子输入分别推理时,collect_fn必须重写; 141 | 142 | Args: 143 | jobset: 任务集 144 | 145 | return: Any, 返回值保存在jobset.future.result; 146 | """ 147 | raise NotImplementedError 148 | 149 | def wait_for_job(self) -> Job: 150 | """获取任务""" 151 | with self._cond: 152 | self._cond.wait_for( 153 | lambda: (len(self._job_queue) > 0 or not self._run)) 154 | if not self._run: 155 | return 156 | job = self._job_queue.pop(0) 157 | return job 158 | 159 | def wait_for_jobs(self) -> List[Job]: 160 | """获取一批任务""" 161 | with self._cond: 162 | self._cond.wait_for( 163 | lambda: (len(self._job_queue) > 0 or not self._run)) 164 | if not self._run: 165 | return [] 166 | max_size = min(self.start_params.max_batch_size, 167 | len(self._job_queue)) 168 | fetch_jobs = self._job_queue[:max_size] 169 | self._job_queue = self._job_queue[max_size:] 170 | return fetch_jobs 171 | 172 | def work(self, start_job: Job): 173 | """推理""" 174 | 175 | # 1.推理引擎创建并初始化 176 | engine = build_engine(self.start_params.engine, device =self.start_params.device , logger=self.logger) 177 | if engine is None: 178 | start_job.future.set_result(False) # 初始化推理引擎失败,退出推理消费者线程 179 | return 180 | 181 | for hook_cfg in self.start_params.hooks: 182 | hook = build_hook(hook_cfg) 183 | if hook is None: 184 | start_job.future.set_result(False) # 初始化钩子失败,退出推理消费者线程 185 | return 186 | self._hooks.append(hook) 187 | 188 | # 启动成功 189 | start_job.future.set_result(True) 190 | 191 | while self._run: 192 | # 2.取任务:从任务队里中取出预处理完成的job 193 | fetch_jobs = self.wait_for_jobs() 194 | if len(fetch_jobs) == 0: # 当推理停止时,fetch_jobs返回为空 195 | for job in self._job_queue: 196 | job.future.set_result(job.output) # 未完成job推理结果置None 197 | break 198 | 199 | # 3.组合batch 200 | batch_input = np.stack([job.mono_data.input for job in fetch_jobs]) 201 | # 4.推理 202 | batch_output = engine.forward(batch_input) 203 | self.logger.info( 204 | f"Infer batch size({len(fetch_jobs)}), wait jobs({len(self._job_queue)})") 205 | for index, job in enumerate(fetch_jobs): 206 | # 5.取模型输出结果:engine->job.mono_data.output 207 | job.mono_data.output = batch_output[index] 208 | # 6.后处理,job.mono_data.output->postprocess->job.output 209 | self.postprocess(job) 210 | # 7.释放独占数据资源 211 | job.mono_data.release() 212 | job.mono_data = None 213 | # 8.返回 214 | job.future.set_result(job.output) 215 | 216 | def __del__(self): 217 | with self._cond: 218 | self._run = False 219 | self._cond.notify() 220 | self.logger.info(f"{self.__class__.__name__} is destoryed.") 221 | 222 | def destory(self): 223 | self.__del__() 224 | -------------------------------------------------------------------------------- /pyinfer/core/infer/detection.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | from typing import List, Union, Dict, Tuple 5 | from pydantic import BaseModel, Field 6 | from .base import Infer 7 | 8 | from ..job import Job, JobSet 9 | 10 | from ...utils.detection.nms import cpu_nms 11 | from ...utils.detection.bbox import QuadrangleBBox 12 | from ...utils.functional.slice import slice_one_image 13 | from ...utils.functional.traits import WarpAffineTraits 14 | from ..build import INFERS 15 | 16 | @INFERS.register_module() 17 | class DetectionInfer(Infer): 18 | class Input(BaseModel): 19 | filename: str = Field(describe="图像文件名称") 20 | image: Union[np.ndarray, None] = Field( 21 | default=None, describe="图像解析后numpy对象") 22 | ox: float = Field(default=0, describe="image坐标系原点x坐标") 23 | oy: float = Field(default=0, describe="image坐标系原点y坐标") 24 | 25 | class Config: 26 | arbitrary_types_allowed = True 27 | 28 | class Output(BaseModel): 29 | bbox_num: int 30 | bboxes: List[QuadrangleBBox] 31 | 32 | def to_array(self): 33 | ret = [] 34 | for bbox in self.bboxes: 35 | ret.append([bbox.left, bbox.right, bbox.top, 36 | bbox.bottom, bbox.label]) 37 | return np.array(ret).astype(np.int) 38 | 39 | class StartParams(BaseModel): 40 | engine: Dict 41 | confidence_threshold: float = Field(gt=0, describe="置信度阈值") 42 | nms_threshold: float = Field(gt=0.0, describe="NMS阈值") 43 | max_batch_size: int = Field(gt=0, describe="推理时最大batch_size") 44 | max_object: int = Field(gt=0, describe="图像包含的最大目标数量") 45 | width: int = Field(gt=0, describe="推理时送入网络的图像宽度") 46 | height: int = Field(gt=0, describe="推理时送入网络的图像高度") 47 | workspace: str = Field(default="workspace/") 48 | labelnames: Tuple = Field(describe="label names") 49 | timeout: int = Field(default=10, describe="提交job的最大等待时间,若超时,则提交任务失败.") 50 | hooks: List[Dict] = Field(description="钩子", default=[]) 51 | slice: bool = Field(description="是否启动分片推理模式", default=False) 52 | device: str = Field(description="推理设备", default="cuda:0") 53 | # 分片推理模式下需要配置的参数 54 | subsize: int = Field(default=640, describe="瓦片大小") 55 | rate: int = Field(default=1, describe="大图resize系数") 56 | gap: int = Field(default=200, describe="瓦片间重叠大小") 57 | padding: bool = Field(default=True, describe="瓦片是否padding至subsize大小") 58 | 59 | def parse_raw(self, filename, raw_data): 60 | """数据解析为Input""" 61 | if self.start_params.slice: 62 | return self.__slice_parse_raw(filename, raw_data) 63 | else: 64 | return self.__parse_raw(filename, raw_data) 65 | 66 | def __parse_raw(self, filename, raw_data): 67 | """数据解析为Input, 非分片推理""" 68 | image = cv2.imdecode(np.frombuffer(raw_data, np.uint8), 69 | cv2.IMREAD_COLOR).astype(np.float32) 70 | inp = self.Input.parse_obj(dict(filename=filename, image=image)) 71 | return inp 72 | 73 | def __slice_parse_raw(self, filename, raw_data): 74 | """数据解析为Input, 分片推理,将整图拆分为多个Input""" 75 | image = cv2.imdecode(np.frombuffer(raw_data, np.uint8), 76 | cv2.IMREAD_COLOR).astype(np.float32) 77 | image_base_name, extension = os.path.splitext(filename) 78 | patches = slice_one_image(image=image, 79 | image_base_name=image_base_name, 80 | subsize=self.start_params.subsize, 81 | rate=self.start_params.rate, 82 | gap=self.start_params.gap, 83 | padding=self.start_params.padding, 84 | bboxes=None) 85 | # sub input 86 | inps = [ 87 | self.Input.parse_obj( 88 | dict(filename=f"{patch_name}{extension}", image=patch_image, ox=left, oy=top)) 89 | for patch_name, rate, left, top, patch_image, patch_bboxes in patches 90 | ] 91 | # real input 92 | inps.insert(0, self.Input.parse_obj( 93 | dict(filename=filename, image=image, left=0, top=0))) 94 | return inps 95 | 96 | def collect_fn(self, jobset: JobSet): 97 | """slice分片结果融合""" 98 | bboxes: List[QuadrangleBBox] = [] 99 | for job in jobset.jobs: 100 | ox, oy = job.input.ox, job.input.oy 101 | # 取每个瓦片图的检测框 102 | patch_bboxes: List[QuadrangleBBox] = job.output.bboxes 103 | # 检测框更新原点至(0,0) 104 | bboxes.extend([bbox.reset_origin(-ox, -oy) 105 | for bbox in patch_bboxes]) 106 | bboxes = cpu_nms(bboxes, self.start_params.nms_threshold) # nms 107 | output = self.Output(bboxes=bboxes, bbox_num=len(bboxes)) # 返回大图检测框 108 | return output 109 | 110 | def preprocess(self, job: Job): 111 | """预处理""" 112 | if job.input is None or job.input.image is None: 113 | self.logger.error( 114 | f"input image is empty. Please check {job.input.filename}.") 115 | return False 116 | 117 | # 预处理要求获取独占数据资源 118 | mono_data = self._mono_allocator.query(self.start_params.timeout) 119 | if mono_data is None: 120 | self.logger.error("query mono data timeout.") 121 | return False 122 | else: 123 | job.mono_data = mono_data 124 | 125 | # 通过warpaffineTraits萃取器,计算仿射变换矩阵、逆矩阵、并完成仿射变换 126 | job.traits = WarpAffineTraits(job.input.image.shape[1], job.input.image.shape[0], self.start_params.width, 127 | self.start_params.height) 128 | # mono_data.input存储预处理结果,作为模型推理的输入 129 | job.mono_data.input = job.traits(job.input.image) 130 | return True 131 | 132 | def decode(self, job): 133 | bboxes = [] 134 | # 从独占数据资源中取出模型推理结果, results=List[bbox]格式,其中bbox为[left,top,right,width, height, confidence, *scores] 135 | results = job.mono_data.output 136 | for result in results: 137 | xc, yc, width, height, confidence, *scores = result 138 | if confidence < self.start_params.confidence_threshold: 139 | continue 140 | 141 | max_score_index = np.argmax(scores) 142 | max_score = scores[max_score_index] 143 | if max_score*confidence < self.start_params.confidence_threshold: 144 | continue 145 | 146 | label = max_score_index 147 | # 通过逆仿射变换将bbox坐标映射回原图 148 | left, top = job.traits.to_src_coord(xc - width/2, yc - height/2) 149 | right, bottom = job.traits.to_src_coord(xc + width/2, yc + height/2) 150 | labelname = self.start_params.labelnames[label] 151 | bbox = QuadrangleBBox(x1=left, 152 | y1=top, 153 | x2=right, 154 | y2=top, 155 | x3=right, 156 | y3=bottom, 157 | x4=left, 158 | y4=bottom, 159 | confidence=confidence, 160 | label=label, 161 | labelname=labelname, 162 | keepflag=True) 163 | bboxes.append(bbox) 164 | return bboxes 165 | 166 | def nms(self, bboxes:List[QuadrangleBBox]): 167 | return cpu_nms(bboxes, self.start_params.nms_threshold) 168 | 169 | def postprocess(self, job): 170 | """后处理""" 171 | 172 | bboxes = self.nms(self.decode(job)) 173 | job.output = self.Output(bbox_num=len(bboxes), bboxes = bboxes) -------------------------------------------------------------------------------- /pyinfer/core/job.py: -------------------------------------------------------------------------------- 1 | from concurrent import futures 2 | from typing import List 3 | 4 | 5 | class Job(): 6 | def __init__(self, inp=None, loop=None) -> None: 7 | """ 8 | 推理job 9 | 10 | :param input: 推理输入 11 | :param output: 推理输出 12 | :param mono_data: 独占数据资源,相当于workspace,推理任务必须获取独占数据资源,才能被提交至推理队列; 13 | mono_data.input记录输入预处理后数据,等待被送入模型,mono_data.output记录模型输出数据,等待被后处理; 14 | :param traits: 特征萃取,提取当前输入的相关特征,辅助预处理和后处理; 15 | :param future: 非阻塞,记录响应结果;loop非None时,为异步job,需要在异步函数配合await使用,否则为同步; 16 | 17 | """ 18 | self.input = inp 19 | self.output = None 20 | self.traits = None 21 | self.mono_data = None 22 | self.future = futures.Future() if loop is None else loop.create_future() 23 | self.call_backs = [] 24 | self.future.add_done_callback(self.run_call_back) 25 | 26 | def add_done_call_back(self, call_back_fn): 27 | self.call_backs.append(call_back_fn) 28 | 29 | def run_call_back(self, future): 30 | for cb in self.call_backs: 31 | cb(self) 32 | 33 | 34 | class JobSet(Job): 35 | def __init__(self, jobs: List[Job], collect_fn, inp=None, loop=None) -> None: 36 | """ 37 | 任务集合:当所有子任务均完成时,JobSet完成 38 | 39 | 40 | :param jobs: 任务列表 41 | :param loop: loop非None时,为异步job,需要在异步函数配合await使用,否则为同步; 42 | :param collect_fn: 推理结果回收,collect_fn输入为所有job的future.result列表,collect_fn返回值为JobSet.future.result,即整个任务集合的结果; 43 | """ 44 | self.collect_fn = collect_fn 45 | self.jobs = jobs 46 | for job in self.jobs: 47 | job.add_done_call_back(self.notify) 48 | 49 | super().__init__(inp, loop) 50 | 51 | def done(self): 52 | # 所有job完成,且future未执行set_result时,避免多次重复set_result 53 | return all([job.future.done() for job in self.jobs]) and not self.future.done() 54 | 55 | def notify(self, job): 56 | if self.done(): 57 | self.future.set_result(self.collect_fn(self)) 58 | -------------------------------------------------------------------------------- /pyinfer/core/mono_allocator.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import weakref 3 | import time 4 | from ..utils.common.logger import Logger 5 | 6 | __all__ = ["MonoAllocator"] 7 | 8 | 9 | class MonoAllocator(): 10 | """独占数据资源分配器""" 11 | 12 | class MonoData(): 13 | """独占数据资源""" 14 | 15 | def __init__(self, name, allocator, available=True) -> None: 16 | """ 17 | input: 存储模型网络的输入数据 18 | output: 存储模型网络的输出数据 19 | """ 20 | self.name = name 21 | self._available = available 22 | self._allocator = allocator 23 | self.input = None 24 | self.output = None 25 | self.workspace = None 26 | 27 | 28 | def release(self): 29 | self._allocator.release_one(self) 30 | 31 | def __init__(self, size: int) -> None: 32 | """ 33 | :param 34 | size : 独占数据资源数量,建议设置为batch_size*2; 35 | self.datas : 独占数据资源列表; 36 | self._num_avail : 剩余资源数量; 37 | self._num_thread_wait: 等待资源的线程数量; 38 | self._run: 当程序终止时,触发析构__del__, _run=False,停止资源分配,在等待资源的线程陆续退出等待; 39 | self._lock:线程锁; 40 | self._cond:线程同步,同步资源数量; 41 | self._cond_exit:线程同步,同步等待资源的线程数量,停止资源分配_run=False时,退出时需要等待_num_thread_wait为0; 42 | 43 | """ 44 | self.datas = [self.MonoData(f"mono_data_{i}", weakref.proxy(self)) for i in range(size)] 45 | self._num_avail = size 46 | self._num_thread_wait = 0 47 | self._run = True 48 | self._lock = threading.Lock() 49 | self._cond = threading.Condition(self._lock) 50 | self._cond_exit = threading.Condition(self._lock) 51 | self.logger = Logger() 52 | 53 | def query(self, timeout=10): 54 | """请求独占数据资源""" 55 | with self._cond: 56 | if not self._run: 57 | # 推理终止时停止分配数据资源 58 | return 59 | 60 | # 等待独占数据资源 61 | if self._num_avail == 0: 62 | self._num_thread_wait += 1 # 排队线程数+1 63 | state = self._cond.wait_for(lambda: (self._num_avail > 0 or not self._run), timeout) 64 | self._num_thread_wait -= 1 # 排队线程数-1 65 | # 更新析构状态 66 | self._cond_exit.notify() 67 | # 未获取到独占数据资源,可能是请求超时或分配器停止分配 68 | if not state or self._num_avail == 0 or not self._run: 69 | return 70 | 71 | # 返回请求得到的独占数据资源 72 | for mono_data in self.datas: 73 | if mono_data._available: 74 | self.logger.debug(f"occupy {mono_data.name}") 75 | mono_data._available = False 76 | self._num_avail -= 1 77 | return mono_data 78 | 79 | def release_one(self, mono_data: MonoData): 80 | """释放独占数据资源所有权""" 81 | with self._cond: 82 | self._num_avail += 1 83 | mono_data._available = True 84 | mono_data.input = None 85 | mono_data.output = None 86 | self._cond.notify_all() 87 | time.sleep(1e-4) 88 | self.logger.debug(f"release {mono_data.name}") 89 | 90 | def __del__(self): 91 | """对象析构时调用""" 92 | with self._cond: 93 | self._run = False 94 | self._cond.notify_all() 95 | 96 | with self._cond_exit: 97 | # 等待所有等待线程退出wait状态 98 | self._cond_exit.wait_for(lambda: (self._num_thread_wait == 0)) 99 | 100 | self.logger.info("MonoAllocator is destoryed.") 101 | 102 | def destory(self): 103 | self.__del__() 104 | -------------------------------------------------------------------------------- /pyinfer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/pyinfer/utils/__init__.py -------------------------------------------------------------------------------- /pyinfer/utils/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/pyinfer/utils/common/__init__.py -------------------------------------------------------------------------------- /pyinfer/utils/common/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import ast 3 | import copy 4 | from hashlib import new 5 | import os 6 | import os.path as osp 7 | import platform 8 | import shutil 9 | import sys 10 | import tempfile 11 | import types 12 | import uuid 13 | import warnings 14 | from argparse import Action, ArgumentParser 15 | from collections import abc 16 | from importlib import import_module 17 | from pathlib import Path 18 | 19 | from addict import Dict 20 | from yapf.yapflib.yapf_api import FormatCode 21 | import re 22 | 23 | 24 | if platform.system() == 'Windows': 25 | import regex as re # type: ignore 26 | else: 27 | import re # type: ignore 28 | 29 | BASE_KEY = '_base_' 30 | DELETE_KEY = '_delete_' 31 | DEPRECATION_KEY = '_deprecation_' 32 | RESERVED_KEYS = ['filename', 'text', 'pretty_text'] 33 | 34 | 35 | def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): 36 | if not osp.isfile(filename): 37 | raise FileNotFoundError(msg_tmpl.format(filename)) 38 | 39 | 40 | def import_modules_from_strings(imports, allow_failed_imports=False): 41 | """Import modules from the given list of strings. 42 | 43 | Args: 44 | imports (list | str | None): The given module names to be imported. 45 | allow_failed_imports (bool): If True, the failed imports will return 46 | None. Otherwise, an ImportError is raise. Default: False. 47 | 48 | Returns: 49 | list[module] | module | None: The imported modules. 50 | 51 | Examples: 52 | >>> osp, sys = import_modules_from_strings( 53 | ... ['os.path', 'sys']) 54 | >>> import os.path as osp_ 55 | >>> import sys as sys_ 56 | >>> assert osp == osp_ 57 | >>> assert sys == sys_ 58 | """ 59 | if not imports: 60 | return 61 | single_import = False 62 | if isinstance(imports, str): 63 | single_import = True 64 | imports = [imports] 65 | if not isinstance(imports, list): 66 | raise TypeError( 67 | f'custom_imports must be a list but got type {type(imports)}') 68 | imported = [] 69 | for imp in imports: 70 | if not isinstance(imp, str): 71 | raise TypeError( 72 | f'{imp} is of type {type(imp)} and cannot be imported.') 73 | try: 74 | imported_tmp = import_module(imp) 75 | except ImportError: 76 | if allow_failed_imports: 77 | warnings.warn(f'{imp} failed to import and is ignored.', 78 | UserWarning) 79 | imported_tmp = None 80 | else: 81 | raise ImportError 82 | imported.append(imported_tmp) 83 | if single_import: 84 | imported = imported[0] 85 | return imported 86 | 87 | 88 | class ConfigDict(Dict): 89 | def __init__(self, config_dict): 90 | """支持三种环境变量""" 91 | self.pattern_dict = {re.compile('{{ AppFolder }}'): os.environ["AppFolder"], 92 | re.compile('{{ WorkspaceFolder }}'):os.environ["WorkspaceFolder"], 93 | re.compile('{{ RootFolder }}'):os.environ["RootFolder"]} 94 | 95 | config_dict = self.identify_env_variables(config_dict) 96 | self.__delattr__("pattern_dict") 97 | super().__init__(config_dict) 98 | 99 | def identify_env_variables(self, config_dict): 100 | for k, v in config_dict.items(): 101 | if isinstance(v,dict): 102 | config_dict[k] = self.identify_env_variables(v) 103 | elif isinstance(v,str): 104 | for pattern , repl in super().__getattr__('pattern_dict').items(): 105 | new_v = re.sub(pattern, repl, v) 106 | if new_v != config_dict[k]: 107 | config_dict[k] = new_v 108 | break 109 | return config_dict 110 | 111 | 112 | 113 | 114 | def __missing__(self, name): 115 | raise KeyError(name) 116 | 117 | def __getattr__(self, name): 118 | try: 119 | value = super().__getattr__(name) 120 | """识别环境变量""" 121 | if isinstance(value, str): 122 | for pattern , repl in super().__getattr__('pattern_dict').items(): 123 | value = re.sub(pattern, repl, value) 124 | except KeyError: 125 | ex = AttributeError(f"'{self.__class__.__name__}' object has no " 126 | f"attribute '{name}'") 127 | except Exception as e: 128 | ex = e 129 | else: 130 | return value 131 | raise ex 132 | 133 | 134 | def add_args(parser, cfg, prefix=''): 135 | for k, v in cfg.items(): 136 | if isinstance(v, str): 137 | parser.add_argument('--' + prefix + k) 138 | elif isinstance(v, int): 139 | parser.add_argument('--' + prefix + k, type=int) 140 | elif isinstance(v, float): 141 | parser.add_argument('--' + prefix + k, type=float) 142 | elif isinstance(v, bool): 143 | parser.add_argument('--' + prefix + k, action='store_true') 144 | elif isinstance(v, dict): 145 | add_args(parser, v, prefix + k + '.') 146 | elif isinstance(v, abc.Iterable): 147 | parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+') 148 | else: 149 | print(f'cannot parse key {prefix + k} of type {type(v)}') 150 | return parser 151 | 152 | 153 | class Config: 154 | """A facility for config and config files. 155 | 156 | It supports common file formats as configs: python/json/yaml. The interface 157 | is the same as a dict object and also allows access config values as 158 | attributes. 159 | 160 | Example: 161 | >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) 162 | >>> cfg.a 163 | 1 164 | >>> cfg.b 165 | {'b1': [0, 1]} 166 | >>> cfg.b.b1 167 | [0, 1] 168 | >>> cfg = Config.fromfile('tests/data/config/a.py') 169 | >>> cfg.filename 170 | "/home/kchen/projects/mmcv/tests/data/config/a.py" 171 | >>> cfg.item4 172 | 'test' 173 | >>> cfg 174 | "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " 175 | "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" 176 | """ 177 | 178 | @staticmethod 179 | def _validate_py_syntax(filename): 180 | with open(filename, encoding='utf-8') as f: 181 | # Setting encoding explicitly to resolve coding issue on windows 182 | content = f.read() 183 | try: 184 | ast.parse(content) 185 | except SyntaxError as e: 186 | raise SyntaxError('There are syntax errors in config ' 187 | f'file {filename}: {e}') 188 | 189 | @staticmethod 190 | def _substitute_predefined_vars(filename, temp_config_name): 191 | file_dirname = osp.dirname(filename) 192 | file_basename = osp.basename(filename) 193 | file_basename_no_extension = osp.splitext(file_basename)[0] 194 | file_extname = osp.splitext(filename)[1] 195 | support_templates = dict( 196 | fileDirname=file_dirname, 197 | fileBasename=file_basename, 198 | fileBasenameNoExtension=file_basename_no_extension, 199 | fileExtname=file_extname) 200 | with open(filename, encoding='utf-8') as f: 201 | # Setting encoding explicitly to resolve coding issue on windows 202 | config_file = f.read() 203 | for key, value in support_templates.items(): 204 | regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' 205 | value = value.replace('\\', '/') 206 | config_file = re.sub(regexp, value, config_file) 207 | with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: 208 | tmp_config_file.write(config_file) 209 | 210 | @staticmethod 211 | def _pre_substitute_base_vars(filename, temp_config_name): 212 | """Substitute base variable placehoders to string, so that parsing 213 | would work.""" 214 | with open(filename, encoding='utf-8') as f: 215 | # Setting encoding explicitly to resolve coding issue on windows 216 | config_file = f.read() 217 | base_var_dict = {} 218 | regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}' 219 | base_vars = set(re.findall(regexp, config_file)) 220 | for base_var in base_vars: 221 | randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}' 222 | base_var_dict[randstr] = base_var 223 | regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}' 224 | config_file = re.sub(regexp, f'"{randstr}"', config_file) 225 | with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: 226 | tmp_config_file.write(config_file) 227 | return base_var_dict 228 | 229 | @staticmethod 230 | def _substitute_base_vars(cfg, base_var_dict, base_cfg): 231 | """Substitute variable strings to their actual values.""" 232 | cfg = copy.deepcopy(cfg) 233 | 234 | if isinstance(cfg, dict): 235 | for k, v in cfg.items(): 236 | if isinstance(v, str) and v in base_var_dict: 237 | new_v = base_cfg 238 | for new_k in base_var_dict[v].split('.'): 239 | new_v = new_v[new_k] 240 | cfg[k] = new_v 241 | elif isinstance(v, (list, tuple, dict)): 242 | cfg[k] = Config._substitute_base_vars( 243 | v, base_var_dict, base_cfg) 244 | elif isinstance(cfg, tuple): 245 | cfg = tuple( 246 | Config._substitute_base_vars(c, base_var_dict, base_cfg) 247 | for c in cfg) 248 | elif isinstance(cfg, list): 249 | cfg = [ 250 | Config._substitute_base_vars(c, base_var_dict, base_cfg) 251 | for c in cfg 252 | ] 253 | elif isinstance(cfg, str) and cfg in base_var_dict: 254 | new_v = base_cfg 255 | for new_k in base_var_dict[cfg].split('.'): 256 | new_v = new_v[new_k] 257 | cfg = new_v 258 | 259 | return cfg 260 | 261 | @staticmethod 262 | def _file2dict(filename, use_predefined_variables=True): 263 | filename = osp.abspath(osp.expanduser(filename)) 264 | check_file_exist(filename) 265 | fileExtname = osp.splitext(filename)[1] 266 | if fileExtname not in ['.py', '.json', '.yaml', '.yml']: 267 | raise OSError('Only py/yml/yaml/json type are supported now!') 268 | 269 | with tempfile.TemporaryDirectory() as temp_config_dir: 270 | temp_config_file = tempfile.NamedTemporaryFile( 271 | dir=temp_config_dir, suffix=fileExtname) 272 | if platform.system() == 'Windows': 273 | temp_config_file.close() 274 | temp_config_name = osp.basename(temp_config_file.name) 275 | # Substitute predefined variables 276 | if use_predefined_variables: 277 | Config._substitute_predefined_vars(filename, 278 | temp_config_file.name) 279 | else: 280 | shutil.copyfile(filename, temp_config_file.name) 281 | # Substitute base variables from placeholders to strings 282 | base_var_dict = Config._pre_substitute_base_vars( 283 | temp_config_file.name, temp_config_file.name) 284 | 285 | if filename.endswith('.py'): 286 | temp_module_name = osp.splitext(temp_config_name)[0] 287 | sys.path.insert(0, temp_config_dir) 288 | Config._validate_py_syntax(filename) 289 | mod = import_module(temp_module_name) 290 | sys.path.pop(0) 291 | cfg_dict = { 292 | name: value 293 | for name, value in mod.__dict__.items() 294 | if not name.startswith('__') 295 | and not isinstance(value, types.ModuleType) 296 | and not isinstance(value, types.FunctionType) 297 | } 298 | # delete imported module 299 | del sys.modules[temp_module_name] 300 | elif filename.endswith(('.yml', '.yaml', '.json')): 301 | import mmcv 302 | cfg_dict = mmcv.load(temp_config_file.name) 303 | # close temp file 304 | temp_config_file.close() 305 | 306 | # check deprecation information 307 | if DEPRECATION_KEY in cfg_dict: 308 | deprecation_info = cfg_dict.pop(DEPRECATION_KEY) 309 | warning_msg = f'The config file {filename} will be deprecated ' \ 310 | 'in the future.' 311 | if 'expected' in deprecation_info: 312 | warning_msg += f' Please use {deprecation_info["expected"]} ' \ 313 | 'instead.' 314 | if 'reference' in deprecation_info: 315 | warning_msg += ' More information can be found at ' \ 316 | f'{deprecation_info["reference"]}' 317 | warnings.warn(warning_msg, DeprecationWarning) 318 | 319 | cfg_text = filename + '\n' 320 | with open(filename, encoding='utf-8') as f: 321 | # Setting encoding explicitly to resolve coding issue on windows 322 | cfg_text += f.read() 323 | 324 | if BASE_KEY in cfg_dict: 325 | cfg_dir = osp.dirname(filename) 326 | base_filename = cfg_dict.pop(BASE_KEY) 327 | base_filename = base_filename if isinstance( 328 | base_filename, list) else [base_filename] 329 | 330 | cfg_dict_list = list() 331 | cfg_text_list = list() 332 | for f in base_filename: 333 | _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) 334 | cfg_dict_list.append(_cfg_dict) 335 | cfg_text_list.append(_cfg_text) 336 | 337 | base_cfg_dict = dict() 338 | for c in cfg_dict_list: 339 | duplicate_keys = base_cfg_dict.keys() & c.keys() 340 | if len(duplicate_keys) > 0: 341 | raise KeyError('Duplicate key is not allowed among bases. ' 342 | f'Duplicate keys: {duplicate_keys}') 343 | base_cfg_dict.update(c) 344 | 345 | # Substitute base variables from strings to their actual values 346 | cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, 347 | base_cfg_dict) 348 | 349 | base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) 350 | cfg_dict = base_cfg_dict 351 | 352 | # merge cfg_text 353 | cfg_text_list.append(cfg_text) 354 | cfg_text = '\n'.join(cfg_text_list) 355 | 356 | return cfg_dict, cfg_text 357 | 358 | @staticmethod 359 | def _merge_a_into_b(a, b, allow_list_keys=False): 360 | """merge dict ``a`` into dict ``b`` (non-inplace). 361 | 362 | Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid 363 | in-place modifications. 364 | 365 | Args: 366 | a (dict): The source dict to be merged into ``b``. 367 | b (dict): The origin dict to be fetch keys from ``a``. 368 | allow_list_keys (bool): If True, int string keys (e.g. '0', '1') 369 | are allowed in source ``a`` and will replace the element of the 370 | corresponding index in b if b is a list. Default: False. 371 | 372 | Returns: 373 | dict: The modified dict of ``b`` using ``a``. 374 | 375 | Examples: 376 | # Normally merge a into b. 377 | >>> Config._merge_a_into_b( 378 | ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) 379 | {'obj': {'a': 2}} 380 | 381 | # Delete b first and merge a into b. 382 | >>> Config._merge_a_into_b( 383 | ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) 384 | {'obj': {'a': 2}} 385 | 386 | # b is a list 387 | >>> Config._merge_a_into_b( 388 | ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) 389 | [{'a': 2}, {'b': 2}] 390 | """ 391 | b = b.copy() 392 | for k, v in a.items(): 393 | if allow_list_keys and k.isdigit() and isinstance(b, list): 394 | k = int(k) 395 | if len(b) <= k: 396 | raise KeyError(f'Index {k} exceeds the length of list {b}') 397 | b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) 398 | elif isinstance(v, dict): 399 | if k in b and not v.pop(DELETE_KEY, False): 400 | allowed_types = (dict, list) if allow_list_keys else dict 401 | if not isinstance(b[k], allowed_types): 402 | raise TypeError( 403 | f'{k}={v} in child config cannot inherit from ' 404 | f'base because {k} is a dict in the child config ' 405 | f'but is of type {type(b[k])} in base config. ' 406 | f'You may set `{DELETE_KEY}=True` to ignore the ' 407 | f'base config.') 408 | b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) 409 | else: 410 | b[k] = ConfigDict(v) 411 | else: 412 | b[k] = v 413 | return b 414 | 415 | @staticmethod 416 | def fromfile(filename, 417 | use_predefined_variables=True, 418 | import_custom_modules=True): 419 | if isinstance(filename, Path): 420 | filename = str(filename) 421 | cfg_dict, cfg_text = Config._file2dict(filename, 422 | use_predefined_variables) 423 | if import_custom_modules and cfg_dict.get('custom_imports', None): 424 | import_modules_from_strings(**cfg_dict['custom_imports']) 425 | return Config(cfg_dict, cfg_text=cfg_text, filename=filename) 426 | 427 | @staticmethod 428 | def fromstring(cfg_str, file_format): 429 | """Generate config from config str. 430 | 431 | Args: 432 | cfg_str (str): Config str. 433 | file_format (str): Config file format corresponding to the 434 | config str. Only py/yml/yaml/json type are supported now! 435 | 436 | Returns: 437 | :obj:`Config`: Config obj. 438 | """ 439 | if file_format not in ['.py', '.json', '.yaml', '.yml']: 440 | raise OSError('Only py/yml/yaml/json type are supported now!') 441 | if file_format != '.py' and 'dict(' in cfg_str: 442 | # check if users specify a wrong suffix for python 443 | warnings.warn( 444 | 'Please check "file_format", the file format may be .py') 445 | with tempfile.NamedTemporaryFile( 446 | 'w', encoding='utf-8', suffix=file_format, 447 | delete=False) as temp_file: 448 | temp_file.write(cfg_str) 449 | # on windows, previous implementation cause error 450 | # see PR 1077 for details 451 | cfg = Config.fromfile(temp_file.name) 452 | os.remove(temp_file.name) 453 | return cfg 454 | 455 | @staticmethod 456 | def auto_argparser(description=None): 457 | """Generate argparser from config file automatically (experimental)""" 458 | partial_parser = ArgumentParser(description=description) 459 | partial_parser.add_argument('config', help='config file path') 460 | cfg_file = partial_parser.parse_known_args()[0].config 461 | cfg = Config.fromfile(cfg_file) 462 | parser = ArgumentParser(description=description) 463 | parser.add_argument('config', help='config file path') 464 | add_args(parser, cfg) 465 | return parser, cfg 466 | 467 | def __init__(self, cfg_dict=None, cfg_text=None, filename=None): 468 | if cfg_dict is None: 469 | cfg_dict = dict() 470 | elif not isinstance(cfg_dict, dict): 471 | raise TypeError('cfg_dict must be a dict, but ' 472 | f'got {type(cfg_dict)}') 473 | for key in cfg_dict: 474 | if key in RESERVED_KEYS: 475 | raise KeyError(f'{key} is reserved for config file') 476 | 477 | if isinstance(filename, Path): 478 | filename = str(filename) 479 | 480 | super().__setattr__('_cfg_dict', ConfigDict(cfg_dict)) 481 | super().__setattr__('_filename', filename) 482 | if cfg_text: 483 | text = cfg_text 484 | elif filename: 485 | with open(filename, encoding='utf-8') as f: 486 | text = f.read() 487 | else: 488 | text = '' 489 | super().__setattr__('_text', text) 490 | 491 | @property 492 | def filename(self): 493 | return self._filename 494 | 495 | @property 496 | def text(self): 497 | return self._text 498 | 499 | @property 500 | def pretty_text(self): 501 | 502 | indent = 4 503 | 504 | def _indent(s_, num_spaces): 505 | s = s_.split('\n') 506 | if len(s) == 1: 507 | return s_ 508 | first = s.pop(0) 509 | s = [(num_spaces * ' ') + line for line in s] 510 | s = '\n'.join(s) 511 | s = first + '\n' + s 512 | return s 513 | 514 | def _format_basic_types(k, v, use_mapping=False): 515 | if isinstance(v, str): 516 | v_str = f"'{v}'" 517 | else: 518 | v_str = str(v) 519 | 520 | if use_mapping: 521 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 522 | attr_str = f'{k_str}: {v_str}' 523 | else: 524 | attr_str = f'{str(k)}={v_str}' 525 | attr_str = _indent(attr_str, indent) 526 | 527 | return attr_str 528 | 529 | def _format_list(k, v, use_mapping=False): 530 | # check if all items in the list are dict 531 | if all(isinstance(_, dict) for _ in v): 532 | v_str = '[\n' 533 | v_str += '\n'.join( 534 | f'dict({_indent(_format_dict(v_), indent)}),' 535 | for v_ in v).rstrip(',') 536 | if use_mapping: 537 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 538 | attr_str = f'{k_str}: {v_str}' 539 | else: 540 | attr_str = f'{str(k)}={v_str}' 541 | attr_str = _indent(attr_str, indent) + ']' 542 | else: 543 | attr_str = _format_basic_types(k, v, use_mapping) 544 | return attr_str 545 | 546 | def _contain_invalid_identifier(dict_str): 547 | contain_invalid_identifier = False 548 | for key_name in dict_str: 549 | contain_invalid_identifier |= \ 550 | (not str(key_name).isidentifier()) 551 | return contain_invalid_identifier 552 | 553 | def _format_dict(input_dict, outest_level=False): 554 | r = '' 555 | s = [] 556 | 557 | use_mapping = _contain_invalid_identifier(input_dict) 558 | if use_mapping: 559 | r += '{' 560 | for idx, (k, v) in enumerate(input_dict.items()): 561 | is_last = idx >= len(input_dict) - 1 562 | end = '' if outest_level or is_last else ',' 563 | if isinstance(v, dict): 564 | v_str = '\n' + _format_dict(v) 565 | if use_mapping: 566 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 567 | attr_str = f'{k_str}: dict({v_str}' 568 | else: 569 | attr_str = f'{str(k)}=dict({v_str}' 570 | attr_str = _indent(attr_str, indent) + ')' + end 571 | elif isinstance(v, list): 572 | attr_str = _format_list(k, v, use_mapping) + end 573 | else: 574 | attr_str = _format_basic_types(k, v, use_mapping) + end 575 | 576 | s.append(attr_str) 577 | r += '\n'.join(s) 578 | if use_mapping: 579 | r += '}' 580 | return r 581 | 582 | cfg_dict = self._cfg_dict.to_dict() 583 | text = _format_dict(cfg_dict, outest_level=True) 584 | # copied from setup.cfg 585 | yapf_style = dict( 586 | based_on_style='pep8', 587 | blank_line_before_nested_class_or_def=True, 588 | split_before_expression_after_opening_paren=True) 589 | text, _ = FormatCode(text, style_config=yapf_style, verify=True) 590 | 591 | return text 592 | 593 | def __repr__(self): 594 | return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' 595 | 596 | def __len__(self): 597 | return len(self._cfg_dict) 598 | 599 | def __getattr__(self, name): 600 | return getattr(self._cfg_dict, name) 601 | 602 | def __getitem__(self, name): 603 | return self._cfg_dict.__getitem__(name) 604 | 605 | def __setattr__(self, name, value): 606 | if isinstance(value, dict): 607 | value = ConfigDict(value) 608 | self._cfg_dict.__setattr__(name, value) 609 | 610 | def __setitem__(self, name, value): 611 | if isinstance(value, dict): 612 | value = ConfigDict(value) 613 | self._cfg_dict.__setitem__(name, value) 614 | 615 | def __iter__(self): 616 | return iter(self._cfg_dict) 617 | 618 | def __getstate__(self): 619 | return (self._cfg_dict, self._filename, self._text) 620 | 621 | def __copy__(self): 622 | cls = self.__class__ 623 | other = cls.__new__(cls) 624 | other.__dict__.update(self.__dict__) 625 | 626 | return other 627 | 628 | def __deepcopy__(self, memo): 629 | cls = self.__class__ 630 | other = cls.__new__(cls) 631 | memo[id(self)] = other 632 | 633 | for key, value in self.__dict__.items(): 634 | super(Config, other).__setattr__(key, copy.deepcopy(value, memo)) 635 | 636 | return other 637 | 638 | def __setstate__(self, state): 639 | _cfg_dict, _filename, _text = state 640 | super().__setattr__('_cfg_dict', _cfg_dict) 641 | super().__setattr__('_filename', _filename) 642 | super().__setattr__('_text', _text) 643 | 644 | def dump(self, file=None): 645 | """Dumps config into a file or returns a string representation of the 646 | config. 647 | 648 | If a file argument is given, saves the config to that file using the 649 | format defined by the file argument extension. 650 | 651 | Otherwise, returns a string representing the config. The formatting of 652 | this returned string is defined by the extension of `self.filename`. If 653 | `self.filename` is not defined, returns a string representation of a 654 | dict (lowercased and using ' for strings). 655 | 656 | Examples: 657 | >>> cfg_dict = dict(item1=[1, 2], item2=dict(a=0), 658 | ... item3=True, item4='test') 659 | >>> cfg = Config(cfg_dict=cfg_dict) 660 | >>> dump_file = "a.py" 661 | >>> cfg.dump(dump_file) 662 | 663 | Args: 664 | file (str, optional): Path of the output file where the config 665 | will be dumped. Defaults to None. 666 | """ 667 | import mmcv 668 | cfg_dict = super().__getattribute__('_cfg_dict').to_dict() 669 | if file is None: 670 | if self.filename is None or self.filename.endswith('.py'): 671 | return self.pretty_text 672 | else: 673 | file_format = self.filename.split('.')[-1] 674 | return mmcv.dump(cfg_dict, file_format=file_format) 675 | elif file.endswith('.py'): 676 | with open(file, 'w', encoding='utf-8') as f: 677 | f.write(self.pretty_text) 678 | else: 679 | file_format = file.split('.')[-1] 680 | return mmcv.dump(cfg_dict, file=file, file_format=file_format) 681 | 682 | def merge_from_dict(self, options, allow_list_keys=True): 683 | """Merge list into cfg_dict. 684 | 685 | Merge the dict parsed by MultipleKVAction into this cfg. 686 | 687 | Examples: 688 | >>> options = {'model.backbone.depth': 50, 689 | ... 'model.backbone.with_cp':True} 690 | >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) 691 | >>> cfg.merge_from_dict(options) 692 | >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') 693 | >>> assert cfg_dict == dict( 694 | ... model=dict(backbone=dict(depth=50, with_cp=True))) 695 | 696 | >>> # Merge list element 697 | >>> cfg = Config(dict(pipeline=[ 698 | ... dict(type='LoadImage'), dict(type='LoadAnnotations')])) 699 | >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')}) 700 | >>> cfg.merge_from_dict(options, allow_list_keys=True) 701 | >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') 702 | >>> assert cfg_dict == dict(pipeline=[ 703 | ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')]) 704 | 705 | Args: 706 | options (dict): dict of configs to merge from. 707 | allow_list_keys (bool): If True, int string keys (e.g. '0', '1') 708 | are allowed in ``options`` and will replace the element of the 709 | corresponding index in the config if the config is a list. 710 | Default: True. 711 | """ 712 | option_cfg_dict = {} 713 | for full_key, v in options.items(): 714 | d = option_cfg_dict 715 | key_list = full_key.split('.') 716 | for subkey in key_list[:-1]: 717 | d.setdefault(subkey, ConfigDict()) 718 | d = d[subkey] 719 | subkey = key_list[-1] 720 | d[subkey] = v 721 | 722 | cfg_dict = super().__getattribute__('_cfg_dict') 723 | super().__setattr__( 724 | '_cfg_dict', 725 | Config._merge_a_into_b( 726 | option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) 727 | 728 | 729 | class DictAction(Action): 730 | """ 731 | argparse action to split an argument into KEY=VALUE form 732 | on the first = and append to a dictionary. List options can 733 | be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit 734 | brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build 735 | list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' 736 | """ 737 | 738 | @staticmethod 739 | def _parse_int_float_bool(val): 740 | try: 741 | return int(val) 742 | except ValueError: 743 | pass 744 | try: 745 | return float(val) 746 | except ValueError: 747 | pass 748 | if val.lower() in ['true', 'false']: 749 | return True if val.lower() == 'true' else False 750 | if val == 'None': 751 | return None 752 | return val 753 | 754 | @staticmethod 755 | def _parse_iterable(val): 756 | """Parse iterable values in the string. 757 | 758 | All elements inside '()' or '[]' are treated as iterable values. 759 | 760 | Args: 761 | val (str): Value string. 762 | 763 | Returns: 764 | list | tuple: The expanded list or tuple from the string. 765 | 766 | Examples: 767 | >>> DictAction._parse_iterable('1,2,3') 768 | [1, 2, 3] 769 | >>> DictAction._parse_iterable('[a, b, c]') 770 | ['a', 'b', 'c'] 771 | >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]') 772 | [(1, 2, 3), ['a', 'b'], 'c'] 773 | """ 774 | 775 | def find_next_comma(string): 776 | """Find the position of next comma in the string. 777 | 778 | If no ',' is found in the string, return the string length. All 779 | chars inside '()' and '[]' are treated as one element and thus ',' 780 | inside these brackets are ignored. 781 | """ 782 | assert (string.count('(') == string.count(')')) and ( 783 | string.count('[') == string.count(']')), \ 784 | f'Imbalanced brackets exist in {string}' 785 | end = len(string) 786 | for idx, char in enumerate(string): 787 | pre = string[:idx] 788 | # The string before this ',' is balanced 789 | if ((char == ',') and (pre.count('(') == pre.count(')')) 790 | and (pre.count('[') == pre.count(']'))): 791 | end = idx 792 | break 793 | return end 794 | 795 | # Strip ' and " characters and replace whitespace. 796 | val = val.strip('\'\"').replace(' ', '') 797 | is_tuple = False 798 | if val.startswith('(') and val.endswith(')'): 799 | is_tuple = True 800 | val = val[1:-1] 801 | elif val.startswith('[') and val.endswith(']'): 802 | val = val[1:-1] 803 | elif ',' not in val: 804 | # val is a single value 805 | return DictAction._parse_int_float_bool(val) 806 | 807 | values = [] 808 | while len(val) > 0: 809 | comma_idx = find_next_comma(val) 810 | element = DictAction._parse_iterable(val[:comma_idx]) 811 | values.append(element) 812 | val = val[comma_idx + 1:] 813 | if is_tuple: 814 | values = tuple(values) 815 | return values 816 | 817 | def __call__(self, parser, namespace, values, option_string=None): 818 | options = {} 819 | for kv in values: 820 | key, val = kv.split('=', maxsplit=1) 821 | options[key] = self._parse_iterable(val) 822 | setattr(namespace, self.dest, options) 823 | -------------------------------------------------------------------------------- /pyinfer/utils/common/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import logging 3 | from logging.handlers import RotatingFileHandler 4 | import colorlog 5 | 6 | __all__ = ["Logger"] 7 | 8 | 9 | class Logger(): 10 | """日志""" 11 | instance = None 12 | 13 | def __new__(cls, filename=None, level="INFO"): 14 | """ 15 | 日志设计为单例模式 16 | 17 | 注: 18 | getLogger对于相同filename返回同一对象,需要保证addHandler仅对同一logger对象进行一次,否则日志将出现重复打印;单例可确保上述需求; 19 | 20 | """ 21 | 22 | level_dict = {"INFO": logging.INFO, "DEBUG": logging.DEBUG, "FATAL": logging.FATAL, "ERROR": logging.ERROR} 23 | 24 | if cls.instance is None: 25 | cls.instance = super().__new__(cls) # 未经过初始化的实例对象 26 | cls.instance._logger = cls.getLogger(filename, level_dict.get(level)) 27 | return cls.instance 28 | 29 | @classmethod 30 | def getLogger(cls, filename, level): 31 | logger = logging.getLogger(filename) 32 | 33 | log_colors_config = { 34 | 'DEBUG': 'cyan', 35 | 'INFO': 'green', 36 | 'WARNING': 'yellow', 37 | 'ERROR': 'red', 38 | 'FATAL': 'red', 39 | } 40 | 41 | formatter = colorlog.ColoredFormatter( 42 | '%(log_color)s [%(levelname)s] %(asctime)s %(filename)s[line:%(lineno)d] : %(message)s', 43 | log_colors=log_colors_config) 44 | 45 | # 设置日志级别 46 | logger.setLevel(logging.INFO) 47 | # 往屏幕上输出 48 | console_handler = logging.StreamHandler() 49 | # 设置屏幕上显示的格式 50 | console_handler.setFormatter(formatter) 51 | # 把对象加到logger里 52 | logger.addHandler(console_handler) 53 | 54 | # 输出到文件 55 | if filename is not None: 56 | file_handler = RotatingFileHandler(filename=filename, mode='a', maxBytes=1 * 1024 * 1024, encoding='utf8') 57 | file_formatter = logging.Formatter( 58 | '[%(levelname)s] %(asctime)s %(filename)s[line:%(lineno)d]: %(message)s') 59 | file_handler.setFormatter(file_formatter) 60 | logger.addHandler(file_handler) 61 | return logger 62 | 63 | def warning(self, msg): 64 | self._logger.warning(msg) 65 | 66 | def info(self, msg): 67 | self._logger.info(msg) 68 | 69 | def debug(self, msg): 70 | self._logger.debug(msg) 71 | 72 | def error(self, msg): 73 | self._logger.error(msg) 74 | 75 | def fatal(self, msg): 76 | self._logger.fatal(msg) 77 | raise -------------------------------------------------------------------------------- /pyinfer/utils/common/registry.py: -------------------------------------------------------------------------------- 1 | __all__ = ["INFER", "ENGINE"] 2 | 3 | 4 | class Register(): 5 | 6 | def __init__(self, name=None) -> None: 7 | self.module_dict = {} 8 | 9 | def register_module(self): 10 | 11 | def _register(module): 12 | self.module_dict[module.__name__] = module 13 | return module 14 | 15 | return _register 16 | 17 | def build(self, cfg): 18 | module_name = cfg.pop('type') 19 | return self.get(module_name)(**dict(cfg)) 20 | 21 | def get(self, module_name): 22 | return self.module_dict.get(module_name) 23 | 24 | 25 | INFERS = Register("INFERS") 26 | ENGINES = Register("ENGINES") 27 | HOOKS = Register("HOOKS") 28 | -------------------------------------------------------------------------------- /pyinfer/utils/detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/pyinfer/utils/detection/__init__.py -------------------------------------------------------------------------------- /pyinfer/utils/detection/bbox.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | import numpy as np 3 | from typing import Union 4 | 5 | 6 | class QuadrangleBBox(BaseModel): 7 | 8 | x1: float = 0 9 | y1: float = 0 10 | x2: float = 0 11 | y2: float = 0 12 | x3: float = 0 13 | y3: float = 0 14 | x4: float = 0 15 | y4: float = 0 16 | area: float = 0 17 | confidence: float = 0 18 | label: int = 0 19 | labelname: str = "" 20 | keepflag: bool = True 21 | 22 | @property 23 | def coords(self): 24 | """坐标值""" 25 | return [self.x1, self.y1, self.x2, self.y2, self.x3, self.y3, self.x4, self.y4] 26 | 27 | def set_coords(self, coords): 28 | """坐标值""" 29 | assert len(coords) == 8 30 | self.x1, self.y1, self.x2, self.y2, self.x3, self.y3, self.x4, self.y4 = coords 31 | return self 32 | 33 | @property 34 | def xcoords(self): 35 | """坐标值""" 36 | return [self.x1, self.x2, self.x3, self.x4] 37 | 38 | def set_xcoords(self, xcoords): 39 | assert len(xcoords) == 4 40 | self.x1, self.x2, self.x3, self.x4 = xcoords 41 | return self 42 | 43 | @property 44 | def ycoords(self): 45 | """坐标值""" 46 | return [self.y1, self.y2, self.y3, self.y4] 47 | 48 | def set_ycoords(self, ycoords): 49 | assert len(ycoords) == 4 50 | self.y1, self.y2, self.y3, self.y4 = ycoords 51 | return self 52 | 53 | @property 54 | def width(self): 55 | return self.right - self.left 56 | 57 | @property 58 | def height(self): 59 | return self.bottom - self.top 60 | 61 | @property 62 | def left(self): 63 | return min(self.x1, self.x2, self.x3, self.x4) 64 | 65 | @property 66 | def right(self): 67 | return max(self.x1, self.x2, self.x3, self.x4) 68 | 69 | @property 70 | def top(self): 71 | return min(self.y1, self.y2, self.y3, self.y4) 72 | 73 | @property 74 | def bottom(self): 75 | return max(self.y1, self.y2, self.y3, self.y4) 76 | 77 | def reset_origin(self, ox, oy): 78 | """更新原点""" 79 | self.xcoords = [self.x1 - ox, self.x2 - ox, self.x3 - ox, self.x4 - ox] 80 | self.ycoords = [self.y1 - oy, self.y2 - oy, self.y3 - oy, self.y4 - oy] 81 | return self 82 | 83 | def clip(self, xmin, xmax, ymin, ymax): 84 | self.xcoords = list(map(lambda x: min(max(x, xmin), xmax), self.xcoords)) 85 | self.ycoords = list(map(lambda y: min(max(y, ymin), ymax), self.ycoords)) 86 | return self 87 | 88 | 89 | @property 90 | def level(self): 91 | return len(set(self.ycoords)) == len(set(self.xcoords)) == 2 92 | 93 | def __mul__(self, scalar: Union[int, float]): 94 | """坐标乘""" 95 | self.coords = list(np.array(self.coords) * scalar) 96 | return self 97 | 98 | def __setattr__(self, key, val): 99 | # @coords.setter is not support in BaseModel, thus, modify __setatter__ 100 | method = self.__config__.property_set_methods.get(key) 101 | if method is None: 102 | super().__setattr__(key, val) 103 | else: 104 | getattr(self, method)(val) 105 | 106 | class Config: 107 | property_set_methods = {"coords": "set_coords", "xcoords": "set_xcoords", "ycoords": "set_ycoords"} 108 | -------------------------------------------------------------------------------- /pyinfer/utils/detection/nms.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | import shapely.geometry as shgeo 4 | from .bbox import QuadrangleBBox 5 | 6 | 7 | def cpu_nms(bboxes: List[QuadrangleBBox], nms_threshold) -> List[QuadrangleBBox]: 8 | def box_iou(a: QuadrangleBBox, b: QuadrangleBBox): 9 | a_poly = shgeo.Polygon(np.array(a.coords).reshape(-1, 2).tolist()) 10 | b_poly = shgeo.Polygon(np.array(b.coords).reshape(-1, 2).tolist()) 11 | inter_poly = a_poly.intersection(b_poly) 12 | if inter_poly.area == 0: 13 | return 0 14 | return inter_poly.area / (a_poly.area + b_poly.area - inter_poly.area) 15 | 16 | for i in range(len(bboxes)): 17 | for j in range(len(bboxes)): 18 | # label不同或同一个bbox跳过nms 19 | if i == j or bboxes[i].label != bboxes[j].label: 20 | continue 21 | 22 | # 置信度相同,保留靠后bbox, 即j bboxes[i].confidence: 27 | iou = box_iou(bboxes[i], bboxes[j]) 28 | if (iou > nms_threshold): 29 | bboxes[i].keepflag = 0 30 | 31 | return list(filter(lambda bbox: bbox.keepflag, bboxes)) -------------------------------------------------------------------------------- /pyinfer/utils/functional/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/pyinfer/utils/functional/__init__.py -------------------------------------------------------------------------------- /pyinfer/utils/functional/coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pycocotools.coco import COCO 3 | 4 | 5 | class COCOCreator(): 6 | def __init__(self) -> None: 7 | self.image_id_to_index = {} 8 | self.anno_id_to_index = {} 9 | self.cat_id_to_index = {} 10 | self.dataset = {"images": [], "annotations": [], "categories": []} 11 | 12 | def create_image_info(self, image_id, file_name, height, width, update=False, **kwargs): 13 | image_info = {"id": image_id, "file_name": file_name, "width": width, "height": height, **kwargs} 14 | if update: 15 | self.update_image_info(image_info) 16 | return image_info 17 | 18 | def create_anno_info(self, 19 | image_id, 20 | anno_id, 21 | category_id, 22 | bbox, 23 | area, 24 | segmentation=[], 25 | iscrowd=0, 26 | update=False, 27 | **kwargs): 28 | anno_info = { 29 | "image_id": image_id, 30 | "id": anno_id, 31 | "category_id": category_id, 32 | "bbox": bbox, 33 | "area": area, 34 | "segmentation": segmentation, 35 | "iscrowd": iscrowd, 36 | **kwargs 37 | } 38 | if update: 39 | self.update_anno_info(anno_info) 40 | return anno_info 41 | 42 | def create_cat_info(self, cat_id, cat_name, update=False, **kwargs): 43 | cat_info = {"id": cat_id, "name": cat_name, **kwargs} 44 | if update: 45 | self.update_cat_info(cat_info) 46 | return cat_info 47 | 48 | def update_image_info(self, image_info): 49 | image_id = image_info["id"] 50 | if image_id in self.image_id_to_index: 51 | self.dataset["images"][self.image_id_to_index[image_id]].update(image_info) 52 | else: 53 | self.image_id_to_index[image_id] = len(self.dataset["images"]) 54 | self.dataset["images"].append(image_info) 55 | 56 | def update_anno_info(self, anno_info): 57 | anno_id = anno_info["id"] 58 | if anno_id in self.anno_id_to_index: 59 | self.dataset["annotations"][self.anno_id_to_index[anno_id]].update(anno_info) 60 | else: 61 | self.anno_id_to_index[anno_id] = len(self.dataset["annotations"]) 62 | self.dataset["annotations"].append(anno_info) 63 | 64 | def update_cat_info(self, cat_info): 65 | cat_id = cat_info["id"] 66 | if cat_id in self.cat_id_to_index: 67 | self.dataset["categories"][self.cat_id_to_index[cat_id]].update(cat_info) 68 | else: 69 | self.cat_id_to_index[cat_id] = len(self.dataset["categories"]) 70 | self.dataset["categories"].append(cat_info) 71 | 72 | def build(self): 73 | coco = COCO() 74 | coco.showAnns() 75 | coco.dataset = self.dataset 76 | coco.createIndex() 77 | return coco 78 | 79 | def write(self, path): 80 | with open(path, 'w') as f: 81 | json.dump(self.dataset, f, indent=4, ensure_ascii=False) 82 | -------------------------------------------------------------------------------- /pyinfer/utils/functional/slice.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import copy 3 | import math 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | from PIL import Image 8 | import shapely.geometry as shgeo 9 | from typing import List 10 | from pycocotools.coco import COCO 11 | 12 | from ..detection.bbox import QuadrangleBBox 13 | from .coco import COCOCreator 14 | 15 | __all__ = [""] 16 | 17 | 18 | def reduce_coords(coords): 19 | """将最短边的两个点p1,p2,用p1,p2中点p替换,从而使多边形减少一个顶点""" 20 | def get_distance(point1, point2): 21 | return math.sqrt(math.pow(point1[0] - point2[0], 2) + math.pow(point1[1] - point2[1], 2)) 22 | 23 | coords.append(coords[0]) 24 | distances = [get_distance(coords[i], coords[i + 1]) for i in range(len(coords) - 1)] 25 | 26 | min_index = np.array(distances).argsort()[0] 27 | index = 0 28 | out_coords = [] 29 | while index < len(coords): 30 | if (index == min_index): # 取中点 31 | middle_x = (coords[index][0] + coords[index + 1][0]) / 2 32 | middle_y = (coords[index][1] + coords[index + 1][1]) / 2 33 | out_coords.append((middle_x, middle_y)) 34 | elif (index != (min_index + 1)): # 非最短边点保留 35 | out_coords.append((coords[index][0], coords[index][1])) 36 | index += 1 37 | return out_coords 38 | 39 | 40 | def choose_best_pointorder_fit_another(coords, gt_coords): 41 | """ 42 | :params coords: [p1,p2,p3,p4], 四边形坐标点 43 | :params gt_coords: [p1,p2,p3,p4], 标签四边形坐标点 44 | 45 | 选择最匹配gt_coords的坐标排布顺序,coords在最优排布顺序下,与gt_coords逐点点间距之和最小; 46 | """ 47 | x1, gt_x1 = coords[0][0], gt_coords[0][0] 48 | y1, gt_y1 = coords[0][1], gt_coords[0][1] 49 | x2, gt_x2 = coords[1][0], gt_coords[1][0] 50 | y2, gt_y2 = coords[1][1], gt_coords[1][1] 51 | x3, gt_x3 = coords[2][0], gt_coords[2][0] 52 | y3, gt_y3 = coords[2][1], gt_coords[2][1] 53 | x4, gt_x4 = coords[3][0], gt_coords[3][0] 54 | y4, gt_y4 = coords[3][1], gt_coords[3][1] 55 | 56 | combinate = [ 57 | np.array([x1, y1, x2, y2, x3, y3, x4, y4]), 58 | np.array([x2, y2, x3, y3, x4, y4, x1, y1]), 59 | np.array([x3, y3, x4, y4, x1, y1, x2, y2]), 60 | np.array([x4, y4, x1, y1, x2, y2, x3, y3]) 61 | ] 62 | gt = np.array([gt_x1, gt_y1, gt_x2, gt_y2, gt_x3, gt_y3, gt_x4, gt_y4]) 63 | distances = np.array([np.sum((coord - gt)**2) for coord in combinate]) 64 | sorted = distances.argsort() 65 | best_coords = combinate[sorted[0]].reshape(-1, 2).tolist() 66 | return best_coords 67 | 68 | 69 | def assign_bboxes(pleft, ptop, pright, pbottom, gt_bboxes: List[QuadrangleBBox], threshold): 70 | """ 71 | 分配图像bounding box,过滤掉坐标不在当前图像内的bbox 72 | 73 | Args: 74 | pleft, ptop, pright, pbottom: 图像左上角和右下角坐标值 75 | bboxes (List[QuadrangleBBox], optional): 图像上目标的bounding box. Defaults to None. 76 | threshold (float, optional): 目标bbox和瓦片重叠IOU阈值,大于阈值将bbox分配给该瓦片,否则丢弃. Defaults to 0.5. 77 | 78 | Returns: 79 | patches: List[bboxes] 80 | 81 | 基本步骤: 82 | 83 | 1. 遍历一个gt_bbox; 84 | 2. 计算bbox和当前图像的iou; 85 | 3. iou=1, bbox完全在图像内,保留,跳至第一步; 86 | 4. iou>threshold, bbox部分在图像内,重叠多边形顶点数为n 87 | ① n<4, bbox只有一个顶点在图像内,bbox不匹配; 88 | ② n=5,将5边形变换为4边形,生成新的new_bbox,匹配; 89 | ③ n=6,将6边形变换为4边形,生成新的new_bbox,匹配; 90 | ④ n>6,实际使用中,该情况较少发生,不匹配; 91 | 5. 调整new_bbox顶点顺序,即new_bbox以第p个顶点作为起始点,该顶点排布顺序,与gt_bbox对应顶点距离和最小, 92 | 6. new_bbox坐标值是相对原点(0,0)的值,以图像左上角点为原点,更新坐标值; 93 | 7. 重复第一步; 94 | """ 95 | image_poly = shgeo.Polygon([(pleft, ptop), (pright, ptop), (pright, pbottom), (pleft, pbottom)]) 96 | 97 | remain_bboxes = [] 98 | for gt_bbox in gt_bboxes: 99 | x1, y1, x2, y2, x3, y3, x4, y4 = gt_bbox.coords 100 | gt_coords = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] 101 | gt_poly = shgeo.Polygon(gt_coords) 102 | # 错误的bbox 103 | if (gt_poly.area <= 0): 104 | continue 105 | 106 | inter_poly = gt_poly.intersection(image_poly) # 重叠区域 107 | iou = inter_poly.area / gt_poly.area # 非常规IOU,此IOU时交集与gt_box的比值 108 | 109 | # 若bbox在图像内部,则保留,同时以图像的左上角点作为bbox原点 110 | if (iou == 1): 111 | bbox = copy.deepcopy(gt_bbox) # 更新原点,不要在原始bbox上直接修改 112 | remain_bboxes.append(bbox.reset_origin(pleft, ptop)) 113 | 114 | # 若bbox部分在图像内部 115 | elif iou > threshold: 116 | # 有序排布多边形点; 117 | inter_poly = shgeo.polygon.orient(inter_poly, sign=1) 118 | # 重叠区域多边形顶点,由于收尾相接,第一个点和最后一个点重复,因此去除最后一个点; 119 | # inter_coords: [(x1, y1), (x2, y2), (x3, y3), (x4, y4), ..., (x1, y1)] 120 | inter_coords = list(inter_poly.exterior.coords)[0:-1] 121 | 122 | # 两个矩形重叠区域较小时,重叠区可能由三个点组成,重叠较小,bbox不保留 123 | if len(inter_coords) < 4: 124 | continue 125 | 126 | elif (len(inter_coords) > 6): 127 | """两个矩形重叠区域存在6个以上顶点,在实际情况中极少出现,因此此类bbox做丢弃处理""" 128 | continue 129 | 130 | elif (len(inter_coords) == 6): 131 | """重叠区域6个顶点时,合并两个最短边,减少顶点数至4""" 132 | inter_coords = reduce_coords(reduce_coords(inter_coords)) 133 | 134 | elif (len(inter_coords) == 5): 135 | """重叠区域5个顶点时,合并一个最短边,减少顶点数至4""" 136 | inter_coords = reduce_coords(inter_coords) 137 | 138 | # elif (len(inter_coords)==4): 139 | # 4个顶点不作处理 140 | 141 | best_coords = choose_best_pointorder_fit_another(inter_coords, gt_coords) # 最优顶点顺序 142 | best_coords_area = shgeo.Polygon(best_coords).area 143 | # 限制坐标范围, 并更新原点 144 | best_coords = np.array(best_coords).flatten() 145 | new_bbox = QuadrangleBBox(label=gt_bbox.label, area=best_coords_area) 146 | new_bbox.coords = best_coords 147 | new_bbox.clip(xmin=pleft, xmax=pright, ymin=ptop, ymax=pbottom) 148 | new_bbox.reset_origin(ox=pleft, oy=ptop) 149 | remain_bboxes.append(new_bbox) 150 | return remain_bboxes 151 | 152 | 153 | def slice_patch(image, left, top, subsize, padding): 154 | no_padding_patch = copy.deepcopy(image[top:(top + subsize), left:(left + subsize)]) 155 | h, w, c = no_padding_patch.shape 156 | if (padding): 157 | patch_image = np.ones((subsize, subsize, c)) * 114 158 | patch_image[0:h, 0:w, :] = no_padding_patch 159 | else: 160 | patch_image = no_padding_patch 161 | return patch_image.astype(np.uint8) 162 | 163 | 164 | def slice_one_image(image: np.ndarray, 165 | image_base_name, 166 | subsize, 167 | rate, 168 | gap, 169 | threshold=0.5, 170 | padding=True, 171 | bboxes: List[QuadrangleBBox] = None): 172 | """ 173 | 拆分一张图像 174 | 175 | Args: 176 | image (np.ndarray): 图像 177 | subsize (int): 瓦片大小 178 | rate (int): 图像resize倍数, rate=1表示不对图像做resize处理 179 | gap (int): 瓦片重叠大小 180 | bboxes (List[QuadrangleBBox], optional): 图像上目标的bounding box. Defaults to None. 181 | threshold (float, optional): 目标bbox和瓦片重叠IOU阈值,大于阈值将bbox分配给该瓦片,否则丢弃. Defaults to 0.5. 182 | padding (bool, optional): 瓦片大小不足subsize,是否padding. Defaults to True. 183 | name (str, optional): 图像名称. Defaults to "". 184 | 185 | Returns: 186 | patches: List[Tuple[patch_name, patch_image, bboxes]] 187 | 188 | 基本步骤: 189 | 190 | 1. 选取瓦片图像范围 191 | 2. 从image切分瓦片patch image 192 | 3. 若image有标签,将标签bboxes与瓦片进行匹配,匹配成功的部分bboxes作为patch bboxes 193 | 4. 滑动,重复步骤1~3 194 | """ 195 | assert np.shape(image) != () 196 | 197 | patches = [] 198 | if (rate != 1): 199 | if bboxes is not None: 200 | bboxes = list(map(lambda bbox: bbox * rate, bboxes)) 201 | resizeimg = cv2.resize(image, None, fx=rate, fy=rate, interpolation=cv2.INTER_CUBIC) 202 | else: 203 | resizeimg = image 204 | 205 | base_name = image_base_name + '__' + str(rate) + '__' 206 | 207 | width = np.shape(resizeimg)[1] 208 | height = np.shape(resizeimg)[0] 209 | 210 | slide = subsize - gap 211 | left, top = 0, 0 212 | while (left < width): 213 | if (left + subsize >= width): 214 | left = max(width - subsize, 0) 215 | top = 0 216 | while (top < height): 217 | if (top + subsize >= height): 218 | top = max(height - subsize, 0) 219 | patch_name = base_name + str(left) + '___' + str(top) 220 | right = min(left + subsize - 1, width - 1) 221 | bottom = min(top + subsize - 1, height - 1) 222 | 223 | # 获取一张瓦片图 224 | patch_image = slice_patch(resizeimg, left, top, subsize, padding) 225 | 226 | # 获取瓦片图内的bbox 227 | if bboxes is None: 228 | patch_bboxes = [] 229 | else: 230 | patch_bboxes = assign_bboxes(left, top, right, bottom, bboxes, threshold) 231 | patch = (patch_name, rate, left, top, patch_image, patch_bboxes) 232 | patches.append(patch) 233 | 234 | # 滑动 235 | if (top + subsize >= height): 236 | break 237 | else: 238 | top = top + slide 239 | if (left + subsize >= width): 240 | break 241 | else: 242 | left = left + slide 243 | return patches 244 | 245 | 246 | def slice_coco_dataset(coco_image_src, 247 | coco_image_dst, 248 | coco_json_src, 249 | coco_json_dst, 250 | rate, 251 | subsize, 252 | gap, 253 | padding=True, 254 | threshold=0.5): 255 | coco = COCO(coco_json_src) 256 | coco_creator = COCOCreator() 257 | 258 | image_cnt = 0 259 | anno_cnt = 0 260 | for parent_image_id in tqdm(coco.getImgIds(), desc="slice"): 261 | # 1.读取图像 262 | image_info = coco.loadImgs(ids=parent_image_id)[0] 263 | image = np.array(Image.open(os.path.join(coco_image_src, image_info['file_name']))) 264 | image_base_name, extension = os.path.splitext(image_info['file_name']) 265 | 266 | # 2.获取图像gt_bboxes 267 | anno_infos = coco.loadAnns(coco.getAnnIds(imgIds=parent_image_id)) 268 | gt_bboxes = [] 269 | for ann_info in anno_infos: 270 | left, top, width, height = ann_info['bbox'] 271 | right, bottom = left + width, top + height 272 | gt_bbox = QuadrangleBBox(x1=left, 273 | y1=top, 274 | x2=right, 275 | y2=top, 276 | x3=right, 277 | y3=bottom, 278 | x4=left, 279 | y4=bottom, 280 | label=ann_info['category_id'], 281 | area=ann_info['area']) 282 | gt_bboxes.append(gt_bbox) 283 | 284 | # 3.图像切片,gt_bboxes分配 285 | patches = slice_one_image(image=image, 286 | image_base_name=image_base_name, 287 | subsize=subsize, 288 | gap=gap, 289 | bboxes=gt_bboxes, 290 | rate=rate, 291 | threshold=threshold, 292 | padding=padding) 293 | 294 | for patch_name, rate, left, top, patch_image, patch_bboxes in patches: 295 | if len(patch_bboxes) == 0: # 无目标的切片丢弃 296 | continue 297 | 298 | patch_image_filename = f"{patch_name}{extension}" 299 | 300 | # 4 保存切片图像 301 | Image.fromarray(patch_image.astype(np.uint8)).save(os.path.join(coco_image_dst, patch_image_filename)) 302 | 303 | # 5 新增切片的image_info 304 | image_info = coco_creator.create_image_info( 305 | image_id=image_cnt, 306 | file_name=patch_image_filename, 307 | height=subsize, 308 | width=subsize, 309 | parent_image_id=parent_image_id, # 来自同一大图的切片具有相同的大图id 310 | slice_params=[rate, left, top]) # 切片左上角在大图中的坐标,大图切片前的resize参数rate 311 | coco_creator.update_image_info(image_info) 312 | 313 | # 6 新增切片的anno_info 314 | for bbox in patch_bboxes: 315 | bbox: QuadrangleBBox 316 | anno_info = coco_creator.create_anno_info(image_id=image_cnt, 317 | anno_id=anno_cnt, 318 | bbox=[bbox.left, bbox.top, bbox.width, bbox.height], 319 | category_id=bbox.label, 320 | area=bbox.area) 321 | coco_creator.update_anno_info(anno_info) 322 | anno_cnt += 1 323 | image_cnt += 1 324 | 325 | # 直接引用大图的cat_info 326 | for cat_info in coco.loadCats(coco.getCatIds()): 327 | coco_creator.update_cat_info(cat_info) 328 | 329 | # 生成新的coco_json 330 | coco_creator.write(coco_json_dst) -------------------------------------------------------------------------------- /pyinfer/utils/functional/traits.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | __all__ = ["WarpAffineTraits"] 5 | 6 | 7 | class WarpAffineTraits(): 8 | "仿射变换" 9 | 10 | def __init__(self, sx, sy, dx, dy): 11 | self.sx = sx 12 | self.sy = sy 13 | self.dx = dx 14 | self.dy = dy 15 | # 仿射变换矩阵 16 | self.m2x3_to_dst, self.m2x3_to_src = self.init() 17 | 18 | def init(self): 19 | scale_x = self.dx / self.sx 20 | scale_y = self.dy / self.sy 21 | self.scale = min(scale_x, scale_y) 22 | 23 | self.tx = round(self.scale * self.sx) 24 | self.ty = round(self.scale * self.sy) 25 | 26 | # keep ratio resize 27 | m2x3_to_dst = np.zeros((2, 3)) 28 | m2x3_to_dst[0][0] = self.scale 29 | m2x3_to_dst[0][1] = 0 30 | m2x3_to_dst[0][2] = -self.scale * self.sx * 0.5 + self.dx * 0.5 + self.scale * 0.5 - 0.5 31 | m2x3_to_dst[1][0] = 0 32 | m2x3_to_dst[1][1] = self.scale 33 | m2x3_to_dst[1][2] = -self.scale * self.sy * 0.5 + self.dy * 0.5 + self.scale * 0.5 - 0.5 34 | m2x3_to_dst = m2x3_to_dst.astype(np.float32) 35 | 36 | m2x3_to_src = cv2.invertAffineTransform(m2x3_to_dst).astype(np.float32) 37 | return m2x3_to_dst, m2x3_to_src 38 | 39 | def __call__(self, src_img: np.ndarray, interpolation=cv2.INTER_LINEAR, pad_value=[114, 114, 114]): 40 | "对输入图像进行仿射变换" 41 | top = int((self.dy - self.ty) * 0.5) 42 | left = int((self.dx - self.tx) * 0.5) 43 | bottom = self.dy - self.ty - top 44 | right = self.dx - self.tx - left 45 | dst_img = cv2.resize(src_img, (0, 0), fx=self.scale, fy=self.scale, interpolation=interpolation) 46 | dst_img = cv2.copyMakeBorder(dst_img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=pad_value) 47 | return dst_img 48 | 49 | def to_src_coord(self, dx, dy): 50 | "变换后坐标->变换前坐标" 51 | sx, sy = np.matmul(self.m2x3_to_src, np.array([dx, dy, 1]).T) 52 | sx = min(max(round(sx), 0), self.sx - 1) 53 | sy = min(max(round(sy), 0), self.sy - 1) 54 | return sx, sy 55 | 56 | def to_dst_coord(self, sx, sy): 57 | "变换前坐标->变换后坐标" 58 | dx, dy = np.matmul(self.m2x3_to_dst, np.array([sx, sy, 1]).T) 59 | dx = min(max(round(dx), 0), self.dx - 1) 60 | dy = min(max(round(dy), 0), self.dy - 1) 61 | return dx, dy 62 | -------------------------------------------------------------------------------- /requirments.txt: -------------------------------------------------------------------------------- 1 | colorlog 2 | numpy 3 | fastapi[all] 4 | opencv-python 5 | Sharply 6 | uvicorn 7 | pycocotools -------------------------------------------------------------------------------- /static/README.md: -------------------------------------------------------------------------------- 1 | # 关于FastAPI 2 | 3 | [离线环境无法加载docs问题](https://fastapi.tiangolo.com/advanced/extending-openapi/#self-hosting-javascript-and-css-for-docs) 4 | 5 | ```python 6 | . 7 | ├── app 8 | │   ├── __init__.py 9 | │   ├── main.py 10 | └── static 11 | ├── redoc.standalone.js 12 | ├── swagger-ui-bundle.js 13 | └── swagger-ui.css 14 | 15 | ``` 16 | 17 | 基于上述目录结构,api脚本改进如下: 18 | 19 | ```python 20 | from fastapi import FastAPI 21 | from fastapi.openapi.docs import ( 22 | get_redoc_html, 23 | get_swagger_ui_html, 24 | get_swagger_ui_oauth2_redirect_html, 25 | ) 26 | from fastapi.staticfiles import StaticFiles 27 | 28 | app = FastAPI(docs_url=None, redoc_url=None) 29 | 30 | app.mount("/static", StaticFiles(directory="static"), name="static") 31 | 32 | 33 | @app.get("/docs", include_in_schema=False) 34 | async def custom_swagger_ui_html(): 35 | return get_swagger_ui_html( 36 | openapi_url=app.openapi_url, 37 | title=app.title + " - Swagger UI", 38 | oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url, 39 | swagger_js_url="/static/swagger-ui-bundle.js", 40 | swagger_css_url="/static/swagger-ui.css", 41 | ) 42 | 43 | 44 | @app.get(app.swagger_ui_oauth2_redirect_url, include_in_schema=False) 45 | async def swagger_ui_redirect(): 46 | return get_swagger_ui_oauth2_redirect_html() 47 | 48 | 49 | @app.get("/redoc", include_in_schema=False) 50 | async def redoc_html(): 51 | return get_redoc_html( 52 | openapi_url=app.openapi_url, 53 | title=app.title + " - ReDoc", 54 | redoc_js_url="/static/redoc.standalone.js", 55 | ) 56 | 57 | 58 | @app.get("/users/{username}") 59 | async def read_user(username: str): 60 | return {"message": f"Hello {username}"} 61 | 62 | ``` -------------------------------------------------------------------------------- /tools/generate_gif.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | 3 | 4 | def create_gif(image_list, gif_name, duration=1.0): 5 | ''' 6 | :param image_list: 这个列表用于存放生成动图的图片 7 | :param gif_name: 字符串,所生成gif文件名,带.gif后缀 8 | :param duration: 图像间隔时间 9 | :return: 10 | ''' 11 | frames = [] 12 | for image_name in image_list: 13 | frames.append(imageio.imread(image_name)) 14 | 15 | imageio.mimsave(gif_name, frames, 'GIF', duration=duration) 16 | return 17 | 18 | 19 | def main(): 20 | import os 21 | image_dir = "C:/Users/wzy/Desktop/NOTEBOOK/docs/algorithm/images/新建文件夹" 22 | image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)] 23 | gif_name = os.path.join(image_dir, "1.gif") 24 | duration = 1.5 25 | create_gif(image_list, gif_name, duration) 26 | 27 | 28 | if __name__ == '__main__': 29 | main() -------------------------------------------------------------------------------- /tools/mmdet_export_onnx/balloon/export_yolox_onnx.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 导出mmdetection onnx思路:(或者任意框架) 3 | 4 | 1. 跑通官方推理代码,等够调试官方推理代码,不能调试则无法分析; 5 | 2. 找到模型核心部分,即forward涉及的关键组件,例如backbone, neck, head等,其他部分丢弃; 6 | 3. 重新组合Model(nn.Module)封装核心forward组件; 7 | 4. 导出onnx观察,若有多个输出,在forward中补充代码,调整模型输出; 8 | 9 | mmdetection官方onnx导出不够灵活,调整onnx很容易报错,错误解决需要充分阅读理解框架算法源码实现; 10 | ''' 11 | import sys 12 | sys.path.append("/volume/huaru/third_party/mmdetection") # 配置环境变量 13 | 14 | import torch 15 | from mmdet.apis import init_detector, inference_detector 16 | 17 | config_file = '/volume/wzy/project/PyInfer/tools/mmdet_export_onnx/balloon/yolox_s_8x8_300e_coco.py' 18 | # 从 model zoo 下载 checkpoint 并放在 `checkpoints/` 文件下 19 | checkpoint_file = '/volume/wzy/project/PyInfer/tools/mmdet_export_onnx/balloon/model.pth' 20 | device = 'cuda:0' 21 | 22 | #初始化检测器 23 | # model = init_detector(config_file, checkpoint_file, device=device) 24 | # # 推理演示图像 25 | # print(inference_detector(model, 'demo/demo.jpg')) 26 | 27 | class Model(torch.nn.Module): 28 | def __init__(self): 29 | super().__init__() 30 | self.model = init_detector(config_file, checkpoint_file, device=device) 31 | 32 | def forward(self, x): 33 | ib, ic, ih, iw = map(int, x.shape) 34 | x = self.model.backbone(x) 35 | x = self.model.neck(x) 36 | clas, bbox, objness = self.model.bbox_head(x) 37 | 38 | # 网络输出映射到网络输入,heatmap特征层bbox映射到输入图像bbox 39 | # 1.映射bbox、2.decode(逆仿射变换+confidence过滤)、3.nms,其中2,3属于后处理,映射放在onnx中是为了简化后处理, 后处理不放在onnx中是为单独编写cuda提高效率 40 | output_x = [] 41 | for class_item, bbox_item, objness_item in zip(clas, bbox, objness): 42 | hm_b, hm_c, hm_h, hm_w = map(int, class_item.shape) 43 | stride_h, stride_w = ih / hm_h, iw / hm_w 44 | strides = torch.tensor([stride_w, stride_h], device=device).view(-1, 1, 2) 45 | 46 | prior_y, prior_x = torch.meshgrid(torch.arange(hm_h), torch.arange(hm_w)) 47 | prior_x = prior_x.reshape(hm_h * hm_w, 1).to(device) 48 | prior_y = prior_y.reshape(hm_h * hm_w, 1).to(device) 49 | prior_xy = torch.cat([prior_x, prior_y], dim=-1) 50 | class_item = class_item.permute(0, 2, 3, 1).reshape(-1, hm_h * hm_w, hm_c) 51 | bbox_item = bbox_item.permute(0, 2, 3, 1).reshape(-1, hm_h * hm_w, 4) 52 | objness_item = objness_item.reshape(-1, hm_h * hm_w, 1) 53 | pred_xy = (bbox_item[..., :2] + prior_xy) * strides 54 | pred_wh = bbox_item[..., 2:4].exp() * strides 55 | pred_class = torch.cat([objness_item, class_item], dim=-1).sigmoid() 56 | output_x.append(torch.cat([pred_xy, pred_wh, pred_class], dim=-1)) 57 | 58 | return torch.cat(output_x, dim=1) 59 | 60 | m = Model().eval() 61 | 62 | image = torch.zeros(1, 3, 640, 640, device=device) 63 | torch.onnx.export( 64 | m, (image,), "yolox.onnx", 65 | opset_version=11, 66 | input_names=["images"], 67 | output_names=["output"], 68 | dynamic_axes={ 69 | "images": {0: "batch"}, 70 | "output": {0: "batch"} 71 | } 72 | ) 73 | print("Done.!") -------------------------------------------------------------------------------- /tools/mmdet_export_onnx/balloon/yolox_s_8x8_300e_coco.py: -------------------------------------------------------------------------------- 1 | application_root = '' 2 | data_root = '' 3 | img_scale = (640, 640) 4 | dataset_type = 'CocoDataset' 5 | classes = ('balloon', ) 6 | num_classes = 1 7 | train_pipeline = [ 8 | dict(type='Mosaic', img_scale=(640, 640), pad_val=114.0), 9 | dict(type='YOLOXHSVRandomAug'), 10 | dict(type='RandomFlip', flip_ratio=0.5), 11 | dict(type='Resize', img_scale=(640, 640), keep_ratio=True), 12 | dict( 13 | type='Pad', 14 | pad_to_square=True, 15 | pad_val=dict(img=(114.0, 114.0, 114.0))), 16 | dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), 17 | dict(type='DefaultFormatBundle'), 18 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 19 | ] 20 | test_pipeline = [ 21 | dict(type='LoadImageFromFile'), 22 | dict( 23 | type='MultiScaleFlipAug', 24 | img_scale=(640, 640), 25 | flip=False, 26 | transforms=[ 27 | dict(type='Resize', keep_ratio=True), 28 | dict(type='RandomFlip'), 29 | dict( 30 | type='Pad', 31 | pad_to_square=True, 32 | pad_val=dict(img=(114.0, 114.0, 114.0))), 33 | dict(type='DefaultFormatBundle'), 34 | dict(type='Collect', keys=['img']) 35 | ]) 36 | ] 37 | train_dataset = dict( 38 | type='MultiImageMixDataset', 39 | dataset=dict( 40 | type='CocoDataset', 41 | ann_file='', 42 | img_prefix='', 43 | classes=('balloon', ), 44 | pipeline=[ 45 | dict(type='LoadImageFromFile'), 46 | dict(type='LoadAnnotations', with_bbox=True) 47 | ], 48 | filter_empty_gt=False), 49 | pipeline=[ 50 | dict(type='Mosaic', img_scale=(640, 640), pad_val=114.0), 51 | dict(type='YOLOXHSVRandomAug'), 52 | dict(type='RandomFlip', flip_ratio=0.5), 53 | dict(type='Resize', img_scale=(640, 640), keep_ratio=True), 54 | dict( 55 | type='Pad', 56 | pad_to_square=True, 57 | pad_val=dict(img=(114.0, 114.0, 114.0))), 58 | dict( 59 | type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), 60 | dict(type='DefaultFormatBundle'), 61 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 62 | ]) 63 | val_dataset = dict( 64 | type='CocoDataset', 65 | ann_file='', 66 | img_prefix='', 67 | classes=('balloon', ), 68 | pipeline=[ 69 | dict(type='LoadImageFromFile'), 70 | dict( 71 | type='MultiScaleFlipAug', 72 | img_scale=(640, 640), 73 | flip=False, 74 | transforms=[ 75 | dict(type='Resize', keep_ratio=True), 76 | dict(type='RandomFlip'), 77 | dict( 78 | type='Pad', 79 | pad_to_square=True, 80 | pad_val=dict(img=(114.0, 114.0, 114.0))), 81 | dict(type='DefaultFormatBundle'), 82 | dict(type='Collect', keys=['img']) 83 | ]) 84 | ]) 85 | data = dict( 86 | samples_per_gpu=4, 87 | workers_per_gpu=2, 88 | persistent_workers=True, 89 | train=dict( 90 | type='MultiImageMixDataset', 91 | dataset=dict( 92 | type='CocoDataset', 93 | ann_file='', 94 | img_prefix='', 95 | classes=('balloon', ), 96 | pipeline=[ 97 | dict(type='LoadImageFromFile'), 98 | dict(type='LoadAnnotations', with_bbox=True) 99 | ], 100 | filter_empty_gt=False), 101 | pipeline=[ 102 | dict(type='Mosaic', img_scale=(640, 640), pad_val=114.0), 103 | dict(type='YOLOXHSVRandomAug'), 104 | dict(type='RandomFlip', flip_ratio=0.5), 105 | dict(type='Resize', img_scale=(640, 640), keep_ratio=True), 106 | dict( 107 | type='Pad', 108 | pad_to_square=True, 109 | pad_val=dict(img=(114.0, 114.0, 114.0))), 110 | dict( 111 | type='FilterAnnotations', 112 | min_gt_bbox_wh=(1, 1), 113 | keep_empty=False), 114 | dict(type='DefaultFormatBundle'), 115 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 116 | ]), 117 | val=dict( 118 | type='CocoDataset', 119 | ann_file='', 120 | img_prefix='', 121 | classes=('balloon', ), 122 | pipeline=[ 123 | dict(type='LoadImageFromFile'), 124 | dict( 125 | type='MultiScaleFlipAug', 126 | img_scale=(640, 640), 127 | flip=False, 128 | transforms=[ 129 | dict(type='Resize', keep_ratio=True), 130 | dict(type='RandomFlip'), 131 | dict( 132 | type='Pad', 133 | pad_to_square=True, 134 | pad_val=dict(img=(114.0, 114.0, 114.0))), 135 | dict(type='DefaultFormatBundle'), 136 | dict(type='Collect', keys=['img']) 137 | ]) 138 | ]), 139 | test=dict( 140 | type='CocoDataset', 141 | ann_file='', 142 | img_prefix='', 143 | classes=('balloon', ), 144 | pipeline=[ 145 | dict(type='LoadImageFromFile'), 146 | dict( 147 | type='MultiScaleFlipAug', 148 | img_scale=(640, 640), 149 | flip=False, 150 | transforms=[ 151 | dict(type='Resize', keep_ratio=True), 152 | dict(type='RandomFlip'), 153 | dict( 154 | type='Pad', 155 | pad_to_square=True, 156 | pad_val=dict(img=(114.0, 114.0, 114.0))), 157 | dict(type='DefaultFormatBundle'), 158 | dict(type='Collect', keys=['img']) 159 | ]) 160 | ])) 161 | model = dict( 162 | type='YOLOX', 163 | input_size=(640, 640), 164 | random_size_range=(15, 25), 165 | random_size_interval=10, 166 | backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5), 167 | neck=dict( 168 | type='YOLOXPAFPN', 169 | in_channels=[128, 256, 512], 170 | out_channels=128, 171 | num_csp_blocks=1), 172 | bbox_head=dict( 173 | type='YOLOXHead', num_classes=1, in_channels=128, feat_channels=128), 174 | train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), 175 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) 176 | max_epochs = 50 177 | num_last_epochs = 14 178 | resume_from = None 179 | interval = 5 180 | optimizer = dict( 181 | type='SGD', 182 | lr=0.0001, 183 | momentum=0.9, 184 | weight_decay=0.0005, 185 | nesterov=True, 186 | paramwise_cfg=dict(norm_decay_mult=0.0, bias_decay_mult=0.0)) 187 | optimizer_config = dict(grad_clip=None) 188 | lr_config = dict( 189 | policy='YOLOX', 190 | warmup='exp', 191 | by_epoch=False, 192 | warmup_by_epoch=True, 193 | warmup_ratio=1, 194 | warmup_iters=15, 195 | num_last_epochs=14, 196 | min_lr_ratio=0.05) 197 | custom_hooks = [ 198 | dict(type='YOLOXModeSwitchHook', num_last_epochs=14, priority=48), 199 | dict(type='SyncNormHook', num_last_epochs=14, interval=5, priority=48), 200 | dict( 201 | type='ExpMomentumEMAHook', 202 | resume_from=None, 203 | momentum=0.0001, 204 | priority=49) 205 | ] 206 | checkpoint_config = dict(interval=5) 207 | evaluation = dict( 208 | save_best='auto', interval=5, dynamic_intervals=[(36, 5)], metric='bbox') 209 | log_config = dict(interval=8, hooks=[dict(type='TextLoggerHook')]) 210 | dist_params = dict(backend='nccl') 211 | log_level = 'INFO' 212 | opencv_num_threads = 0 213 | auto_scale_lr = dict(base_batch_size=8) 214 | mp_start_method = 'fork' 215 | workflow = [('train', 1)] 216 | runner = dict(type='EpochBasedRunner', max_epochs=50) 217 | load_from = '' 218 | work_dir = '' 219 | auto_resume = False 220 | gpu_ids = [0] 221 | -------------------------------------------------------------------------------- /tools/slice_coco.py: -------------------------------------------------------------------------------- 1 | from pyinfer.utils.common.slice import slice_coco_dataset 2 | 3 | if __name__ == "__main__": 4 | coco_image_src = "C:/Users/wzy/Desktop/xxx/async_infer_python/workspace/coco/train" 5 | coco_image_dst = "C:/Users/wzy/Desktop/xxx/async_infer_python/workspace/coco/slice_train" 6 | 7 | coco_json_src = "C:/Users/wzy/Desktop/xxx/async_infer_python/workspace/coco/train.json" 8 | coco_json_dst = "C:/Users/wzy/Desktop/xxx/async_infer_python/workspace/coco/slice_train.json" 9 | 10 | slice_coco_dataset( 11 | coco_image_src, 12 | coco_image_dst, 13 | coco_json_src, 14 | coco_json_dst, 15 | rate=1, 16 | subsize=640, 17 | gap=200, 18 | padding=True, 19 | threshold=0.2) 20 | -------------------------------------------------------------------------------- /tools/visual_coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import matplotlib.pyplot as plt 4 | from tqdm import tqdm 5 | from pycocotools.coco import COCO 6 | import numpy as np 7 | 8 | 9 | def visual_coco(coco_image, coco_json): 10 | coco = COCO(coco_json) 11 | for image_id in tqdm(coco.getImgIds(), desc="visual coco"): 12 | image_info = coco.loadImgs(ids=image_id)[0] 13 | image = np.array(Image.open(os.path.join( 14 | coco_image, image_info['file_name']))) 15 | plt.imshow(image) 16 | anno_infos = coco.loadAnns(coco.getAnnIds(imgIds=image_id)) 17 | coco.showAnns(anno_infos, draw_bbox=True) 18 | plt.show() 19 | 20 | 21 | coco = COCO 22 | if __name__ == "__main__": 23 | coco_image = "" 24 | coco_json = "" 25 | visual_coco(coco_image, coco_json) 26 | -------------------------------------------------------------------------------- /workspace/balloon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/workspace/balloon.jpg -------------------------------------------------------------------------------- /workspace/infer_balloon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/workspace/infer_balloon.jpg --------------------------------------------------------------------------------