├── .gitignore ├── README.md ├── app.py ├── applications └── balloon │ └── config.py ├── doc ├── demo_api.png ├── head.png ├── infer_flow.gif ├── 基本流程.svg ├── 数据流动.svg └── 独占资源分配器.svg ├── pyinfer ├── __init__.py ├── core │ ├── __init__.py │ ├── build.py │ ├── engine │ │ ├── __init__.py │ │ └── engine.py │ ├── hook │ │ ├── __init__.py │ │ └── hooks.py │ ├── infer │ │ ├── __init__.py │ │ ├── base.py │ │ └── detection.py │ ├── job.py │ └── mono_allocator.py └── utils │ ├── __init__.py │ ├── common │ ├── __init__.py │ ├── config.py │ ├── logger.py │ └── registry.py │ ├── detection │ ├── __init__.py │ ├── bbox.py │ └── nms.py │ └── functional │ ├── __init__.py │ ├── coco.py │ ├── slice.py │ └── traits.py ├── requirments.txt ├── static ├── README.md ├── redoc.standalone.js ├── swagger-ui-bundle.js └── swagger-ui.css ├── tools ├── generate_gif.py ├── mmdet_export_onnx │ └── balloon │ │ ├── export_yolox_onnx.py │ │ └── yolox_s_8x8_300e_coco.py ├── slice_coco.py └── visual_coco.py └── workspace ├── balloon.jpg └── infer_balloon.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | models 2 | __pycache__/ 3 | .vscode 4 | *.pth 5 | *.onnx 6 | doc/ori -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |  2 | 3 | # 概述 4 | 5 | 机器学习高性能推理解决方案(python 版); 6 | 7 | **功能** 8 | 9 | 1. 基于事件循环的非阻塞同步动态 Batch 推理:推理的并发和并行; 10 | 2. 独占资源分配器:预处理和推理并行,并且防止内存溢出; 11 | 3. 分解推理:分解为多个子任务推理; 12 | 13 | **基本流程** 14 | 15 |  16 | 17 | 1. 异步服务:FastAPI 注册的路由 `route func`,`route func`为异步协程,实现并发访问; 18 | 2. 异步推理:`route func`解析请求参数获取输入,将输入封装成 job,并立即返回未完成计算 `future`;`route func`通过 await 挂起,直到协程 `future`返回结果; 19 | 3. 执行任务:创建消费者线程,从任务队列 `queue`中获取任务并执行具体的推理工作,并将推理结果填充至 `future`; 20 | 4. 结果返回:异步协程 `future`在结果填充后返回,程序回到 `route func`,收集推理结果通过 FastAPI 返回,完成一次推理; 21 | 22 | **数据流** 23 | 24 |  25 | 26 | 1. 接收数据:FastAPI 路由函数 `route func`接收数据流 `raw_data`; 27 | 2. 解析数据:生产者线程通过推理器 `infer.parse_raw`方法,将 `raw_data`解析为推理器输入格式 `infer.Input`; 28 | 3. 存储推理任务输入:将解析后输入 `input`存储在 `job.input`; 29 | 4. 预处理:通过 `preprocess`对 `job.input`做预处理; 30 | 5. 存储模型输入:申请数据资源 `mono_data`,将预处理后模型输入存储至 `mono_data.input`; 31 | 6. 多个任务模型输入组成 batch:多个任务 `mono_data.input`组合为 `batch_input`; 32 | 7. 存储模型输出:根据索引,从 `batch_output`中获取任务输出记录在 `mono_data.output`中; 33 | 8. 后处理:通过 `postprocess`,将模型输出 `mono_data.output`转换为最终推理结果存储在 `job.output`中; 34 | 9. 返回推理结果:将推理结果 `job.output`填充至未完成工作 `future`,使 `future`变成完成状态,FastAPI 获取 `future.result()`后返回; 35 | 36 | # 安装 37 | 38 | ``` 39 | pip install fastapi[all] 40 | ``` 41 | 42 | # 关键技术 43 | 44 | ## 异步动态 Batch 推理 45 | 46 | 1. 异步 47 | 48 | python 的异步模块并非真实的异步,而是基于事件循环的非阻塞同步操作; 49 | 50 | - 阻塞和非阻塞:调用方等待计算完成; 51 | - 同步和异步:同步,调用方轮询计算是否完成;异步,约定某一通信方式,当任务完成时通知调用方。 52 | 53 | 生产者提交推理任务,返回未完成计算 `future`,调用方不必等待推理任务完成,可继续接收 http 请求,生产新的推理任务,因此是非阻塞的。生产者轮询推理任务是否完成,完成则通过 http 返回推理结果,因此,是同步的。`python asyncio`源码如下: 54 | 55 | ```python 56 | from asyncio.futures import Future 57 | 58 | 59 | async def func(): 60 | future = Future() # 非阻塞 61 | await future # 同步 62 | 63 | 64 | 65 | class Future(): 66 | def __init__(self, *, loop=None): 67 | if loop is None: 68 | # 事件循环,事件指future.set_result事件, 事件发生时执行绑定事件的callback回调函数 69 | self._loop = events.get_event_loop() 70 | 71 | def set_result(self, result): 72 | # set_result事件修改任务状态,之后出发set_result事件绑定回调,事件循环loop与future的非阻塞同步无关; 73 | self._state = _FINISHED 74 | self.__schedule_callbacks() 75 | 76 | def __await__(self): 77 | if not self.done(): # 判断是否完成计算 78 | self._asyncio_future_blocking = True 79 | yield self #出让CPU时间片 80 | if not self.done(): 81 | raise RuntimeError("await wasn't used with future") 82 | return self.result() # 返回推理结果 83 | 84 | __iter__ = __await__ # 循环迭代,轮询方式; 85 | 86 | ``` 87 | 88 | python 中提供两种未完成计算 `future`,`concurrent.futures`和 `asyncio.futures`,前者基于线程通信实现非阻塞异步,后者基于事件循环+协程实现非阻塞同步。 89 | 90 | 2. 动态 Batch 推理 91 | 92 | 通过 Batch 实现并行:基于生产消费设计模式、任务队列、线程通信等机制,推理时可将多个任务合并为一个 Batch 推理,提供推理并行度。 93 | 94 | ## 独占资源分配器 95 | 96 | 独占资源分配器实现两个功能:预处理和推理并行和内存防溢出。 97 | 98 | 由 job 结构和数据流动可知,要想将 job 加入任务队列必须经过预处理 preprocess,而预处理必须申请独占数据资源(python 类对象),才能够存储预处理输出(模型输入)和后处理输入(模型输出),因此,设计独占资源分配器,可以巧妙控制整个推理的进行; 99 | 100 | 1. 内存防溢出:独占资源分配器设定指定数量的独占资源,当所有资源被分配占用时,再次 commit 必须等待之前的任务推理完成,释放资源;因此内存占用恒定,大量 commit 不会造成服务器内存溢出崩溃; 101 | 2. 预处理和推理解耦:独占数据资源数量一般为最大 batch_size 两倍,此时,满足一个 batch_size 进行预处理,一个 batch_size 进行推理,即预备一个 batch_size,实现 prefetch 的功能,从而使预处理和推理同时进行,提高推理效率; 102 | 103 | ## 分解推理 104 | 105 | 分解推理:将推理任务拆分为若干子任务,合并子任务推理结果,最终得到完整的推理结果;分解推理是通过 `jobset`任务集类型实现的。 106 | 107 | **场景一: 遥感图像目标检测** 108 | 将遥感大图 A 拆分为瓦片图 `A1、A2、...、An`,对应创建子任务 `job1、job2、...、jobn`,融合每个瓦片图检测结果,获取大图检测结果; 109 | 110 | # 测试 111 | 112 | 1. 下载balloon 数据集的目标检测模型文件; 113 | 114 | [下载地址](https://pan.baidu.com/s/13HCb_N-Gc1oLp2DG1l2yJw),提取码:lhvs 115 | 116 | 117 | ``` 118 | balloon/ 119 | config.py 120 | yolo.onnx 121 | ``` 122 | 123 | 2. 启动服务 124 | 125 | 切换路径至根目录下: 126 | 127 | ```bash 128 | python app.py 129 | ``` 130 | 131 | 3. 测试接口 132 | 133 | 上传 workspace 目录下 balloon.jpg; 134 | 135 |  136 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | import argparse 3 | from enum import Enum 4 | from fastapi import FastAPI, UploadFile, File 5 | 6 | from pyinfer.core.build import build_infer, build_logger 7 | from pyinfer.core.infer import Infer 8 | from pyinfer.utils.common.config import Config 9 | 10 | 11 | 12 | 13 | app = FastAPI() 14 | 15 | infers = {} # 存储初始化后的推理器 16 | 17 | 18 | class OnlineInferName(str, Enum): 19 | """app上线的推理服务""" 20 | DetectionInfer = "Detection目标检测推理" 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--config", default="/volume/wzy/project/PyInfer/applications/balloon/config.py") 26 | return parser.parse_args() 27 | 28 | 29 | def infer_func(infer_name): 30 | 31 | async def wrapper(file: UploadFile = File(...)): 32 | """ 33 | \f 34 | :param infer_name: 推理器名称; 35 | :param file: API接口输入文件流; 36 | 37 | 1. 获取推理器; 38 | 2. 解析文件流至推理器输入; 39 | 3. 提交输入; 40 | 4. 返回推理结果; 41 | 42 | """ 43 | infer: Infer = infers.get(infer_name) 44 | input = infer.parse_raw(filename=file.filename, raw_data=await file.read()) 45 | future = infer.commit(input) 46 | return future if future is None else await future 47 | 48 | return wrapper 49 | 50 | 51 | @app.on_event("startup") 52 | async def startup_event(): 53 | """ 54 | 初始化所有待上线的推理器 55 | 56 | 1. 依据推理器名称,获取相应的配置参数; 57 | 2. 创建并初始化推理器; 58 | 3. 一个推理器初始化成功,则添加其对应的推理服务路由 59 | 60 | """ 61 | args = parse_args() 62 | cfg = Config.fromfile(args.config) 63 | 64 | logger = build_logger(cfg.log) 65 | for item in OnlineInferName: 66 | if cfg.infer.get(item.name) is None: # 1 67 | logger.error(f"{item.name} config lack, build infer failed.") 68 | continue 69 | 70 | infer = build_infer(cfg.infer.get(item.name), logger=logger) # 2 71 | app.post(f'/infer/{item.name}', response_model=infer.Output, 72 | tags=[item.value])(infer_func(item.name)) # 3 73 | infers[item.name] = infer 74 | 75 | 76 | @app.on_event("shutdown") 77 | def shutdown_event(): 78 | """销毁推理器""" 79 | for infer in infers.values(): 80 | infer.destory() 81 | 82 | 83 | @app.get('/health', tags=["服务状态"], summary="") 84 | async def health(): 85 | """服务网络状态测试""" 86 | return {'START': 'UP'} 87 | 88 | 89 | if __name__ == "__main__": 90 | uvicorn.run(app, port = 8805) 91 | -------------------------------------------------------------------------------- /applications/balloon/config.py: -------------------------------------------------------------------------------- 1 | # 支持三种环境变量: {{ AppFolder }}、{{ WorkspaceFolder }}、{{ RootFolder }} 2 | 3 | log = dict(filename=None, level="INFO") 4 | 5 | 6 | infer = dict( 7 | DetectionInfer=dict( 8 | type="DetectionInfer", 9 | engine=dict(type="OnnxInferEngine", 10 | onnx_file = "{{ AppFolder }}/balloon/yolox.onnx"), 11 | hooks=[dict(type="DrawBBoxHook", 12 | out_dir="{{ WorkspaceFolder }}")], 13 | confidence_threshold=0.7, 14 | nms_threshold=0.2, 15 | max_batch_size=16, 16 | max_object=100, 17 | width=640, 18 | height=640, 19 | workspace="{{ WorkspaceFolder }}", 20 | slice=False, 21 | subsize=640, 22 | rate=1, 23 | gap=200, 24 | padding=True, 25 | labelnames=("balloon", ), 26 | device="cuda")) 27 | -------------------------------------------------------------------------------- /doc/demo_api.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/doc/demo_api.png -------------------------------------------------------------------------------- /doc/head.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/doc/head.png -------------------------------------------------------------------------------- /doc/infer_flow.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/doc/infer_flow.gif -------------------------------------------------------------------------------- /doc/基本流程.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | aiohttpaiohttp创建job创建job非阻塞返回非阻塞返回forwardforwardjob queuejob queueinputinputroute funcroute func异步await结果 异步await结果 futurefutureset value/触发事件set value/触发事件httphttpfuture resultfuture resultroute funcroute funccommitcommit推理推理waits_for_jobswaits_for_jobsInferEngineInferEngine生产者生产者任务队列任务队列消费者消费者批量获取job批量获取jobworkworkViewer does not support full SVG 1.1 -------------------------------------------------------------------------------- /doc/数据流动.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | JobJobinputinputmono_datamono_datatraitstraitsfuturefutureinputinputoutputoutputinputinput提取数据特征提取数据特征inputinputset_resultset_resultoutputoutputInferInferinputinput⑦ 存储后处理结果⑦ 存储后处理结果postprocesspostprocess③ 存储预处理结果③ 存储预处理结果preprocesspreprocessinputinput① 存储输入数据① 存储输入数据commitcommit② 待预处理② 待预处理⑤ 存储推理结果⑤ 存储推理结果workwork④ 待推理④ 待推理⑥ 待后处理 ⑥ 待后处理 inputinputtraitstraitsmono_datamono_dataoutputoutputfuturefutureoutputoutputinputinputcommitcommitpreprocesspreprocessworkworkViewer does not support full SVG 1.1 -------------------------------------------------------------------------------- /doc/独占资源分配器.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | forwardforwardmodelmodelpreprocesspreprocesscommitcommit独占资源分配器 Monopoly Allocator独占资源分配器 Monopoly AllocatorMonopolyDataMonopoly...MonopolyDataMonopoly...MonopolyDataMonopoly...MonopolyDataMonopoly...MonopolyDataMonopoly...MonopolyDataMonopoly...容量 capacity容量 capacityunavailableunavailableavailableavailablequeryquery分配独占数据资源分配独占数据资源unavailableunavailable请求独占数据资源请求独占数据资源Viewer does not support full SVG 1.1 -------------------------------------------------------------------------------- /pyinfer/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import dirname as dname 3 | 4 | # SET ENV 5 | os.environ["RootFolder"] = dname(dname(__file__)) 6 | os.environ["WorkspaceFolder"] = f"{os.environ['RootFolder']}/workspace" 7 | os.environ["AppFolder"] = f"{os.environ['RootFolder']}/applications" -------------------------------------------------------------------------------- /pyinfer/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .engine import * 2 | from .infer import * 3 | from .hook import * -------------------------------------------------------------------------------- /pyinfer/core/build.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from ..utils.common.registry import INFERS, ENGINES, HOOKS 3 | from ..utils.common.logger import Logger 4 | 5 | __all__ = ["build_infer", "build_engine", "build_hook"] 6 | 7 | 8 | def build_infer(config: Dict, **kwargs): 9 | if INFERS.get(config.get('type')) is None: 10 | raise KeyError(f"Cannot found infer type {config.get('type')}.") 11 | 12 | infer = INFERS.get(config.pop('type'))(**kwargs) 13 | start_params = infer.StartParams.parse_obj(config) 14 | if not infer.startup(start_params): 15 | return 16 | else: 17 | return infer 18 | 19 | 20 | def build_engine(config: Dict, **kwargs): 21 | if ENGINES.get(config.get('type')) is None: 22 | raise KeyError(f"Cannot found infer engine {config.get('type')}.") 23 | 24 | engine = ENGINES.get(config.pop('type'))(**kwargs) 25 | if not engine.build(**config): 26 | return 27 | else: 28 | return engine 29 | 30 | 31 | def build_hook(config): 32 | return HOOKS.build(config) 33 | 34 | 35 | def build_logger(config) -> Logger: 36 | return Logger(**dict(config)) 37 | -------------------------------------------------------------------------------- /pyinfer/core/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .engine import OnnxInferEngine 2 | -------------------------------------------------------------------------------- /pyinfer/core/engine/engine.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABCMeta 2 | import numpy as np 3 | import onnxruntime 4 | from ...utils.common.registry import ENGINES 5 | from ...utils.common.logger import Logger 6 | 7 | __all__ = ["OnnxInferEngine"] 8 | 9 | 10 | 11 | 12 | class InferEngine(metaclass=ABCMeta): 13 | 14 | def __init__(self, device, logger=None, **kwargs) -> None: 15 | self.device = device 16 | self.logger = Logger() if logger is None else logger 17 | 18 | @abstractmethod 19 | def build(self): 20 | pass 21 | 22 | @abstractmethod 23 | def forward(self, job): 24 | pass 25 | 26 | 27 | @ENGINES.register_module() 28 | class OnnxInferEngine(InferEngine): 29 | __providers__ = {"cuda":"CUDAExecutionProvider", "TensorRT":"TensorrtExecutionProvider", "cpu":"TensorrtExecutionProvider"} 30 | 31 | def build(self, onnx_file): 32 | providers = onnxruntime.get_available_providers() 33 | 34 | if self.__providers__[self.device] not in providers: 35 | self.logger.fatal(f"Onnxruntime lack {self.__providers__[self.device]}, build model failed.") 36 | return False 37 | 38 | self.logger.info(f"onnxruntime use [{self.__providers__[self.device]}], support [{','.join(providers)}]") 39 | self.session = onnxruntime.InferenceSession(onnx_file, providers=providers) 40 | self.input_names = [inp.name for inp in self.session.get_inputs()] 41 | self.output_names = [out.name for out in self.session.get_outputs()] 42 | 43 | return True 44 | 45 | def forward(self, batch_input: np.ndarray): 46 | """ 47 | 默认onnx单输入单输出 48 | 49 | batch_input[batch_size, ...] 50 | """ 51 | # 输入格式转换 52 | batch_input = batch_input.transpose(0,3,1,2) 53 | batch_output = self.session.run(self.output_names, {self.input_names[0]:batch_input})[0] 54 | return batch_output 55 | -------------------------------------------------------------------------------- /pyinfer/core/hook/__init__.py: -------------------------------------------------------------------------------- 1 | from .hooks import DrawBBoxHook 2 | -------------------------------------------------------------------------------- /pyinfer/core/hook/hooks.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from typing import List 4 | from PIL import Image, ImageDraw 5 | from ..job import Job 6 | from ...utils.common.registry import HOOKS 7 | from ...utils.detection.bbox import QuadrangleBBox 8 | 9 | __all__ = ["DrawBBoxHook"] 10 | 11 | 12 | class Hook(): 13 | 14 | def __init__(self): 15 | pass 16 | 17 | def after_set_result(self, job): 18 | pass 19 | 20 | 21 | @HOOKS.register_module() 22 | class DrawBBoxHook(Hook): 23 | 24 | def __init__(self, out_dir, prefix = "infer_") -> None: 25 | self.out_dir = out_dir 26 | self.prefix = prefix 27 | 28 | def after_set_result(self, job: Job): 29 | """绘制目标框""" 30 | image_np, filename = job.input.image.astype(np.uint8), job.input.filename 31 | bboxes: List[QuadrangleBBox] = job.future.result().bboxes 32 | 33 | image_background = Image.fromarray(image_np) 34 | draw_background = ImageDraw.Draw(image_background) 35 | for bbox in bboxes: 36 | draw_background.rectangle((bbox.left,bbox.top,bbox.right,bbox.bottom),fill='#FF00FF') 37 | 38 | image = Image.fromarray(cv2.addWeighted(image_np, 0.8, np.array(image_background),0.2, 0)) 39 | draw = ImageDraw.Draw(image) 40 | for bbox in bboxes: 41 | draw.rectangle((bbox.left,bbox.top,bbox.right,bbox.bottom),outline='#FF00FF') 42 | draw.text((bbox.left,bbox.top),bbox.labelname) 43 | cv2.imwrite(f"{self.out_dir}/{self.prefix}{filename}", np.array(image)) 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /pyinfer/core/infer/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection import DetectionInfer 2 | from .base import Infer -------------------------------------------------------------------------------- /pyinfer/core/infer/base.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import threading 3 | import numpy as np 4 | from abc import abstractmethod, ABCMeta 5 | from typing import List, Union, Any, Dict 6 | from pydantic import BaseModel, Field 7 | 8 | from ..mono_allocator import MonoAllocator 9 | from ..job import Job, JobSet 10 | from ...utils.common.logger import Logger 11 | from ..build import build_engine, build_hook 12 | 13 | __all__ = ["Infer"] 14 | 15 | 16 | class Infer(metaclass=ABCMeta): 17 | """ 18 | 推理器基类,接口类 19 | 20 | parser_raw commit preprocess append await job future 21 | 生产者:raw_input----------------> infer input--------->job-------------->valid job--------->job queue---------------------> 22 | 23 | engine forward postprocess set_result 24 | 消费者:job.mono_data.input---------------------> job.mono_data.output---------------->job.output-----------------> future result 25 | """ 26 | class Input(BaseModel): 27 | """输入""" 28 | pass 29 | 30 | class Output(BaseModel): 31 | """输出""" 32 | pass 33 | 34 | class StartParams(BaseModel): 35 | """推理配置参数""" 36 | engine: Dict = Field(description="infer engine config dict") 37 | hooks: List[Dict] = Field(description="钩子", default=[]) 38 | max_batch_size: int 39 | workspace: str = Field(default="workspace/") 40 | 41 | def __init__(self, logger=None): 42 | self._loop = asyncio.get_running_loop() # 绑定异步事件循环,实现基于异步future的事件通知机制; 43 | self._cond = threading.Condition() 44 | self._job_queue = [] 45 | self.logger = Logger() if logger is None else logger 46 | self._hooks = [] 47 | 48 | def startup(self, start_params: StartParams): 49 | """启动推理消费线程""" 50 | self.start_params = start_params 51 | self._run = True 52 | self._mono_allocator = MonoAllocator( 53 | self.start_params.max_batch_size * 2) 54 | start_job = Job() 55 | t = threading.Thread(target=self.work, args=(start_job, )) 56 | t.start() 57 | # 阻塞,确保模型引擎初始化完成 58 | return start_job.future.result() 59 | 60 | @abstractmethod 61 | def parse_raw(self, filename, raw_data) -> Union[Input, List[Input]]: 62 | """ 63 | 由网络输入解析为Infer的Input格式,或List[Input]格式 64 | """ 65 | pass 66 | 67 | @abstractmethod 68 | def preprocess(self, job: Job): 69 | """前处理""" 70 | pass 71 | 72 | @abstractmethod 73 | def postprocess(self, job: Job): 74 | """后处理""" 75 | pass 76 | 77 | 78 | 79 | def commit(self, inp: Union[Input, List[Input]]): 80 | """ 81 | 提交任务 82 | 83 | 将输入封装成job,提交至任务队列; 84 | """ 85 | if isinstance(inp, list): 86 | future = self.__commits(inp) 87 | else: 88 | future = self.__commit(inp) 89 | return future 90 | 91 | 92 | def __commit(self, inp: Input): 93 | "提交单个任务" 94 | job = Job(inp, self._loop) 95 | 96 | if not self.preprocess(job): 97 | return 98 | 99 | # 预处理完成,将valid job添加至任务队列 100 | with self._cond: 101 | self._job_queue.append(job) 102 | self._cond.notify() 103 | 104 | # 添加钩子 105 | for hook in self._hooks: 106 | job.add_done_call_back(hook.after_set_result) 107 | 108 | return job.future 109 | 110 | def __commits(self, inps: List[Input]): 111 | """ 112 | 提交任务集 113 | 114 | inps[0]表示原始输入,inps[1:]表示原始输入拆分生成的n个子输入; 115 | """ 116 | assert len(inps) >= 2, "inps num less than 2, check parse_raw method." 117 | real_inp, *sub_inps = inps 118 | 119 | # 创建子任务 120 | jobs = [Job(sub_inp, self._loop) for sub_inp in sub_inps] 121 | # 若存在无法预处理的子任务,则提交失败; 122 | if not all([self.preprocess(job) for job in jobs]): 123 | return 124 | 125 | # 预处理完成,将子任务添加至任务队列 126 | with self._cond: 127 | self._job_queue.extend(jobs) 128 | self._cond.notify() 129 | 130 | # 创建任务集 131 | jobset = JobSet(jobs, self.collect_fn, real_inp, self._loop) 132 | 133 | # 添加钩子 134 | for hook in self._hooks: 135 | jobset.add_done_call_back(hook.after_set_result) 136 | return jobset.future 137 | 138 | def collect_fn(self, jobset: JobSet) -> Any: 139 | """ 140 | 任务集结果融合:汇集所有子任务推理结果,collect_fn返回值作为任务集结果;当推理输入被拆分为多个子输入分别推理时,collect_fn必须重写; 141 | 142 | Args: 143 | jobset: 任务集 144 | 145 | return: Any, 返回值保存在jobset.future.result; 146 | """ 147 | raise NotImplementedError 148 | 149 | def wait_for_job(self) -> Job: 150 | """获取任务""" 151 | with self._cond: 152 | self._cond.wait_for( 153 | lambda: (len(self._job_queue) > 0 or not self._run)) 154 | if not self._run: 155 | return 156 | job = self._job_queue.pop(0) 157 | return job 158 | 159 | def wait_for_jobs(self) -> List[Job]: 160 | """获取一批任务""" 161 | with self._cond: 162 | self._cond.wait_for( 163 | lambda: (len(self._job_queue) > 0 or not self._run)) 164 | if not self._run: 165 | return [] 166 | max_size = min(self.start_params.max_batch_size, 167 | len(self._job_queue)) 168 | fetch_jobs = self._job_queue[:max_size] 169 | self._job_queue = self._job_queue[max_size:] 170 | return fetch_jobs 171 | 172 | def work(self, start_job: Job): 173 | """推理""" 174 | 175 | # 1.推理引擎创建并初始化 176 | engine = build_engine(self.start_params.engine, device =self.start_params.device , logger=self.logger) 177 | if engine is None: 178 | start_job.future.set_result(False) # 初始化推理引擎失败,退出推理消费者线程 179 | return 180 | 181 | for hook_cfg in self.start_params.hooks: 182 | hook = build_hook(hook_cfg) 183 | if hook is None: 184 | start_job.future.set_result(False) # 初始化钩子失败,退出推理消费者线程 185 | return 186 | self._hooks.append(hook) 187 | 188 | # 启动成功 189 | start_job.future.set_result(True) 190 | 191 | while self._run: 192 | # 2.取任务:从任务队里中取出预处理完成的job 193 | fetch_jobs = self.wait_for_jobs() 194 | if len(fetch_jobs) == 0: # 当推理停止时,fetch_jobs返回为空 195 | for job in self._job_queue: 196 | job.future.set_result(job.output) # 未完成job推理结果置None 197 | break 198 | 199 | # 3.组合batch 200 | batch_input = np.stack([job.mono_data.input for job in fetch_jobs]) 201 | # 4.推理 202 | batch_output = engine.forward(batch_input) 203 | self.logger.info( 204 | f"Infer batch size({len(fetch_jobs)}), wait jobs({len(self._job_queue)})") 205 | for index, job in enumerate(fetch_jobs): 206 | # 5.取模型输出结果:engine->job.mono_data.output 207 | job.mono_data.output = batch_output[index] 208 | # 6.后处理,job.mono_data.output->postprocess->job.output 209 | self.postprocess(job) 210 | # 7.释放独占数据资源 211 | job.mono_data.release() 212 | job.mono_data = None 213 | # 8.返回 214 | job.future.set_result(job.output) 215 | 216 | def __del__(self): 217 | with self._cond: 218 | self._run = False 219 | self._cond.notify() 220 | self.logger.info(f"{self.__class__.__name__} is destoryed.") 221 | 222 | def destory(self): 223 | self.__del__() 224 | -------------------------------------------------------------------------------- /pyinfer/core/infer/detection.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | from typing import List, Union, Dict, Tuple 5 | from pydantic import BaseModel, Field 6 | from .base import Infer 7 | 8 | from ..job import Job, JobSet 9 | 10 | from ...utils.detection.nms import cpu_nms 11 | from ...utils.detection.bbox import QuadrangleBBox 12 | from ...utils.functional.slice import slice_one_image 13 | from ...utils.functional.traits import WarpAffineTraits 14 | from ..build import INFERS 15 | 16 | @INFERS.register_module() 17 | class DetectionInfer(Infer): 18 | class Input(BaseModel): 19 | filename: str = Field(describe="图像文件名称") 20 | image: Union[np.ndarray, None] = Field( 21 | default=None, describe="图像解析后numpy对象") 22 | ox: float = Field(default=0, describe="image坐标系原点x坐标") 23 | oy: float = Field(default=0, describe="image坐标系原点y坐标") 24 | 25 | class Config: 26 | arbitrary_types_allowed = True 27 | 28 | class Output(BaseModel): 29 | bbox_num: int 30 | bboxes: List[QuadrangleBBox] 31 | 32 | def to_array(self): 33 | ret = [] 34 | for bbox in self.bboxes: 35 | ret.append([bbox.left, bbox.right, bbox.top, 36 | bbox.bottom, bbox.label]) 37 | return np.array(ret).astype(np.int) 38 | 39 | class StartParams(BaseModel): 40 | engine: Dict 41 | confidence_threshold: float = Field(gt=0, describe="置信度阈值") 42 | nms_threshold: float = Field(gt=0.0, describe="NMS阈值") 43 | max_batch_size: int = Field(gt=0, describe="推理时最大batch_size") 44 | max_object: int = Field(gt=0, describe="图像包含的最大目标数量") 45 | width: int = Field(gt=0, describe="推理时送入网络的图像宽度") 46 | height: int = Field(gt=0, describe="推理时送入网络的图像高度") 47 | workspace: str = Field(default="workspace/") 48 | labelnames: Tuple = Field(describe="label names") 49 | timeout: int = Field(default=10, describe="提交job的最大等待时间,若超时,则提交任务失败.") 50 | hooks: List[Dict] = Field(description="钩子", default=[]) 51 | slice: bool = Field(description="是否启动分片推理模式", default=False) 52 | device: str = Field(description="推理设备", default="cuda:0") 53 | # 分片推理模式下需要配置的参数 54 | subsize: int = Field(default=640, describe="瓦片大小") 55 | rate: int = Field(default=1, describe="大图resize系数") 56 | gap: int = Field(default=200, describe="瓦片间重叠大小") 57 | padding: bool = Field(default=True, describe="瓦片是否padding至subsize大小") 58 | 59 | def parse_raw(self, filename, raw_data): 60 | """数据解析为Input""" 61 | if self.start_params.slice: 62 | return self.__slice_parse_raw(filename, raw_data) 63 | else: 64 | return self.__parse_raw(filename, raw_data) 65 | 66 | def __parse_raw(self, filename, raw_data): 67 | """数据解析为Input, 非分片推理""" 68 | image = cv2.imdecode(np.frombuffer(raw_data, np.uint8), 69 | cv2.IMREAD_COLOR).astype(np.float32) 70 | inp = self.Input.parse_obj(dict(filename=filename, image=image)) 71 | return inp 72 | 73 | def __slice_parse_raw(self, filename, raw_data): 74 | """数据解析为Input, 分片推理,将整图拆分为多个Input""" 75 | image = cv2.imdecode(np.frombuffer(raw_data, np.uint8), 76 | cv2.IMREAD_COLOR).astype(np.float32) 77 | image_base_name, extension = os.path.splitext(filename) 78 | patches = slice_one_image(image=image, 79 | image_base_name=image_base_name, 80 | subsize=self.start_params.subsize, 81 | rate=self.start_params.rate, 82 | gap=self.start_params.gap, 83 | padding=self.start_params.padding, 84 | bboxes=None) 85 | # sub input 86 | inps = [ 87 | self.Input.parse_obj( 88 | dict(filename=f"{patch_name}{extension}", image=patch_image, ox=left, oy=top)) 89 | for patch_name, rate, left, top, patch_image, patch_bboxes in patches 90 | ] 91 | # real input 92 | inps.insert(0, self.Input.parse_obj( 93 | dict(filename=filename, image=image, left=0, top=0))) 94 | return inps 95 | 96 | def collect_fn(self, jobset: JobSet): 97 | """slice分片结果融合""" 98 | bboxes: List[QuadrangleBBox] = [] 99 | for job in jobset.jobs: 100 | ox, oy = job.input.ox, job.input.oy 101 | # 取每个瓦片图的检测框 102 | patch_bboxes: List[QuadrangleBBox] = job.output.bboxes 103 | # 检测框更新原点至(0,0) 104 | bboxes.extend([bbox.reset_origin(-ox, -oy) 105 | for bbox in patch_bboxes]) 106 | bboxes = cpu_nms(bboxes, self.start_params.nms_threshold) # nms 107 | output = self.Output(bboxes=bboxes, bbox_num=len(bboxes)) # 返回大图检测框 108 | return output 109 | 110 | def preprocess(self, job: Job): 111 | """预处理""" 112 | if job.input is None or job.input.image is None: 113 | self.logger.error( 114 | f"input image is empty. Please check {job.input.filename}.") 115 | return False 116 | 117 | # 预处理要求获取独占数据资源 118 | mono_data = self._mono_allocator.query(self.start_params.timeout) 119 | if mono_data is None: 120 | self.logger.error("query mono data timeout.") 121 | return False 122 | else: 123 | job.mono_data = mono_data 124 | 125 | # 通过warpaffineTraits萃取器,计算仿射变换矩阵、逆矩阵、并完成仿射变换 126 | job.traits = WarpAffineTraits(job.input.image.shape[1], job.input.image.shape[0], self.start_params.width, 127 | self.start_params.height) 128 | # mono_data.input存储预处理结果,作为模型推理的输入 129 | job.mono_data.input = job.traits(job.input.image) 130 | return True 131 | 132 | def decode(self, job): 133 | bboxes = [] 134 | # 从独占数据资源中取出模型推理结果, results=List[bbox]格式,其中bbox为[left,top,right,width, height, confidence, *scores] 135 | results = job.mono_data.output 136 | for result in results: 137 | xc, yc, width, height, confidence, *scores = result 138 | if confidence < self.start_params.confidence_threshold: 139 | continue 140 | 141 | max_score_index = np.argmax(scores) 142 | max_score = scores[max_score_index] 143 | if max_score*confidence < self.start_params.confidence_threshold: 144 | continue 145 | 146 | label = max_score_index 147 | # 通过逆仿射变换将bbox坐标映射回原图 148 | left, top = job.traits.to_src_coord(xc - width/2, yc - height/2) 149 | right, bottom = job.traits.to_src_coord(xc + width/2, yc + height/2) 150 | labelname = self.start_params.labelnames[label] 151 | bbox = QuadrangleBBox(x1=left, 152 | y1=top, 153 | x2=right, 154 | y2=top, 155 | x3=right, 156 | y3=bottom, 157 | x4=left, 158 | y4=bottom, 159 | confidence=confidence, 160 | label=label, 161 | labelname=labelname, 162 | keepflag=True) 163 | bboxes.append(bbox) 164 | return bboxes 165 | 166 | def nms(self, bboxes:List[QuadrangleBBox]): 167 | return cpu_nms(bboxes, self.start_params.nms_threshold) 168 | 169 | def postprocess(self, job): 170 | """后处理""" 171 | 172 | bboxes = self.nms(self.decode(job)) 173 | job.output = self.Output(bbox_num=len(bboxes), bboxes = bboxes) -------------------------------------------------------------------------------- /pyinfer/core/job.py: -------------------------------------------------------------------------------- 1 | from concurrent import futures 2 | from typing import List 3 | 4 | 5 | class Job(): 6 | def __init__(self, inp=None, loop=None) -> None: 7 | """ 8 | 推理job 9 | 10 | :param input: 推理输入 11 | :param output: 推理输出 12 | :param mono_data: 独占数据资源,相当于workspace,推理任务必须获取独占数据资源,才能被提交至推理队列; 13 | mono_data.input记录输入预处理后数据,等待被送入模型,mono_data.output记录模型输出数据,等待被后处理; 14 | :param traits: 特征萃取,提取当前输入的相关特征,辅助预处理和后处理; 15 | :param future: 非阻塞,记录响应结果;loop非None时,为异步job,需要在异步函数配合await使用,否则为同步; 16 | 17 | """ 18 | self.input = inp 19 | self.output = None 20 | self.traits = None 21 | self.mono_data = None 22 | self.future = futures.Future() if loop is None else loop.create_future() 23 | self.call_backs = [] 24 | self.future.add_done_callback(self.run_call_back) 25 | 26 | def add_done_call_back(self, call_back_fn): 27 | self.call_backs.append(call_back_fn) 28 | 29 | def run_call_back(self, future): 30 | for cb in self.call_backs: 31 | cb(self) 32 | 33 | 34 | class JobSet(Job): 35 | def __init__(self, jobs: List[Job], collect_fn, inp=None, loop=None) -> None: 36 | """ 37 | 任务集合:当所有子任务均完成时,JobSet完成 38 | 39 | 40 | :param jobs: 任务列表 41 | :param loop: loop非None时,为异步job,需要在异步函数配合await使用,否则为同步; 42 | :param collect_fn: 推理结果回收,collect_fn输入为所有job的future.result列表,collect_fn返回值为JobSet.future.result,即整个任务集合的结果; 43 | """ 44 | self.collect_fn = collect_fn 45 | self.jobs = jobs 46 | for job in self.jobs: 47 | job.add_done_call_back(self.notify) 48 | 49 | super().__init__(inp, loop) 50 | 51 | def done(self): 52 | # 所有job完成,且future未执行set_result时,避免多次重复set_result 53 | return all([job.future.done() for job in self.jobs]) and not self.future.done() 54 | 55 | def notify(self, job): 56 | if self.done(): 57 | self.future.set_result(self.collect_fn(self)) 58 | -------------------------------------------------------------------------------- /pyinfer/core/mono_allocator.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import weakref 3 | import time 4 | from ..utils.common.logger import Logger 5 | 6 | __all__ = ["MonoAllocator"] 7 | 8 | 9 | class MonoAllocator(): 10 | """独占数据资源分配器""" 11 | 12 | class MonoData(): 13 | """独占数据资源""" 14 | 15 | def __init__(self, name, allocator, available=True) -> None: 16 | """ 17 | input: 存储模型网络的输入数据 18 | output: 存储模型网络的输出数据 19 | """ 20 | self.name = name 21 | self._available = available 22 | self._allocator = allocator 23 | self.input = None 24 | self.output = None 25 | self.workspace = None 26 | 27 | 28 | def release(self): 29 | self._allocator.release_one(self) 30 | 31 | def __init__(self, size: int) -> None: 32 | """ 33 | :param 34 | size : 独占数据资源数量,建议设置为batch_size*2; 35 | self.datas : 独占数据资源列表; 36 | self._num_avail : 剩余资源数量; 37 | self._num_thread_wait: 等待资源的线程数量; 38 | self._run: 当程序终止时,触发析构__del__, _run=False,停止资源分配,在等待资源的线程陆续退出等待; 39 | self._lock:线程锁; 40 | self._cond:线程同步,同步资源数量; 41 | self._cond_exit:线程同步,同步等待资源的线程数量,停止资源分配_run=False时,退出时需要等待_num_thread_wait为0; 42 | 43 | """ 44 | self.datas = [self.MonoData(f"mono_data_{i}", weakref.proxy(self)) for i in range(size)] 45 | self._num_avail = size 46 | self._num_thread_wait = 0 47 | self._run = True 48 | self._lock = threading.Lock() 49 | self._cond = threading.Condition(self._lock) 50 | self._cond_exit = threading.Condition(self._lock) 51 | self.logger = Logger() 52 | 53 | def query(self, timeout=10): 54 | """请求独占数据资源""" 55 | with self._cond: 56 | if not self._run: 57 | # 推理终止时停止分配数据资源 58 | return 59 | 60 | # 等待独占数据资源 61 | if self._num_avail == 0: 62 | self._num_thread_wait += 1 # 排队线程数+1 63 | state = self._cond.wait_for(lambda: (self._num_avail > 0 or not self._run), timeout) 64 | self._num_thread_wait -= 1 # 排队线程数-1 65 | # 更新析构状态 66 | self._cond_exit.notify() 67 | # 未获取到独占数据资源,可能是请求超时或分配器停止分配 68 | if not state or self._num_avail == 0 or not self._run: 69 | return 70 | 71 | # 返回请求得到的独占数据资源 72 | for mono_data in self.datas: 73 | if mono_data._available: 74 | self.logger.debug(f"occupy {mono_data.name}") 75 | mono_data._available = False 76 | self._num_avail -= 1 77 | return mono_data 78 | 79 | def release_one(self, mono_data: MonoData): 80 | """释放独占数据资源所有权""" 81 | with self._cond: 82 | self._num_avail += 1 83 | mono_data._available = True 84 | mono_data.input = None 85 | mono_data.output = None 86 | self._cond.notify_all() 87 | time.sleep(1e-4) 88 | self.logger.debug(f"release {mono_data.name}") 89 | 90 | def __del__(self): 91 | """对象析构时调用""" 92 | with self._cond: 93 | self._run = False 94 | self._cond.notify_all() 95 | 96 | with self._cond_exit: 97 | # 等待所有等待线程退出wait状态 98 | self._cond_exit.wait_for(lambda: (self._num_thread_wait == 0)) 99 | 100 | self.logger.info("MonoAllocator is destoryed.") 101 | 102 | def destory(self): 103 | self.__del__() 104 | -------------------------------------------------------------------------------- /pyinfer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/pyinfer/utils/__init__.py -------------------------------------------------------------------------------- /pyinfer/utils/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/pyinfer/utils/common/__init__.py -------------------------------------------------------------------------------- /pyinfer/utils/common/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import ast 3 | import copy 4 | from hashlib import new 5 | import os 6 | import os.path as osp 7 | import platform 8 | import shutil 9 | import sys 10 | import tempfile 11 | import types 12 | import uuid 13 | import warnings 14 | from argparse import Action, ArgumentParser 15 | from collections import abc 16 | from importlib import import_module 17 | from pathlib import Path 18 | 19 | from addict import Dict 20 | from yapf.yapflib.yapf_api import FormatCode 21 | import re 22 | 23 | 24 | if platform.system() == 'Windows': 25 | import regex as re # type: ignore 26 | else: 27 | import re # type: ignore 28 | 29 | BASE_KEY = '_base_' 30 | DELETE_KEY = '_delete_' 31 | DEPRECATION_KEY = '_deprecation_' 32 | RESERVED_KEYS = ['filename', 'text', 'pretty_text'] 33 | 34 | 35 | def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): 36 | if not osp.isfile(filename): 37 | raise FileNotFoundError(msg_tmpl.format(filename)) 38 | 39 | 40 | def import_modules_from_strings(imports, allow_failed_imports=False): 41 | """Import modules from the given list of strings. 42 | 43 | Args: 44 | imports (list | str | None): The given module names to be imported. 45 | allow_failed_imports (bool): If True, the failed imports will return 46 | None. Otherwise, an ImportError is raise. Default: False. 47 | 48 | Returns: 49 | list[module] | module | None: The imported modules. 50 | 51 | Examples: 52 | >>> osp, sys = import_modules_from_strings( 53 | ... ['os.path', 'sys']) 54 | >>> import os.path as osp_ 55 | >>> import sys as sys_ 56 | >>> assert osp == osp_ 57 | >>> assert sys == sys_ 58 | """ 59 | if not imports: 60 | return 61 | single_import = False 62 | if isinstance(imports, str): 63 | single_import = True 64 | imports = [imports] 65 | if not isinstance(imports, list): 66 | raise TypeError( 67 | f'custom_imports must be a list but got type {type(imports)}') 68 | imported = [] 69 | for imp in imports: 70 | if not isinstance(imp, str): 71 | raise TypeError( 72 | f'{imp} is of type {type(imp)} and cannot be imported.') 73 | try: 74 | imported_tmp = import_module(imp) 75 | except ImportError: 76 | if allow_failed_imports: 77 | warnings.warn(f'{imp} failed to import and is ignored.', 78 | UserWarning) 79 | imported_tmp = None 80 | else: 81 | raise ImportError 82 | imported.append(imported_tmp) 83 | if single_import: 84 | imported = imported[0] 85 | return imported 86 | 87 | 88 | class ConfigDict(Dict): 89 | def __init__(self, config_dict): 90 | """支持三种环境变量""" 91 | self.pattern_dict = {re.compile('{{ AppFolder }}'): os.environ["AppFolder"], 92 | re.compile('{{ WorkspaceFolder }}'):os.environ["WorkspaceFolder"], 93 | re.compile('{{ RootFolder }}'):os.environ["RootFolder"]} 94 | 95 | config_dict = self.identify_env_variables(config_dict) 96 | self.__delattr__("pattern_dict") 97 | super().__init__(config_dict) 98 | 99 | def identify_env_variables(self, config_dict): 100 | for k, v in config_dict.items(): 101 | if isinstance(v,dict): 102 | config_dict[k] = self.identify_env_variables(v) 103 | elif isinstance(v,str): 104 | for pattern , repl in super().__getattr__('pattern_dict').items(): 105 | new_v = re.sub(pattern, repl, v) 106 | if new_v != config_dict[k]: 107 | config_dict[k] = new_v 108 | break 109 | return config_dict 110 | 111 | 112 | 113 | 114 | def __missing__(self, name): 115 | raise KeyError(name) 116 | 117 | def __getattr__(self, name): 118 | try: 119 | value = super().__getattr__(name) 120 | """识别环境变量""" 121 | if isinstance(value, str): 122 | for pattern , repl in super().__getattr__('pattern_dict').items(): 123 | value = re.sub(pattern, repl, value) 124 | except KeyError: 125 | ex = AttributeError(f"'{self.__class__.__name__}' object has no " 126 | f"attribute '{name}'") 127 | except Exception as e: 128 | ex = e 129 | else: 130 | return value 131 | raise ex 132 | 133 | 134 | def add_args(parser, cfg, prefix=''): 135 | for k, v in cfg.items(): 136 | if isinstance(v, str): 137 | parser.add_argument('--' + prefix + k) 138 | elif isinstance(v, int): 139 | parser.add_argument('--' + prefix + k, type=int) 140 | elif isinstance(v, float): 141 | parser.add_argument('--' + prefix + k, type=float) 142 | elif isinstance(v, bool): 143 | parser.add_argument('--' + prefix + k, action='store_true') 144 | elif isinstance(v, dict): 145 | add_args(parser, v, prefix + k + '.') 146 | elif isinstance(v, abc.Iterable): 147 | parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+') 148 | else: 149 | print(f'cannot parse key {prefix + k} of type {type(v)}') 150 | return parser 151 | 152 | 153 | class Config: 154 | """A facility for config and config files. 155 | 156 | It supports common file formats as configs: python/json/yaml. The interface 157 | is the same as a dict object and also allows access config values as 158 | attributes. 159 | 160 | Example: 161 | >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) 162 | >>> cfg.a 163 | 1 164 | >>> cfg.b 165 | {'b1': [0, 1]} 166 | >>> cfg.b.b1 167 | [0, 1] 168 | >>> cfg = Config.fromfile('tests/data/config/a.py') 169 | >>> cfg.filename 170 | "/home/kchen/projects/mmcv/tests/data/config/a.py" 171 | >>> cfg.item4 172 | 'test' 173 | >>> cfg 174 | "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " 175 | "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" 176 | """ 177 | 178 | @staticmethod 179 | def _validate_py_syntax(filename): 180 | with open(filename, encoding='utf-8') as f: 181 | # Setting encoding explicitly to resolve coding issue on windows 182 | content = f.read() 183 | try: 184 | ast.parse(content) 185 | except SyntaxError as e: 186 | raise SyntaxError('There are syntax errors in config ' 187 | f'file {filename}: {e}') 188 | 189 | @staticmethod 190 | def _substitute_predefined_vars(filename, temp_config_name): 191 | file_dirname = osp.dirname(filename) 192 | file_basename = osp.basename(filename) 193 | file_basename_no_extension = osp.splitext(file_basename)[0] 194 | file_extname = osp.splitext(filename)[1] 195 | support_templates = dict( 196 | fileDirname=file_dirname, 197 | fileBasename=file_basename, 198 | fileBasenameNoExtension=file_basename_no_extension, 199 | fileExtname=file_extname) 200 | with open(filename, encoding='utf-8') as f: 201 | # Setting encoding explicitly to resolve coding issue on windows 202 | config_file = f.read() 203 | for key, value in support_templates.items(): 204 | regexp = r'\{\{\s*' + str(key) + r'\s*\}\}' 205 | value = value.replace('\\', '/') 206 | config_file = re.sub(regexp, value, config_file) 207 | with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: 208 | tmp_config_file.write(config_file) 209 | 210 | @staticmethod 211 | def _pre_substitute_base_vars(filename, temp_config_name): 212 | """Substitute base variable placehoders to string, so that parsing 213 | would work.""" 214 | with open(filename, encoding='utf-8') as f: 215 | # Setting encoding explicitly to resolve coding issue on windows 216 | config_file = f.read() 217 | base_var_dict = {} 218 | regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}' 219 | base_vars = set(re.findall(regexp, config_file)) 220 | for base_var in base_vars: 221 | randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}' 222 | base_var_dict[randstr] = base_var 223 | regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}' 224 | config_file = re.sub(regexp, f'"{randstr}"', config_file) 225 | with open(temp_config_name, 'w', encoding='utf-8') as tmp_config_file: 226 | tmp_config_file.write(config_file) 227 | return base_var_dict 228 | 229 | @staticmethod 230 | def _substitute_base_vars(cfg, base_var_dict, base_cfg): 231 | """Substitute variable strings to their actual values.""" 232 | cfg = copy.deepcopy(cfg) 233 | 234 | if isinstance(cfg, dict): 235 | for k, v in cfg.items(): 236 | if isinstance(v, str) and v in base_var_dict: 237 | new_v = base_cfg 238 | for new_k in base_var_dict[v].split('.'): 239 | new_v = new_v[new_k] 240 | cfg[k] = new_v 241 | elif isinstance(v, (list, tuple, dict)): 242 | cfg[k] = Config._substitute_base_vars( 243 | v, base_var_dict, base_cfg) 244 | elif isinstance(cfg, tuple): 245 | cfg = tuple( 246 | Config._substitute_base_vars(c, base_var_dict, base_cfg) 247 | for c in cfg) 248 | elif isinstance(cfg, list): 249 | cfg = [ 250 | Config._substitute_base_vars(c, base_var_dict, base_cfg) 251 | for c in cfg 252 | ] 253 | elif isinstance(cfg, str) and cfg in base_var_dict: 254 | new_v = base_cfg 255 | for new_k in base_var_dict[cfg].split('.'): 256 | new_v = new_v[new_k] 257 | cfg = new_v 258 | 259 | return cfg 260 | 261 | @staticmethod 262 | def _file2dict(filename, use_predefined_variables=True): 263 | filename = osp.abspath(osp.expanduser(filename)) 264 | check_file_exist(filename) 265 | fileExtname = osp.splitext(filename)[1] 266 | if fileExtname not in ['.py', '.json', '.yaml', '.yml']: 267 | raise OSError('Only py/yml/yaml/json type are supported now!') 268 | 269 | with tempfile.TemporaryDirectory() as temp_config_dir: 270 | temp_config_file = tempfile.NamedTemporaryFile( 271 | dir=temp_config_dir, suffix=fileExtname) 272 | if platform.system() == 'Windows': 273 | temp_config_file.close() 274 | temp_config_name = osp.basename(temp_config_file.name) 275 | # Substitute predefined variables 276 | if use_predefined_variables: 277 | Config._substitute_predefined_vars(filename, 278 | temp_config_file.name) 279 | else: 280 | shutil.copyfile(filename, temp_config_file.name) 281 | # Substitute base variables from placeholders to strings 282 | base_var_dict = Config._pre_substitute_base_vars( 283 | temp_config_file.name, temp_config_file.name) 284 | 285 | if filename.endswith('.py'): 286 | temp_module_name = osp.splitext(temp_config_name)[0] 287 | sys.path.insert(0, temp_config_dir) 288 | Config._validate_py_syntax(filename) 289 | mod = import_module(temp_module_name) 290 | sys.path.pop(0) 291 | cfg_dict = { 292 | name: value 293 | for name, value in mod.__dict__.items() 294 | if not name.startswith('__') 295 | and not isinstance(value, types.ModuleType) 296 | and not isinstance(value, types.FunctionType) 297 | } 298 | # delete imported module 299 | del sys.modules[temp_module_name] 300 | elif filename.endswith(('.yml', '.yaml', '.json')): 301 | import mmcv 302 | cfg_dict = mmcv.load(temp_config_file.name) 303 | # close temp file 304 | temp_config_file.close() 305 | 306 | # check deprecation information 307 | if DEPRECATION_KEY in cfg_dict: 308 | deprecation_info = cfg_dict.pop(DEPRECATION_KEY) 309 | warning_msg = f'The config file {filename} will be deprecated ' \ 310 | 'in the future.' 311 | if 'expected' in deprecation_info: 312 | warning_msg += f' Please use {deprecation_info["expected"]} ' \ 313 | 'instead.' 314 | if 'reference' in deprecation_info: 315 | warning_msg += ' More information can be found at ' \ 316 | f'{deprecation_info["reference"]}' 317 | warnings.warn(warning_msg, DeprecationWarning) 318 | 319 | cfg_text = filename + '\n' 320 | with open(filename, encoding='utf-8') as f: 321 | # Setting encoding explicitly to resolve coding issue on windows 322 | cfg_text += f.read() 323 | 324 | if BASE_KEY in cfg_dict: 325 | cfg_dir = osp.dirname(filename) 326 | base_filename = cfg_dict.pop(BASE_KEY) 327 | base_filename = base_filename if isinstance( 328 | base_filename, list) else [base_filename] 329 | 330 | cfg_dict_list = list() 331 | cfg_text_list = list() 332 | for f in base_filename: 333 | _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f)) 334 | cfg_dict_list.append(_cfg_dict) 335 | cfg_text_list.append(_cfg_text) 336 | 337 | base_cfg_dict = dict() 338 | for c in cfg_dict_list: 339 | duplicate_keys = base_cfg_dict.keys() & c.keys() 340 | if len(duplicate_keys) > 0: 341 | raise KeyError('Duplicate key is not allowed among bases. ' 342 | f'Duplicate keys: {duplicate_keys}') 343 | base_cfg_dict.update(c) 344 | 345 | # Substitute base variables from strings to their actual values 346 | cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict, 347 | base_cfg_dict) 348 | 349 | base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict) 350 | cfg_dict = base_cfg_dict 351 | 352 | # merge cfg_text 353 | cfg_text_list.append(cfg_text) 354 | cfg_text = '\n'.join(cfg_text_list) 355 | 356 | return cfg_dict, cfg_text 357 | 358 | @staticmethod 359 | def _merge_a_into_b(a, b, allow_list_keys=False): 360 | """merge dict ``a`` into dict ``b`` (non-inplace). 361 | 362 | Values in ``a`` will overwrite ``b``. ``b`` is copied first to avoid 363 | in-place modifications. 364 | 365 | Args: 366 | a (dict): The source dict to be merged into ``b``. 367 | b (dict): The origin dict to be fetch keys from ``a``. 368 | allow_list_keys (bool): If True, int string keys (e.g. '0', '1') 369 | are allowed in source ``a`` and will replace the element of the 370 | corresponding index in b if b is a list. Default: False. 371 | 372 | Returns: 373 | dict: The modified dict of ``b`` using ``a``. 374 | 375 | Examples: 376 | # Normally merge a into b. 377 | >>> Config._merge_a_into_b( 378 | ... dict(obj=dict(a=2)), dict(obj=dict(a=1))) 379 | {'obj': {'a': 2}} 380 | 381 | # Delete b first and merge a into b. 382 | >>> Config._merge_a_into_b( 383 | ... dict(obj=dict(_delete_=True, a=2)), dict(obj=dict(a=1))) 384 | {'obj': {'a': 2}} 385 | 386 | # b is a list 387 | >>> Config._merge_a_into_b( 388 | ... {'0': dict(a=2)}, [dict(a=1), dict(b=2)], True) 389 | [{'a': 2}, {'b': 2}] 390 | """ 391 | b = b.copy() 392 | for k, v in a.items(): 393 | if allow_list_keys and k.isdigit() and isinstance(b, list): 394 | k = int(k) 395 | if len(b) <= k: 396 | raise KeyError(f'Index {k} exceeds the length of list {b}') 397 | b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) 398 | elif isinstance(v, dict): 399 | if k in b and not v.pop(DELETE_KEY, False): 400 | allowed_types = (dict, list) if allow_list_keys else dict 401 | if not isinstance(b[k], allowed_types): 402 | raise TypeError( 403 | f'{k}={v} in child config cannot inherit from ' 404 | f'base because {k} is a dict in the child config ' 405 | f'but is of type {type(b[k])} in base config. ' 406 | f'You may set `{DELETE_KEY}=True` to ignore the ' 407 | f'base config.') 408 | b[k] = Config._merge_a_into_b(v, b[k], allow_list_keys) 409 | else: 410 | b[k] = ConfigDict(v) 411 | else: 412 | b[k] = v 413 | return b 414 | 415 | @staticmethod 416 | def fromfile(filename, 417 | use_predefined_variables=True, 418 | import_custom_modules=True): 419 | if isinstance(filename, Path): 420 | filename = str(filename) 421 | cfg_dict, cfg_text = Config._file2dict(filename, 422 | use_predefined_variables) 423 | if import_custom_modules and cfg_dict.get('custom_imports', None): 424 | import_modules_from_strings(**cfg_dict['custom_imports']) 425 | return Config(cfg_dict, cfg_text=cfg_text, filename=filename) 426 | 427 | @staticmethod 428 | def fromstring(cfg_str, file_format): 429 | """Generate config from config str. 430 | 431 | Args: 432 | cfg_str (str): Config str. 433 | file_format (str): Config file format corresponding to the 434 | config str. Only py/yml/yaml/json type are supported now! 435 | 436 | Returns: 437 | :obj:`Config`: Config obj. 438 | """ 439 | if file_format not in ['.py', '.json', '.yaml', '.yml']: 440 | raise OSError('Only py/yml/yaml/json type are supported now!') 441 | if file_format != '.py' and 'dict(' in cfg_str: 442 | # check if users specify a wrong suffix for python 443 | warnings.warn( 444 | 'Please check "file_format", the file format may be .py') 445 | with tempfile.NamedTemporaryFile( 446 | 'w', encoding='utf-8', suffix=file_format, 447 | delete=False) as temp_file: 448 | temp_file.write(cfg_str) 449 | # on windows, previous implementation cause error 450 | # see PR 1077 for details 451 | cfg = Config.fromfile(temp_file.name) 452 | os.remove(temp_file.name) 453 | return cfg 454 | 455 | @staticmethod 456 | def auto_argparser(description=None): 457 | """Generate argparser from config file automatically (experimental)""" 458 | partial_parser = ArgumentParser(description=description) 459 | partial_parser.add_argument('config', help='config file path') 460 | cfg_file = partial_parser.parse_known_args()[0].config 461 | cfg = Config.fromfile(cfg_file) 462 | parser = ArgumentParser(description=description) 463 | parser.add_argument('config', help='config file path') 464 | add_args(parser, cfg) 465 | return parser, cfg 466 | 467 | def __init__(self, cfg_dict=None, cfg_text=None, filename=None): 468 | if cfg_dict is None: 469 | cfg_dict = dict() 470 | elif not isinstance(cfg_dict, dict): 471 | raise TypeError('cfg_dict must be a dict, but ' 472 | f'got {type(cfg_dict)}') 473 | for key in cfg_dict: 474 | if key in RESERVED_KEYS: 475 | raise KeyError(f'{key} is reserved for config file') 476 | 477 | if isinstance(filename, Path): 478 | filename = str(filename) 479 | 480 | super().__setattr__('_cfg_dict', ConfigDict(cfg_dict)) 481 | super().__setattr__('_filename', filename) 482 | if cfg_text: 483 | text = cfg_text 484 | elif filename: 485 | with open(filename, encoding='utf-8') as f: 486 | text = f.read() 487 | else: 488 | text = '' 489 | super().__setattr__('_text', text) 490 | 491 | @property 492 | def filename(self): 493 | return self._filename 494 | 495 | @property 496 | def text(self): 497 | return self._text 498 | 499 | @property 500 | def pretty_text(self): 501 | 502 | indent = 4 503 | 504 | def _indent(s_, num_spaces): 505 | s = s_.split('\n') 506 | if len(s) == 1: 507 | return s_ 508 | first = s.pop(0) 509 | s = [(num_spaces * ' ') + line for line in s] 510 | s = '\n'.join(s) 511 | s = first + '\n' + s 512 | return s 513 | 514 | def _format_basic_types(k, v, use_mapping=False): 515 | if isinstance(v, str): 516 | v_str = f"'{v}'" 517 | else: 518 | v_str = str(v) 519 | 520 | if use_mapping: 521 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 522 | attr_str = f'{k_str}: {v_str}' 523 | else: 524 | attr_str = f'{str(k)}={v_str}' 525 | attr_str = _indent(attr_str, indent) 526 | 527 | return attr_str 528 | 529 | def _format_list(k, v, use_mapping=False): 530 | # check if all items in the list are dict 531 | if all(isinstance(_, dict) for _ in v): 532 | v_str = '[\n' 533 | v_str += '\n'.join( 534 | f'dict({_indent(_format_dict(v_), indent)}),' 535 | for v_ in v).rstrip(',') 536 | if use_mapping: 537 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 538 | attr_str = f'{k_str}: {v_str}' 539 | else: 540 | attr_str = f'{str(k)}={v_str}' 541 | attr_str = _indent(attr_str, indent) + ']' 542 | else: 543 | attr_str = _format_basic_types(k, v, use_mapping) 544 | return attr_str 545 | 546 | def _contain_invalid_identifier(dict_str): 547 | contain_invalid_identifier = False 548 | for key_name in dict_str: 549 | contain_invalid_identifier |= \ 550 | (not str(key_name).isidentifier()) 551 | return contain_invalid_identifier 552 | 553 | def _format_dict(input_dict, outest_level=False): 554 | r = '' 555 | s = [] 556 | 557 | use_mapping = _contain_invalid_identifier(input_dict) 558 | if use_mapping: 559 | r += '{' 560 | for idx, (k, v) in enumerate(input_dict.items()): 561 | is_last = idx >= len(input_dict) - 1 562 | end = '' if outest_level or is_last else ',' 563 | if isinstance(v, dict): 564 | v_str = '\n' + _format_dict(v) 565 | if use_mapping: 566 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 567 | attr_str = f'{k_str}: dict({v_str}' 568 | else: 569 | attr_str = f'{str(k)}=dict({v_str}' 570 | attr_str = _indent(attr_str, indent) + ')' + end 571 | elif isinstance(v, list): 572 | attr_str = _format_list(k, v, use_mapping) + end 573 | else: 574 | attr_str = _format_basic_types(k, v, use_mapping) + end 575 | 576 | s.append(attr_str) 577 | r += '\n'.join(s) 578 | if use_mapping: 579 | r += '}' 580 | return r 581 | 582 | cfg_dict = self._cfg_dict.to_dict() 583 | text = _format_dict(cfg_dict, outest_level=True) 584 | # copied from setup.cfg 585 | yapf_style = dict( 586 | based_on_style='pep8', 587 | blank_line_before_nested_class_or_def=True, 588 | split_before_expression_after_opening_paren=True) 589 | text, _ = FormatCode(text, style_config=yapf_style, verify=True) 590 | 591 | return text 592 | 593 | def __repr__(self): 594 | return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' 595 | 596 | def __len__(self): 597 | return len(self._cfg_dict) 598 | 599 | def __getattr__(self, name): 600 | return getattr(self._cfg_dict, name) 601 | 602 | def __getitem__(self, name): 603 | return self._cfg_dict.__getitem__(name) 604 | 605 | def __setattr__(self, name, value): 606 | if isinstance(value, dict): 607 | value = ConfigDict(value) 608 | self._cfg_dict.__setattr__(name, value) 609 | 610 | def __setitem__(self, name, value): 611 | if isinstance(value, dict): 612 | value = ConfigDict(value) 613 | self._cfg_dict.__setitem__(name, value) 614 | 615 | def __iter__(self): 616 | return iter(self._cfg_dict) 617 | 618 | def __getstate__(self): 619 | return (self._cfg_dict, self._filename, self._text) 620 | 621 | def __copy__(self): 622 | cls = self.__class__ 623 | other = cls.__new__(cls) 624 | other.__dict__.update(self.__dict__) 625 | 626 | return other 627 | 628 | def __deepcopy__(self, memo): 629 | cls = self.__class__ 630 | other = cls.__new__(cls) 631 | memo[id(self)] = other 632 | 633 | for key, value in self.__dict__.items(): 634 | super(Config, other).__setattr__(key, copy.deepcopy(value, memo)) 635 | 636 | return other 637 | 638 | def __setstate__(self, state): 639 | _cfg_dict, _filename, _text = state 640 | super().__setattr__('_cfg_dict', _cfg_dict) 641 | super().__setattr__('_filename', _filename) 642 | super().__setattr__('_text', _text) 643 | 644 | def dump(self, file=None): 645 | """Dumps config into a file or returns a string representation of the 646 | config. 647 | 648 | If a file argument is given, saves the config to that file using the 649 | format defined by the file argument extension. 650 | 651 | Otherwise, returns a string representing the config. The formatting of 652 | this returned string is defined by the extension of `self.filename`. If 653 | `self.filename` is not defined, returns a string representation of a 654 | dict (lowercased and using ' for strings). 655 | 656 | Examples: 657 | >>> cfg_dict = dict(item1=[1, 2], item2=dict(a=0), 658 | ... item3=True, item4='test') 659 | >>> cfg = Config(cfg_dict=cfg_dict) 660 | >>> dump_file = "a.py" 661 | >>> cfg.dump(dump_file) 662 | 663 | Args: 664 | file (str, optional): Path of the output file where the config 665 | will be dumped. Defaults to None. 666 | """ 667 | import mmcv 668 | cfg_dict = super().__getattribute__('_cfg_dict').to_dict() 669 | if file is None: 670 | if self.filename is None or self.filename.endswith('.py'): 671 | return self.pretty_text 672 | else: 673 | file_format = self.filename.split('.')[-1] 674 | return mmcv.dump(cfg_dict, file_format=file_format) 675 | elif file.endswith('.py'): 676 | with open(file, 'w', encoding='utf-8') as f: 677 | f.write(self.pretty_text) 678 | else: 679 | file_format = file.split('.')[-1] 680 | return mmcv.dump(cfg_dict, file=file, file_format=file_format) 681 | 682 | def merge_from_dict(self, options, allow_list_keys=True): 683 | """Merge list into cfg_dict. 684 | 685 | Merge the dict parsed by MultipleKVAction into this cfg. 686 | 687 | Examples: 688 | >>> options = {'model.backbone.depth': 50, 689 | ... 'model.backbone.with_cp':True} 690 | >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) 691 | >>> cfg.merge_from_dict(options) 692 | >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') 693 | >>> assert cfg_dict == dict( 694 | ... model=dict(backbone=dict(depth=50, with_cp=True))) 695 | 696 | >>> # Merge list element 697 | >>> cfg = Config(dict(pipeline=[ 698 | ... dict(type='LoadImage'), dict(type='LoadAnnotations')])) 699 | >>> options = dict(pipeline={'0': dict(type='SelfLoadImage')}) 700 | >>> cfg.merge_from_dict(options, allow_list_keys=True) 701 | >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') 702 | >>> assert cfg_dict == dict(pipeline=[ 703 | ... dict(type='SelfLoadImage'), dict(type='LoadAnnotations')]) 704 | 705 | Args: 706 | options (dict): dict of configs to merge from. 707 | allow_list_keys (bool): If True, int string keys (e.g. '0', '1') 708 | are allowed in ``options`` and will replace the element of the 709 | corresponding index in the config if the config is a list. 710 | Default: True. 711 | """ 712 | option_cfg_dict = {} 713 | for full_key, v in options.items(): 714 | d = option_cfg_dict 715 | key_list = full_key.split('.') 716 | for subkey in key_list[:-1]: 717 | d.setdefault(subkey, ConfigDict()) 718 | d = d[subkey] 719 | subkey = key_list[-1] 720 | d[subkey] = v 721 | 722 | cfg_dict = super().__getattribute__('_cfg_dict') 723 | super().__setattr__( 724 | '_cfg_dict', 725 | Config._merge_a_into_b( 726 | option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) 727 | 728 | 729 | class DictAction(Action): 730 | """ 731 | argparse action to split an argument into KEY=VALUE form 732 | on the first = and append to a dictionary. List options can 733 | be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit 734 | brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build 735 | list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]' 736 | """ 737 | 738 | @staticmethod 739 | def _parse_int_float_bool(val): 740 | try: 741 | return int(val) 742 | except ValueError: 743 | pass 744 | try: 745 | return float(val) 746 | except ValueError: 747 | pass 748 | if val.lower() in ['true', 'false']: 749 | return True if val.lower() == 'true' else False 750 | if val == 'None': 751 | return None 752 | return val 753 | 754 | @staticmethod 755 | def _parse_iterable(val): 756 | """Parse iterable values in the string. 757 | 758 | All elements inside '()' or '[]' are treated as iterable values. 759 | 760 | Args: 761 | val (str): Value string. 762 | 763 | Returns: 764 | list | tuple: The expanded list or tuple from the string. 765 | 766 | Examples: 767 | >>> DictAction._parse_iterable('1,2,3') 768 | [1, 2, 3] 769 | >>> DictAction._parse_iterable('[a, b, c]') 770 | ['a', 'b', 'c'] 771 | >>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]') 772 | [(1, 2, 3), ['a', 'b'], 'c'] 773 | """ 774 | 775 | def find_next_comma(string): 776 | """Find the position of next comma in the string. 777 | 778 | If no ',' is found in the string, return the string length. All 779 | chars inside '()' and '[]' are treated as one element and thus ',' 780 | inside these brackets are ignored. 781 | """ 782 | assert (string.count('(') == string.count(')')) and ( 783 | string.count('[') == string.count(']')), \ 784 | f'Imbalanced brackets exist in {string}' 785 | end = len(string) 786 | for idx, char in enumerate(string): 787 | pre = string[:idx] 788 | # The string before this ',' is balanced 789 | if ((char == ',') and (pre.count('(') == pre.count(')')) 790 | and (pre.count('[') == pre.count(']'))): 791 | end = idx 792 | break 793 | return end 794 | 795 | # Strip ' and " characters and replace whitespace. 796 | val = val.strip('\'\"').replace(' ', '') 797 | is_tuple = False 798 | if val.startswith('(') and val.endswith(')'): 799 | is_tuple = True 800 | val = val[1:-1] 801 | elif val.startswith('[') and val.endswith(']'): 802 | val = val[1:-1] 803 | elif ',' not in val: 804 | # val is a single value 805 | return DictAction._parse_int_float_bool(val) 806 | 807 | values = [] 808 | while len(val) > 0: 809 | comma_idx = find_next_comma(val) 810 | element = DictAction._parse_iterable(val[:comma_idx]) 811 | values.append(element) 812 | val = val[comma_idx + 1:] 813 | if is_tuple: 814 | values = tuple(values) 815 | return values 816 | 817 | def __call__(self, parser, namespace, values, option_string=None): 818 | options = {} 819 | for kv in values: 820 | key, val = kv.split('=', maxsplit=1) 821 | options[key] = self._parse_iterable(val) 822 | setattr(namespace, self.dest, options) 823 | -------------------------------------------------------------------------------- /pyinfer/utils/common/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import logging 3 | from logging.handlers import RotatingFileHandler 4 | import colorlog 5 | 6 | __all__ = ["Logger"] 7 | 8 | 9 | class Logger(): 10 | """日志""" 11 | instance = None 12 | 13 | def __new__(cls, filename=None, level="INFO"): 14 | """ 15 | 日志设计为单例模式 16 | 17 | 注: 18 | getLogger对于相同filename返回同一对象,需要保证addHandler仅对同一logger对象进行一次,否则日志将出现重复打印;单例可确保上述需求; 19 | 20 | """ 21 | 22 | level_dict = {"INFO": logging.INFO, "DEBUG": logging.DEBUG, "FATAL": logging.FATAL, "ERROR": logging.ERROR} 23 | 24 | if cls.instance is None: 25 | cls.instance = super().__new__(cls) # 未经过初始化的实例对象 26 | cls.instance._logger = cls.getLogger(filename, level_dict.get(level)) 27 | return cls.instance 28 | 29 | @classmethod 30 | def getLogger(cls, filename, level): 31 | logger = logging.getLogger(filename) 32 | 33 | log_colors_config = { 34 | 'DEBUG': 'cyan', 35 | 'INFO': 'green', 36 | 'WARNING': 'yellow', 37 | 'ERROR': 'red', 38 | 'FATAL': 'red', 39 | } 40 | 41 | formatter = colorlog.ColoredFormatter( 42 | '%(log_color)s [%(levelname)s] %(asctime)s %(filename)s[line:%(lineno)d] : %(message)s', 43 | log_colors=log_colors_config) 44 | 45 | # 设置日志级别 46 | logger.setLevel(logging.INFO) 47 | # 往屏幕上输出 48 | console_handler = logging.StreamHandler() 49 | # 设置屏幕上显示的格式 50 | console_handler.setFormatter(formatter) 51 | # 把对象加到logger里 52 | logger.addHandler(console_handler) 53 | 54 | # 输出到文件 55 | if filename is not None: 56 | file_handler = RotatingFileHandler(filename=filename, mode='a', maxBytes=1 * 1024 * 1024, encoding='utf8') 57 | file_formatter = logging.Formatter( 58 | '[%(levelname)s] %(asctime)s %(filename)s[line:%(lineno)d]: %(message)s') 59 | file_handler.setFormatter(file_formatter) 60 | logger.addHandler(file_handler) 61 | return logger 62 | 63 | def warning(self, msg): 64 | self._logger.warning(msg) 65 | 66 | def info(self, msg): 67 | self._logger.info(msg) 68 | 69 | def debug(self, msg): 70 | self._logger.debug(msg) 71 | 72 | def error(self, msg): 73 | self._logger.error(msg) 74 | 75 | def fatal(self, msg): 76 | self._logger.fatal(msg) 77 | raise -------------------------------------------------------------------------------- /pyinfer/utils/common/registry.py: -------------------------------------------------------------------------------- 1 | __all__ = ["INFER", "ENGINE"] 2 | 3 | 4 | class Register(): 5 | 6 | def __init__(self, name=None) -> None: 7 | self.module_dict = {} 8 | 9 | def register_module(self): 10 | 11 | def _register(module): 12 | self.module_dict[module.__name__] = module 13 | return module 14 | 15 | return _register 16 | 17 | def build(self, cfg): 18 | module_name = cfg.pop('type') 19 | return self.get(module_name)(**dict(cfg)) 20 | 21 | def get(self, module_name): 22 | return self.module_dict.get(module_name) 23 | 24 | 25 | INFERS = Register("INFERS") 26 | ENGINES = Register("ENGINES") 27 | HOOKS = Register("HOOKS") 28 | -------------------------------------------------------------------------------- /pyinfer/utils/detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/pyinfer/utils/detection/__init__.py -------------------------------------------------------------------------------- /pyinfer/utils/detection/bbox.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | import numpy as np 3 | from typing import Union 4 | 5 | 6 | class QuadrangleBBox(BaseModel): 7 | 8 | x1: float = 0 9 | y1: float = 0 10 | x2: float = 0 11 | y2: float = 0 12 | x3: float = 0 13 | y3: float = 0 14 | x4: float = 0 15 | y4: float = 0 16 | area: float = 0 17 | confidence: float = 0 18 | label: int = 0 19 | labelname: str = "" 20 | keepflag: bool = True 21 | 22 | @property 23 | def coords(self): 24 | """坐标值""" 25 | return [self.x1, self.y1, self.x2, self.y2, self.x3, self.y3, self.x4, self.y4] 26 | 27 | def set_coords(self, coords): 28 | """坐标值""" 29 | assert len(coords) == 8 30 | self.x1, self.y1, self.x2, self.y2, self.x3, self.y3, self.x4, self.y4 = coords 31 | return self 32 | 33 | @property 34 | def xcoords(self): 35 | """坐标值""" 36 | return [self.x1, self.x2, self.x3, self.x4] 37 | 38 | def set_xcoords(self, xcoords): 39 | assert len(xcoords) == 4 40 | self.x1, self.x2, self.x3, self.x4 = xcoords 41 | return self 42 | 43 | @property 44 | def ycoords(self): 45 | """坐标值""" 46 | return [self.y1, self.y2, self.y3, self.y4] 47 | 48 | def set_ycoords(self, ycoords): 49 | assert len(ycoords) == 4 50 | self.y1, self.y2, self.y3, self.y4 = ycoords 51 | return self 52 | 53 | @property 54 | def width(self): 55 | return self.right - self.left 56 | 57 | @property 58 | def height(self): 59 | return self.bottom - self.top 60 | 61 | @property 62 | def left(self): 63 | return min(self.x1, self.x2, self.x3, self.x4) 64 | 65 | @property 66 | def right(self): 67 | return max(self.x1, self.x2, self.x3, self.x4) 68 | 69 | @property 70 | def top(self): 71 | return min(self.y1, self.y2, self.y3, self.y4) 72 | 73 | @property 74 | def bottom(self): 75 | return max(self.y1, self.y2, self.y3, self.y4) 76 | 77 | def reset_origin(self, ox, oy): 78 | """更新原点""" 79 | self.xcoords = [self.x1 - ox, self.x2 - ox, self.x3 - ox, self.x4 - ox] 80 | self.ycoords = [self.y1 - oy, self.y2 - oy, self.y3 - oy, self.y4 - oy] 81 | return self 82 | 83 | def clip(self, xmin, xmax, ymin, ymax): 84 | self.xcoords = list(map(lambda x: min(max(x, xmin), xmax), self.xcoords)) 85 | self.ycoords = list(map(lambda y: min(max(y, ymin), ymax), self.ycoords)) 86 | return self 87 | 88 | 89 | @property 90 | def level(self): 91 | return len(set(self.ycoords)) == len(set(self.xcoords)) == 2 92 | 93 | def __mul__(self, scalar: Union[int, float]): 94 | """坐标乘""" 95 | self.coords = list(np.array(self.coords) * scalar) 96 | return self 97 | 98 | def __setattr__(self, key, val): 99 | # @coords.setter is not support in BaseModel, thus, modify __setatter__ 100 | method = self.__config__.property_set_methods.get(key) 101 | if method is None: 102 | super().__setattr__(key, val) 103 | else: 104 | getattr(self, method)(val) 105 | 106 | class Config: 107 | property_set_methods = {"coords": "set_coords", "xcoords": "set_xcoords", "ycoords": "set_ycoords"} 108 | -------------------------------------------------------------------------------- /pyinfer/utils/detection/nms.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | import shapely.geometry as shgeo 4 | from .bbox import QuadrangleBBox 5 | 6 | 7 | def cpu_nms(bboxes: List[QuadrangleBBox], nms_threshold) -> List[QuadrangleBBox]: 8 | def box_iou(a: QuadrangleBBox, b: QuadrangleBBox): 9 | a_poly = shgeo.Polygon(np.array(a.coords).reshape(-1, 2).tolist()) 10 | b_poly = shgeo.Polygon(np.array(b.coords).reshape(-1, 2).tolist()) 11 | inter_poly = a_poly.intersection(b_poly) 12 | if inter_poly.area == 0: 13 | return 0 14 | return inter_poly.area / (a_poly.area + b_poly.area - inter_poly.area) 15 | 16 | for i in range(len(bboxes)): 17 | for j in range(len(bboxes)): 18 | # label不同或同一个bbox跳过nms 19 | if i == j or bboxes[i].label != bboxes[j].label: 20 | continue 21 | 22 | # 置信度相同,保留靠后bbox, 即j bboxes[i].confidence: 27 | iou = box_iou(bboxes[i], bboxes[j]) 28 | if (iou > nms_threshold): 29 | bboxes[i].keepflag = 0 30 | 31 | return list(filter(lambda bbox: bbox.keepflag, bboxes)) -------------------------------------------------------------------------------- /pyinfer/utils/functional/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/pyinfer/utils/functional/__init__.py -------------------------------------------------------------------------------- /pyinfer/utils/functional/coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pycocotools.coco import COCO 3 | 4 | 5 | class COCOCreator(): 6 | def __init__(self) -> None: 7 | self.image_id_to_index = {} 8 | self.anno_id_to_index = {} 9 | self.cat_id_to_index = {} 10 | self.dataset = {"images": [], "annotations": [], "categories": []} 11 | 12 | def create_image_info(self, image_id, file_name, height, width, update=False, **kwargs): 13 | image_info = {"id": image_id, "file_name": file_name, "width": width, "height": height, **kwargs} 14 | if update: 15 | self.update_image_info(image_info) 16 | return image_info 17 | 18 | def create_anno_info(self, 19 | image_id, 20 | anno_id, 21 | category_id, 22 | bbox, 23 | area, 24 | segmentation=[], 25 | iscrowd=0, 26 | update=False, 27 | **kwargs): 28 | anno_info = { 29 | "image_id": image_id, 30 | "id": anno_id, 31 | "category_id": category_id, 32 | "bbox": bbox, 33 | "area": area, 34 | "segmentation": segmentation, 35 | "iscrowd": iscrowd, 36 | **kwargs 37 | } 38 | if update: 39 | self.update_anno_info(anno_info) 40 | return anno_info 41 | 42 | def create_cat_info(self, cat_id, cat_name, update=False, **kwargs): 43 | cat_info = {"id": cat_id, "name": cat_name, **kwargs} 44 | if update: 45 | self.update_cat_info(cat_info) 46 | return cat_info 47 | 48 | def update_image_info(self, image_info): 49 | image_id = image_info["id"] 50 | if image_id in self.image_id_to_index: 51 | self.dataset["images"][self.image_id_to_index[image_id]].update(image_info) 52 | else: 53 | self.image_id_to_index[image_id] = len(self.dataset["images"]) 54 | self.dataset["images"].append(image_info) 55 | 56 | def update_anno_info(self, anno_info): 57 | anno_id = anno_info["id"] 58 | if anno_id in self.anno_id_to_index: 59 | self.dataset["annotations"][self.anno_id_to_index[anno_id]].update(anno_info) 60 | else: 61 | self.anno_id_to_index[anno_id] = len(self.dataset["annotations"]) 62 | self.dataset["annotations"].append(anno_info) 63 | 64 | def update_cat_info(self, cat_info): 65 | cat_id = cat_info["id"] 66 | if cat_id in self.cat_id_to_index: 67 | self.dataset["categories"][self.cat_id_to_index[cat_id]].update(cat_info) 68 | else: 69 | self.cat_id_to_index[cat_id] = len(self.dataset["categories"]) 70 | self.dataset["categories"].append(cat_info) 71 | 72 | def build(self): 73 | coco = COCO() 74 | coco.showAnns() 75 | coco.dataset = self.dataset 76 | coco.createIndex() 77 | return coco 78 | 79 | def write(self, path): 80 | with open(path, 'w') as f: 81 | json.dump(self.dataset, f, indent=4, ensure_ascii=False) 82 | -------------------------------------------------------------------------------- /pyinfer/utils/functional/slice.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import copy 3 | import math 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | from PIL import Image 8 | import shapely.geometry as shgeo 9 | from typing import List 10 | from pycocotools.coco import COCO 11 | 12 | from ..detection.bbox import QuadrangleBBox 13 | from .coco import COCOCreator 14 | 15 | __all__ = [""] 16 | 17 | 18 | def reduce_coords(coords): 19 | """将最短边的两个点p1,p2,用p1,p2中点p替换,从而使多边形减少一个顶点""" 20 | def get_distance(point1, point2): 21 | return math.sqrt(math.pow(point1[0] - point2[0], 2) + math.pow(point1[1] - point2[1], 2)) 22 | 23 | coords.append(coords[0]) 24 | distances = [get_distance(coords[i], coords[i + 1]) for i in range(len(coords) - 1)] 25 | 26 | min_index = np.array(distances).argsort()[0] 27 | index = 0 28 | out_coords = [] 29 | while index < len(coords): 30 | if (index == min_index): # 取中点 31 | middle_x = (coords[index][0] + coords[index + 1][0]) / 2 32 | middle_y = (coords[index][1] + coords[index + 1][1]) / 2 33 | out_coords.append((middle_x, middle_y)) 34 | elif (index != (min_index + 1)): # 非最短边点保留 35 | out_coords.append((coords[index][0], coords[index][1])) 36 | index += 1 37 | return out_coords 38 | 39 | 40 | def choose_best_pointorder_fit_another(coords, gt_coords): 41 | """ 42 | :params coords: [p1,p2,p3,p4], 四边形坐标点 43 | :params gt_coords: [p1,p2,p3,p4], 标签四边形坐标点 44 | 45 | 选择最匹配gt_coords的坐标排布顺序,coords在最优排布顺序下,与gt_coords逐点点间距之和最小; 46 | """ 47 | x1, gt_x1 = coords[0][0], gt_coords[0][0] 48 | y1, gt_y1 = coords[0][1], gt_coords[0][1] 49 | x2, gt_x2 = coords[1][0], gt_coords[1][0] 50 | y2, gt_y2 = coords[1][1], gt_coords[1][1] 51 | x3, gt_x3 = coords[2][0], gt_coords[2][0] 52 | y3, gt_y3 = coords[2][1], gt_coords[2][1] 53 | x4, gt_x4 = coords[3][0], gt_coords[3][0] 54 | y4, gt_y4 = coords[3][1], gt_coords[3][1] 55 | 56 | combinate = [ 57 | np.array([x1, y1, x2, y2, x3, y3, x4, y4]), 58 | np.array([x2, y2, x3, y3, x4, y4, x1, y1]), 59 | np.array([x3, y3, x4, y4, x1, y1, x2, y2]), 60 | np.array([x4, y4, x1, y1, x2, y2, x3, y3]) 61 | ] 62 | gt = np.array([gt_x1, gt_y1, gt_x2, gt_y2, gt_x3, gt_y3, gt_x4, gt_y4]) 63 | distances = np.array([np.sum((coord - gt)**2) for coord in combinate]) 64 | sorted = distances.argsort() 65 | best_coords = combinate[sorted[0]].reshape(-1, 2).tolist() 66 | return best_coords 67 | 68 | 69 | def assign_bboxes(pleft, ptop, pright, pbottom, gt_bboxes: List[QuadrangleBBox], threshold): 70 | """ 71 | 分配图像bounding box,过滤掉坐标不在当前图像内的bbox 72 | 73 | Args: 74 | pleft, ptop, pright, pbottom: 图像左上角和右下角坐标值 75 | bboxes (List[QuadrangleBBox], optional): 图像上目标的bounding box. Defaults to None. 76 | threshold (float, optional): 目标bbox和瓦片重叠IOU阈值,大于阈值将bbox分配给该瓦片,否则丢弃. Defaults to 0.5. 77 | 78 | Returns: 79 | patches: List[bboxes] 80 | 81 | 基本步骤: 82 | 83 | 1. 遍历一个gt_bbox; 84 | 2. 计算bbox和当前图像的iou; 85 | 3. iou=1, bbox完全在图像内,保留,跳至第一步; 86 | 4. iou>threshold, bbox部分在图像内,重叠多边形顶点数为n 87 | ① n<4, bbox只有一个顶点在图像内,bbox不匹配; 88 | ② n=5,将5边形变换为4边形,生成新的new_bbox,匹配; 89 | ③ n=6,将6边形变换为4边形,生成新的new_bbox,匹配; 90 | ④ n>6,实际使用中,该情况较少发生,不匹配; 91 | 5. 调整new_bbox顶点顺序,即new_bbox以第p个顶点作为起始点,该顶点排布顺序,与gt_bbox对应顶点距离和最小, 92 | 6. new_bbox坐标值是相对原点(0,0)的值,以图像左上角点为原点,更新坐标值; 93 | 7. 重复第一步; 94 | """ 95 | image_poly = shgeo.Polygon([(pleft, ptop), (pright, ptop), (pright, pbottom), (pleft, pbottom)]) 96 | 97 | remain_bboxes = [] 98 | for gt_bbox in gt_bboxes: 99 | x1, y1, x2, y2, x3, y3, x4, y4 = gt_bbox.coords 100 | gt_coords = [(x1, y1), (x2, y2), (x3, y3), (x4, y4)] 101 | gt_poly = shgeo.Polygon(gt_coords) 102 | # 错误的bbox 103 | if (gt_poly.area <= 0): 104 | continue 105 | 106 | inter_poly = gt_poly.intersection(image_poly) # 重叠区域 107 | iou = inter_poly.area / gt_poly.area # 非常规IOU,此IOU时交集与gt_box的比值 108 | 109 | # 若bbox在图像内部,则保留,同时以图像的左上角点作为bbox原点 110 | if (iou == 1): 111 | bbox = copy.deepcopy(gt_bbox) # 更新原点,不要在原始bbox上直接修改 112 | remain_bboxes.append(bbox.reset_origin(pleft, ptop)) 113 | 114 | # 若bbox部分在图像内部 115 | elif iou > threshold: 116 | # 有序排布多边形点; 117 | inter_poly = shgeo.polygon.orient(inter_poly, sign=1) 118 | # 重叠区域多边形顶点,由于收尾相接,第一个点和最后一个点重复,因此去除最后一个点; 119 | # inter_coords: [(x1, y1), (x2, y2), (x3, y3), (x4, y4), ..., (x1, y1)] 120 | inter_coords = list(inter_poly.exterior.coords)[0:-1] 121 | 122 | # 两个矩形重叠区域较小时,重叠区可能由三个点组成,重叠较小,bbox不保留 123 | if len(inter_coords) < 4: 124 | continue 125 | 126 | elif (len(inter_coords) > 6): 127 | """两个矩形重叠区域存在6个以上顶点,在实际情况中极少出现,因此此类bbox做丢弃处理""" 128 | continue 129 | 130 | elif (len(inter_coords) == 6): 131 | """重叠区域6个顶点时,合并两个最短边,减少顶点数至4""" 132 | inter_coords = reduce_coords(reduce_coords(inter_coords)) 133 | 134 | elif (len(inter_coords) == 5): 135 | """重叠区域5个顶点时,合并一个最短边,减少顶点数至4""" 136 | inter_coords = reduce_coords(inter_coords) 137 | 138 | # elif (len(inter_coords)==4): 139 | # 4个顶点不作处理 140 | 141 | best_coords = choose_best_pointorder_fit_another(inter_coords, gt_coords) # 最优顶点顺序 142 | best_coords_area = shgeo.Polygon(best_coords).area 143 | # 限制坐标范围, 并更新原点 144 | best_coords = np.array(best_coords).flatten() 145 | new_bbox = QuadrangleBBox(label=gt_bbox.label, area=best_coords_area) 146 | new_bbox.coords = best_coords 147 | new_bbox.clip(xmin=pleft, xmax=pright, ymin=ptop, ymax=pbottom) 148 | new_bbox.reset_origin(ox=pleft, oy=ptop) 149 | remain_bboxes.append(new_bbox) 150 | return remain_bboxes 151 | 152 | 153 | def slice_patch(image, left, top, subsize, padding): 154 | no_padding_patch = copy.deepcopy(image[top:(top + subsize), left:(left + subsize)]) 155 | h, w, c = no_padding_patch.shape 156 | if (padding): 157 | patch_image = np.ones((subsize, subsize, c)) * 114 158 | patch_image[0:h, 0:w, :] = no_padding_patch 159 | else: 160 | patch_image = no_padding_patch 161 | return patch_image.astype(np.uint8) 162 | 163 | 164 | def slice_one_image(image: np.ndarray, 165 | image_base_name, 166 | subsize, 167 | rate, 168 | gap, 169 | threshold=0.5, 170 | padding=True, 171 | bboxes: List[QuadrangleBBox] = None): 172 | """ 173 | 拆分一张图像 174 | 175 | Args: 176 | image (np.ndarray): 图像 177 | subsize (int): 瓦片大小 178 | rate (int): 图像resize倍数, rate=1表示不对图像做resize处理 179 | gap (int): 瓦片重叠大小 180 | bboxes (List[QuadrangleBBox], optional): 图像上目标的bounding box. Defaults to None. 181 | threshold (float, optional): 目标bbox和瓦片重叠IOU阈值,大于阈值将bbox分配给该瓦片,否则丢弃. Defaults to 0.5. 182 | padding (bool, optional): 瓦片大小不足subsize,是否padding. Defaults to True. 183 | name (str, optional): 图像名称. Defaults to "". 184 | 185 | Returns: 186 | patches: List[Tuple[patch_name, patch_image, bboxes]] 187 | 188 | 基本步骤: 189 | 190 | 1. 选取瓦片图像范围 191 | 2. 从image切分瓦片patch image 192 | 3. 若image有标签,将标签bboxes与瓦片进行匹配,匹配成功的部分bboxes作为patch bboxes 193 | 4. 滑动,重复步骤1~3 194 | """ 195 | assert np.shape(image) != () 196 | 197 | patches = [] 198 | if (rate != 1): 199 | if bboxes is not None: 200 | bboxes = list(map(lambda bbox: bbox * rate, bboxes)) 201 | resizeimg = cv2.resize(image, None, fx=rate, fy=rate, interpolation=cv2.INTER_CUBIC) 202 | else: 203 | resizeimg = image 204 | 205 | base_name = image_base_name + '__' + str(rate) + '__' 206 | 207 | width = np.shape(resizeimg)[1] 208 | height = np.shape(resizeimg)[0] 209 | 210 | slide = subsize - gap 211 | left, top = 0, 0 212 | while (left < width): 213 | if (left + subsize >= width): 214 | left = max(width - subsize, 0) 215 | top = 0 216 | while (top < height): 217 | if (top + subsize >= height): 218 | top = max(height - subsize, 0) 219 | patch_name = base_name + str(left) + '___' + str(top) 220 | right = min(left + subsize - 1, width - 1) 221 | bottom = min(top + subsize - 1, height - 1) 222 | 223 | # 获取一张瓦片图 224 | patch_image = slice_patch(resizeimg, left, top, subsize, padding) 225 | 226 | # 获取瓦片图内的bbox 227 | if bboxes is None: 228 | patch_bboxes = [] 229 | else: 230 | patch_bboxes = assign_bboxes(left, top, right, bottom, bboxes, threshold) 231 | patch = (patch_name, rate, left, top, patch_image, patch_bboxes) 232 | patches.append(patch) 233 | 234 | # 滑动 235 | if (top + subsize >= height): 236 | break 237 | else: 238 | top = top + slide 239 | if (left + subsize >= width): 240 | break 241 | else: 242 | left = left + slide 243 | return patches 244 | 245 | 246 | def slice_coco_dataset(coco_image_src, 247 | coco_image_dst, 248 | coco_json_src, 249 | coco_json_dst, 250 | rate, 251 | subsize, 252 | gap, 253 | padding=True, 254 | threshold=0.5): 255 | coco = COCO(coco_json_src) 256 | coco_creator = COCOCreator() 257 | 258 | image_cnt = 0 259 | anno_cnt = 0 260 | for parent_image_id in tqdm(coco.getImgIds(), desc="slice"): 261 | # 1.读取图像 262 | image_info = coco.loadImgs(ids=parent_image_id)[0] 263 | image = np.array(Image.open(os.path.join(coco_image_src, image_info['file_name']))) 264 | image_base_name, extension = os.path.splitext(image_info['file_name']) 265 | 266 | # 2.获取图像gt_bboxes 267 | anno_infos = coco.loadAnns(coco.getAnnIds(imgIds=parent_image_id)) 268 | gt_bboxes = [] 269 | for ann_info in anno_infos: 270 | left, top, width, height = ann_info['bbox'] 271 | right, bottom = left + width, top + height 272 | gt_bbox = QuadrangleBBox(x1=left, 273 | y1=top, 274 | x2=right, 275 | y2=top, 276 | x3=right, 277 | y3=bottom, 278 | x4=left, 279 | y4=bottom, 280 | label=ann_info['category_id'], 281 | area=ann_info['area']) 282 | gt_bboxes.append(gt_bbox) 283 | 284 | # 3.图像切片,gt_bboxes分配 285 | patches = slice_one_image(image=image, 286 | image_base_name=image_base_name, 287 | subsize=subsize, 288 | gap=gap, 289 | bboxes=gt_bboxes, 290 | rate=rate, 291 | threshold=threshold, 292 | padding=padding) 293 | 294 | for patch_name, rate, left, top, patch_image, patch_bboxes in patches: 295 | if len(patch_bboxes) == 0: # 无目标的切片丢弃 296 | continue 297 | 298 | patch_image_filename = f"{patch_name}{extension}" 299 | 300 | # 4 保存切片图像 301 | Image.fromarray(patch_image.astype(np.uint8)).save(os.path.join(coco_image_dst, patch_image_filename)) 302 | 303 | # 5 新增切片的image_info 304 | image_info = coco_creator.create_image_info( 305 | image_id=image_cnt, 306 | file_name=patch_image_filename, 307 | height=subsize, 308 | width=subsize, 309 | parent_image_id=parent_image_id, # 来自同一大图的切片具有相同的大图id 310 | slice_params=[rate, left, top]) # 切片左上角在大图中的坐标,大图切片前的resize参数rate 311 | coco_creator.update_image_info(image_info) 312 | 313 | # 6 新增切片的anno_info 314 | for bbox in patch_bboxes: 315 | bbox: QuadrangleBBox 316 | anno_info = coco_creator.create_anno_info(image_id=image_cnt, 317 | anno_id=anno_cnt, 318 | bbox=[bbox.left, bbox.top, bbox.width, bbox.height], 319 | category_id=bbox.label, 320 | area=bbox.area) 321 | coco_creator.update_anno_info(anno_info) 322 | anno_cnt += 1 323 | image_cnt += 1 324 | 325 | # 直接引用大图的cat_info 326 | for cat_info in coco.loadCats(coco.getCatIds()): 327 | coco_creator.update_cat_info(cat_info) 328 | 329 | # 生成新的coco_json 330 | coco_creator.write(coco_json_dst) -------------------------------------------------------------------------------- /pyinfer/utils/functional/traits.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | __all__ = ["WarpAffineTraits"] 5 | 6 | 7 | class WarpAffineTraits(): 8 | "仿射变换" 9 | 10 | def __init__(self, sx, sy, dx, dy): 11 | self.sx = sx 12 | self.sy = sy 13 | self.dx = dx 14 | self.dy = dy 15 | # 仿射变换矩阵 16 | self.m2x3_to_dst, self.m2x3_to_src = self.init() 17 | 18 | def init(self): 19 | scale_x = self.dx / self.sx 20 | scale_y = self.dy / self.sy 21 | self.scale = min(scale_x, scale_y) 22 | 23 | self.tx = round(self.scale * self.sx) 24 | self.ty = round(self.scale * self.sy) 25 | 26 | # keep ratio resize 27 | m2x3_to_dst = np.zeros((2, 3)) 28 | m2x3_to_dst[0][0] = self.scale 29 | m2x3_to_dst[0][1] = 0 30 | m2x3_to_dst[0][2] = -self.scale * self.sx * 0.5 + self.dx * 0.5 + self.scale * 0.5 - 0.5 31 | m2x3_to_dst[1][0] = 0 32 | m2x3_to_dst[1][1] = self.scale 33 | m2x3_to_dst[1][2] = -self.scale * self.sy * 0.5 + self.dy * 0.5 + self.scale * 0.5 - 0.5 34 | m2x3_to_dst = m2x3_to_dst.astype(np.float32) 35 | 36 | m2x3_to_src = cv2.invertAffineTransform(m2x3_to_dst).astype(np.float32) 37 | return m2x3_to_dst, m2x3_to_src 38 | 39 | def __call__(self, src_img: np.ndarray, interpolation=cv2.INTER_LINEAR, pad_value=[114, 114, 114]): 40 | "对输入图像进行仿射变换" 41 | top = int((self.dy - self.ty) * 0.5) 42 | left = int((self.dx - self.tx) * 0.5) 43 | bottom = self.dy - self.ty - top 44 | right = self.dx - self.tx - left 45 | dst_img = cv2.resize(src_img, (0, 0), fx=self.scale, fy=self.scale, interpolation=interpolation) 46 | dst_img = cv2.copyMakeBorder(dst_img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=pad_value) 47 | return dst_img 48 | 49 | def to_src_coord(self, dx, dy): 50 | "变换后坐标->变换前坐标" 51 | sx, sy = np.matmul(self.m2x3_to_src, np.array([dx, dy, 1]).T) 52 | sx = min(max(round(sx), 0), self.sx - 1) 53 | sy = min(max(round(sy), 0), self.sy - 1) 54 | return sx, sy 55 | 56 | def to_dst_coord(self, sx, sy): 57 | "变换前坐标->变换后坐标" 58 | dx, dy = np.matmul(self.m2x3_to_dst, np.array([sx, sy, 1]).T) 59 | dx = min(max(round(dx), 0), self.dx - 1) 60 | dy = min(max(round(dy), 0), self.dy - 1) 61 | return dx, dy 62 | -------------------------------------------------------------------------------- /requirments.txt: -------------------------------------------------------------------------------- 1 | colorlog 2 | numpy 3 | fastapi[all] 4 | opencv-python 5 | Sharply 6 | uvicorn 7 | pycocotools -------------------------------------------------------------------------------- /static/README.md: -------------------------------------------------------------------------------- 1 | # 关于FastAPI 2 | 3 | [离线环境无法加载docs问题](https://fastapi.tiangolo.com/advanced/extending-openapi/#self-hosting-javascript-and-css-for-docs) 4 | 5 | ```python 6 | . 7 | ├── app 8 | │ ├── __init__.py 9 | │ ├── main.py 10 | └── static 11 | ├── redoc.standalone.js 12 | ├── swagger-ui-bundle.js 13 | └── swagger-ui.css 14 | 15 | ``` 16 | 17 | 基于上述目录结构,api脚本改进如下: 18 | 19 | ```python 20 | from fastapi import FastAPI 21 | from fastapi.openapi.docs import ( 22 | get_redoc_html, 23 | get_swagger_ui_html, 24 | get_swagger_ui_oauth2_redirect_html, 25 | ) 26 | from fastapi.staticfiles import StaticFiles 27 | 28 | app = FastAPI(docs_url=None, redoc_url=None) 29 | 30 | app.mount("/static", StaticFiles(directory="static"), name="static") 31 | 32 | 33 | @app.get("/docs", include_in_schema=False) 34 | async def custom_swagger_ui_html(): 35 | return get_swagger_ui_html( 36 | openapi_url=app.openapi_url, 37 | title=app.title + " - Swagger UI", 38 | oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url, 39 | swagger_js_url="/static/swagger-ui-bundle.js", 40 | swagger_css_url="/static/swagger-ui.css", 41 | ) 42 | 43 | 44 | @app.get(app.swagger_ui_oauth2_redirect_url, include_in_schema=False) 45 | async def swagger_ui_redirect(): 46 | return get_swagger_ui_oauth2_redirect_html() 47 | 48 | 49 | @app.get("/redoc", include_in_schema=False) 50 | async def redoc_html(): 51 | return get_redoc_html( 52 | openapi_url=app.openapi_url, 53 | title=app.title + " - ReDoc", 54 | redoc_js_url="/static/redoc.standalone.js", 55 | ) 56 | 57 | 58 | @app.get("/users/{username}") 59 | async def read_user(username: str): 60 | return {"message": f"Hello {username}"} 61 | 62 | ``` -------------------------------------------------------------------------------- /tools/generate_gif.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | 3 | 4 | def create_gif(image_list, gif_name, duration=1.0): 5 | ''' 6 | :param image_list: 这个列表用于存放生成动图的图片 7 | :param gif_name: 字符串,所生成gif文件名,带.gif后缀 8 | :param duration: 图像间隔时间 9 | :return: 10 | ''' 11 | frames = [] 12 | for image_name in image_list: 13 | frames.append(imageio.imread(image_name)) 14 | 15 | imageio.mimsave(gif_name, frames, 'GIF', duration=duration) 16 | return 17 | 18 | 19 | def main(): 20 | import os 21 | image_dir = "C:/Users/wzy/Desktop/NOTEBOOK/docs/algorithm/images/新建文件夹" 22 | image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)] 23 | gif_name = os.path.join(image_dir, "1.gif") 24 | duration = 1.5 25 | create_gif(image_list, gif_name, duration) 26 | 27 | 28 | if __name__ == '__main__': 29 | main() -------------------------------------------------------------------------------- /tools/mmdet_export_onnx/balloon/export_yolox_onnx.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 导出mmdetection onnx思路:(或者任意框架) 3 | 4 | 1. 跑通官方推理代码,等够调试官方推理代码,不能调试则无法分析; 5 | 2. 找到模型核心部分,即forward涉及的关键组件,例如backbone, neck, head等,其他部分丢弃; 6 | 3. 重新组合Model(nn.Module)封装核心forward组件; 7 | 4. 导出onnx观察,若有多个输出,在forward中补充代码,调整模型输出; 8 | 9 | mmdetection官方onnx导出不够灵活,调整onnx很容易报错,错误解决需要充分阅读理解框架算法源码实现; 10 | ''' 11 | import sys 12 | sys.path.append("/volume/huaru/third_party/mmdetection") # 配置环境变量 13 | 14 | import torch 15 | from mmdet.apis import init_detector, inference_detector 16 | 17 | config_file = '/volume/wzy/project/PyInfer/tools/mmdet_export_onnx/balloon/yolox_s_8x8_300e_coco.py' 18 | # 从 model zoo 下载 checkpoint 并放在 `checkpoints/` 文件下 19 | checkpoint_file = '/volume/wzy/project/PyInfer/tools/mmdet_export_onnx/balloon/model.pth' 20 | device = 'cuda:0' 21 | 22 | #初始化检测器 23 | # model = init_detector(config_file, checkpoint_file, device=device) 24 | # # 推理演示图像 25 | # print(inference_detector(model, 'demo/demo.jpg')) 26 | 27 | class Model(torch.nn.Module): 28 | def __init__(self): 29 | super().__init__() 30 | self.model = init_detector(config_file, checkpoint_file, device=device) 31 | 32 | def forward(self, x): 33 | ib, ic, ih, iw = map(int, x.shape) 34 | x = self.model.backbone(x) 35 | x = self.model.neck(x) 36 | clas, bbox, objness = self.model.bbox_head(x) 37 | 38 | # 网络输出映射到网络输入,heatmap特征层bbox映射到输入图像bbox 39 | # 1.映射bbox、2.decode(逆仿射变换+confidence过滤)、3.nms,其中2,3属于后处理,映射放在onnx中是为了简化后处理, 后处理不放在onnx中是为单独编写cuda提高效率 40 | output_x = [] 41 | for class_item, bbox_item, objness_item in zip(clas, bbox, objness): 42 | hm_b, hm_c, hm_h, hm_w = map(int, class_item.shape) 43 | stride_h, stride_w = ih / hm_h, iw / hm_w 44 | strides = torch.tensor([stride_w, stride_h], device=device).view(-1, 1, 2) 45 | 46 | prior_y, prior_x = torch.meshgrid(torch.arange(hm_h), torch.arange(hm_w)) 47 | prior_x = prior_x.reshape(hm_h * hm_w, 1).to(device) 48 | prior_y = prior_y.reshape(hm_h * hm_w, 1).to(device) 49 | prior_xy = torch.cat([prior_x, prior_y], dim=-1) 50 | class_item = class_item.permute(0, 2, 3, 1).reshape(-1, hm_h * hm_w, hm_c) 51 | bbox_item = bbox_item.permute(0, 2, 3, 1).reshape(-1, hm_h * hm_w, 4) 52 | objness_item = objness_item.reshape(-1, hm_h * hm_w, 1) 53 | pred_xy = (bbox_item[..., :2] + prior_xy) * strides 54 | pred_wh = bbox_item[..., 2:4].exp() * strides 55 | pred_class = torch.cat([objness_item, class_item], dim=-1).sigmoid() 56 | output_x.append(torch.cat([pred_xy, pred_wh, pred_class], dim=-1)) 57 | 58 | return torch.cat(output_x, dim=1) 59 | 60 | m = Model().eval() 61 | 62 | image = torch.zeros(1, 3, 640, 640, device=device) 63 | torch.onnx.export( 64 | m, (image,), "yolox.onnx", 65 | opset_version=11, 66 | input_names=["images"], 67 | output_names=["output"], 68 | dynamic_axes={ 69 | "images": {0: "batch"}, 70 | "output": {0: "batch"} 71 | } 72 | ) 73 | print("Done.!") -------------------------------------------------------------------------------- /tools/mmdet_export_onnx/balloon/yolox_s_8x8_300e_coco.py: -------------------------------------------------------------------------------- 1 | application_root = '' 2 | data_root = '' 3 | img_scale = (640, 640) 4 | dataset_type = 'CocoDataset' 5 | classes = ('balloon', ) 6 | num_classes = 1 7 | train_pipeline = [ 8 | dict(type='Mosaic', img_scale=(640, 640), pad_val=114.0), 9 | dict(type='YOLOXHSVRandomAug'), 10 | dict(type='RandomFlip', flip_ratio=0.5), 11 | dict(type='Resize', img_scale=(640, 640), keep_ratio=True), 12 | dict( 13 | type='Pad', 14 | pad_to_square=True, 15 | pad_val=dict(img=(114.0, 114.0, 114.0))), 16 | dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), 17 | dict(type='DefaultFormatBundle'), 18 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 19 | ] 20 | test_pipeline = [ 21 | dict(type='LoadImageFromFile'), 22 | dict( 23 | type='MultiScaleFlipAug', 24 | img_scale=(640, 640), 25 | flip=False, 26 | transforms=[ 27 | dict(type='Resize', keep_ratio=True), 28 | dict(type='RandomFlip'), 29 | dict( 30 | type='Pad', 31 | pad_to_square=True, 32 | pad_val=dict(img=(114.0, 114.0, 114.0))), 33 | dict(type='DefaultFormatBundle'), 34 | dict(type='Collect', keys=['img']) 35 | ]) 36 | ] 37 | train_dataset = dict( 38 | type='MultiImageMixDataset', 39 | dataset=dict( 40 | type='CocoDataset', 41 | ann_file='', 42 | img_prefix='', 43 | classes=('balloon', ), 44 | pipeline=[ 45 | dict(type='LoadImageFromFile'), 46 | dict(type='LoadAnnotations', with_bbox=True) 47 | ], 48 | filter_empty_gt=False), 49 | pipeline=[ 50 | dict(type='Mosaic', img_scale=(640, 640), pad_val=114.0), 51 | dict(type='YOLOXHSVRandomAug'), 52 | dict(type='RandomFlip', flip_ratio=0.5), 53 | dict(type='Resize', img_scale=(640, 640), keep_ratio=True), 54 | dict( 55 | type='Pad', 56 | pad_to_square=True, 57 | pad_val=dict(img=(114.0, 114.0, 114.0))), 58 | dict( 59 | type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False), 60 | dict(type='DefaultFormatBundle'), 61 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 62 | ]) 63 | val_dataset = dict( 64 | type='CocoDataset', 65 | ann_file='', 66 | img_prefix='', 67 | classes=('balloon', ), 68 | pipeline=[ 69 | dict(type='LoadImageFromFile'), 70 | dict( 71 | type='MultiScaleFlipAug', 72 | img_scale=(640, 640), 73 | flip=False, 74 | transforms=[ 75 | dict(type='Resize', keep_ratio=True), 76 | dict(type='RandomFlip'), 77 | dict( 78 | type='Pad', 79 | pad_to_square=True, 80 | pad_val=dict(img=(114.0, 114.0, 114.0))), 81 | dict(type='DefaultFormatBundle'), 82 | dict(type='Collect', keys=['img']) 83 | ]) 84 | ]) 85 | data = dict( 86 | samples_per_gpu=4, 87 | workers_per_gpu=2, 88 | persistent_workers=True, 89 | train=dict( 90 | type='MultiImageMixDataset', 91 | dataset=dict( 92 | type='CocoDataset', 93 | ann_file='', 94 | img_prefix='', 95 | classes=('balloon', ), 96 | pipeline=[ 97 | dict(type='LoadImageFromFile'), 98 | dict(type='LoadAnnotations', with_bbox=True) 99 | ], 100 | filter_empty_gt=False), 101 | pipeline=[ 102 | dict(type='Mosaic', img_scale=(640, 640), pad_val=114.0), 103 | dict(type='YOLOXHSVRandomAug'), 104 | dict(type='RandomFlip', flip_ratio=0.5), 105 | dict(type='Resize', img_scale=(640, 640), keep_ratio=True), 106 | dict( 107 | type='Pad', 108 | pad_to_square=True, 109 | pad_val=dict(img=(114.0, 114.0, 114.0))), 110 | dict( 111 | type='FilterAnnotations', 112 | min_gt_bbox_wh=(1, 1), 113 | keep_empty=False), 114 | dict(type='DefaultFormatBundle'), 115 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 116 | ]), 117 | val=dict( 118 | type='CocoDataset', 119 | ann_file='', 120 | img_prefix='', 121 | classes=('balloon', ), 122 | pipeline=[ 123 | dict(type='LoadImageFromFile'), 124 | dict( 125 | type='MultiScaleFlipAug', 126 | img_scale=(640, 640), 127 | flip=False, 128 | transforms=[ 129 | dict(type='Resize', keep_ratio=True), 130 | dict(type='RandomFlip'), 131 | dict( 132 | type='Pad', 133 | pad_to_square=True, 134 | pad_val=dict(img=(114.0, 114.0, 114.0))), 135 | dict(type='DefaultFormatBundle'), 136 | dict(type='Collect', keys=['img']) 137 | ]) 138 | ]), 139 | test=dict( 140 | type='CocoDataset', 141 | ann_file='', 142 | img_prefix='', 143 | classes=('balloon', ), 144 | pipeline=[ 145 | dict(type='LoadImageFromFile'), 146 | dict( 147 | type='MultiScaleFlipAug', 148 | img_scale=(640, 640), 149 | flip=False, 150 | transforms=[ 151 | dict(type='Resize', keep_ratio=True), 152 | dict(type='RandomFlip'), 153 | dict( 154 | type='Pad', 155 | pad_to_square=True, 156 | pad_val=dict(img=(114.0, 114.0, 114.0))), 157 | dict(type='DefaultFormatBundle'), 158 | dict(type='Collect', keys=['img']) 159 | ]) 160 | ])) 161 | model = dict( 162 | type='YOLOX', 163 | input_size=(640, 640), 164 | random_size_range=(15, 25), 165 | random_size_interval=10, 166 | backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5), 167 | neck=dict( 168 | type='YOLOXPAFPN', 169 | in_channels=[128, 256, 512], 170 | out_channels=128, 171 | num_csp_blocks=1), 172 | bbox_head=dict( 173 | type='YOLOXHead', num_classes=1, in_channels=128, feat_channels=128), 174 | train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)), 175 | test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65))) 176 | max_epochs = 50 177 | num_last_epochs = 14 178 | resume_from = None 179 | interval = 5 180 | optimizer = dict( 181 | type='SGD', 182 | lr=0.0001, 183 | momentum=0.9, 184 | weight_decay=0.0005, 185 | nesterov=True, 186 | paramwise_cfg=dict(norm_decay_mult=0.0, bias_decay_mult=0.0)) 187 | optimizer_config = dict(grad_clip=None) 188 | lr_config = dict( 189 | policy='YOLOX', 190 | warmup='exp', 191 | by_epoch=False, 192 | warmup_by_epoch=True, 193 | warmup_ratio=1, 194 | warmup_iters=15, 195 | num_last_epochs=14, 196 | min_lr_ratio=0.05) 197 | custom_hooks = [ 198 | dict(type='YOLOXModeSwitchHook', num_last_epochs=14, priority=48), 199 | dict(type='SyncNormHook', num_last_epochs=14, interval=5, priority=48), 200 | dict( 201 | type='ExpMomentumEMAHook', 202 | resume_from=None, 203 | momentum=0.0001, 204 | priority=49) 205 | ] 206 | checkpoint_config = dict(interval=5) 207 | evaluation = dict( 208 | save_best='auto', interval=5, dynamic_intervals=[(36, 5)], metric='bbox') 209 | log_config = dict(interval=8, hooks=[dict(type='TextLoggerHook')]) 210 | dist_params = dict(backend='nccl') 211 | log_level = 'INFO' 212 | opencv_num_threads = 0 213 | auto_scale_lr = dict(base_batch_size=8) 214 | mp_start_method = 'fork' 215 | workflow = [('train', 1)] 216 | runner = dict(type='EpochBasedRunner', max_epochs=50) 217 | load_from = '' 218 | work_dir = '' 219 | auto_resume = False 220 | gpu_ids = [0] 221 | -------------------------------------------------------------------------------- /tools/slice_coco.py: -------------------------------------------------------------------------------- 1 | from pyinfer.utils.common.slice import slice_coco_dataset 2 | 3 | if __name__ == "__main__": 4 | coco_image_src = "C:/Users/wzy/Desktop/xxx/async_infer_python/workspace/coco/train" 5 | coco_image_dst = "C:/Users/wzy/Desktop/xxx/async_infer_python/workspace/coco/slice_train" 6 | 7 | coco_json_src = "C:/Users/wzy/Desktop/xxx/async_infer_python/workspace/coco/train.json" 8 | coco_json_dst = "C:/Users/wzy/Desktop/xxx/async_infer_python/workspace/coco/slice_train.json" 9 | 10 | slice_coco_dataset( 11 | coco_image_src, 12 | coco_image_dst, 13 | coco_json_src, 14 | coco_json_dst, 15 | rate=1, 16 | subsize=640, 17 | gap=200, 18 | padding=True, 19 | threshold=0.2) 20 | -------------------------------------------------------------------------------- /tools/visual_coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import matplotlib.pyplot as plt 4 | from tqdm import tqdm 5 | from pycocotools.coco import COCO 6 | import numpy as np 7 | 8 | 9 | def visual_coco(coco_image, coco_json): 10 | coco = COCO(coco_json) 11 | for image_id in tqdm(coco.getImgIds(), desc="visual coco"): 12 | image_info = coco.loadImgs(ids=image_id)[0] 13 | image = np.array(Image.open(os.path.join( 14 | coco_image, image_info['file_name']))) 15 | plt.imshow(image) 16 | anno_infos = coco.loadAnns(coco.getAnnIds(imgIds=image_id)) 17 | coco.showAnns(anno_infos, draw_bbox=True) 18 | plt.show() 19 | 20 | 21 | coco = COCO 22 | if __name__ == "__main__": 23 | coco_image = "" 24 | coco_json = "" 25 | visual_coco(coco_image, coco_json) 26 | -------------------------------------------------------------------------------- /workspace/balloon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/workspace/balloon.jpg -------------------------------------------------------------------------------- /workspace/infer_balloon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangzyon/pyInfer/bff1d9800ffd773ab6745f2ea98d4a83dfdb032a/workspace/infer_balloon.jpg --------------------------------------------------------------------------------