├── jittordet ├── version.py ├── models │ ├── necks │ │ └── __init__.py │ ├── preprocessors │ │ └── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ └── MSFA.py │ ├── layers │ │ ├── linear.py │ │ ├── __init__.py │ │ └── scale.py │ ├── task_utils │ │ ├── prior_generators │ │ │ ├── __init__.py │ │ │ └── utils.py │ │ ├── bbox_coders │ │ │ ├── __init__.py │ │ │ ├── base_bbox_coder.py │ │ │ └── distance_point_bbox_coder.py │ │ ├── samplers │ │ │ ├── __init__.py │ │ │ ├── pseudo_sampler.py │ │ │ ├── sampling_result.py │ │ │ ├── random_sampler.py │ │ │ └── base_sampler.py │ │ ├── assigners │ │ │ ├── __init__.py │ │ │ ├── base_assigner.py │ │ │ ├── iou2d_calculator.py │ │ │ └── assign_result.py │ │ └── __init__.py │ ├── roi_heads │ │ ├── roi_extractors │ │ │ ├── __init__.py │ │ │ ├── base_roi_extractor.py │ │ │ └── single_level_roi_extractor.py │ │ ├── bbox_heads │ │ │ └── __init__.py │ │ ├── __init__.py │ │ └── base_roi_head.py │ ├── dense_heads │ │ ├── __init__.py │ │ └── retina_head.py │ ├── frameworks │ │ ├── __init__.py │ │ ├── kd_stage.py │ │ ├── rpn.py │ │ └── single_stage.py │ ├── __init__.py │ ├── losses │ │ ├── __init__.py │ │ ├── pkd_loss.py │ │ ├── accuracy.py │ │ ├── utils.py │ │ ├── smooth_l1_loss.py │ │ ├── focal_loss.py │ │ └── kd_loss.py │ └── utils │ │ ├── __init__.py │ │ ├── transforms.py │ │ └── initialize.py ├── ops │ ├── __init__.py │ ├── nms.py │ └── bbox_geometry.py ├── __init__.py ├── engine │ ├── hooks │ │ ├── __init__.py │ │ ├── checkpoint_hook.py │ │ └── base_hook.py │ ├── evaluator │ │ ├── __init__.py │ │ └── base_evaluator.py │ ├── loops │ │ ├── __init__.py │ │ ├── base_loop.py │ │ ├── val_loop.py │ │ ├── test_loop.py │ │ └── train_loop.py │ ├── optim │ │ ├── __init__.py │ │ ├── optimizers.py │ │ └── schedulers.py │ ├── __init__.py │ ├── register │ │ ├── __init__.py │ │ ├── fields.py │ │ └── register.py │ ├── config │ │ ├── dumpers.py │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── readers.py │ │ ├── parsers.py │ │ └── config.py │ └── logger.py ├── datasets │ ├── samplers │ │ ├── __init__.py │ │ ├── base_batch_sampler.py │ │ ├── pad_batch_sampler.py │ │ └── aspect_ratio_batch_sampler.py │ ├── transforms │ │ ├── __init__.py │ │ ├── loading.py │ │ └── formatting.py │ ├── sardet100k.py │ ├── __init__.py │ ├── wrappers.py │ ├── voc.py │ └── coco.py ├── utils │ ├── dist.py │ ├── __init__.py │ ├── util_random.py │ ├── types.py │ └── bbox_transforms.py └── structures │ ├── __init__.py │ └── det_data_sample.py ├── configs ├── _common_ │ ├── sgd_0_01.yml │ ├── sgd_0_02.yml │ ├── adamw_0_0001.yml │ ├── default_setting.yml │ ├── loop_1x.yml │ └── loop_2x.yml ├── gfl │ ├── gfl_r50_fpn_coco_1x.yml │ └── _model_ │ │ └── gfl_r50_fpn.yml ├── retinanet │ ├── retinanet_r50_fpn_voc0712_1x.yml │ ├── retinanet_r50_fpn_coco_1x.yml │ └── _model_ │ │ └── retinanet_r50_fpn.yml ├── faster_rcnn │ ├── faster_rcnn_r50_fpn_voc0712_1x.yml │ ├── faster_rcnn_r50_fpn_coco_1x.yml │ ├── _model_ │ │ └── faster_rcnn_r50_fpn.yml │ └── MSFA_faster_rcnn_r50_fpn_sardet_1x.yml ├── crosskd │ ├── crosskd_gfl_r50_r18_fpn_coco_1x.yml │ └── _model_ │ │ └── crosskd_gfl_r50_r18_fpn.yml ├── rpn │ ├── rpn_r50_fpn_coco_1x.yml │ └── _model_ │ │ └── rpn_r50_fpn.yml └── _dataset_ │ ├── coco_detection.yml │ ├── voc0712.yml │ └── sardet100k.yml ├── requirements.txt ├── tools ├── dist_train.sh ├── dist_test.sh ├── test.py └── train.py ├── setup.cfg ├── .github └── workflows │ └── lint.yml ├── .pre-commit-config.yaml ├── setup.py ├── .gitignore └── README.md /jittordet/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.1' 2 | -------------------------------------------------------------------------------- /jittordet/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | from .fpn import FPN 2 | 3 | __all__ = ['FPN'] 4 | -------------------------------------------------------------------------------- /configs/_common_/sgd_0_01.yml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: SGD 3 | lr: 0.01 4 | momentum: 0.9 5 | weight_decay: 0.0001 6 | -------------------------------------------------------------------------------- /configs/_common_/sgd_0_02.yml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: SGD 3 | lr: 0.02 4 | momentum: 0.9 5 | weight_decay: 0.0001 6 | -------------------------------------------------------------------------------- /jittordet/models/preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocessor import Preprocessor 2 | 3 | __all__ = ['Preprocessor'] 4 | -------------------------------------------------------------------------------- /configs/_common_/adamw_0_0001.yml: -------------------------------------------------------------------------------- 1 | optimizer: 2 | type: AdamW 3 | lr: 0.0001 4 | betas: (0.9,0.999) 5 | weight_decay: 0.05 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jittor > 1.3.6 2 | matplotlib 3 | numpy 4 | numpy 5 | opencv-python 6 | Pillow 7 | pycocotools 8 | terminaltables 9 | -------------------------------------------------------------------------------- /jittordet/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import ResNet, ResNetV1d 2 | from .MSFA import MSFA 3 | 4 | __all__ = ['ResNet', 'ResNetV1d', 'MSFA'] 5 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | 6 | mpirun -np $GPUS python tools/train.py \ 7 | $CONFIG 8 | ${@:3} 9 | -------------------------------------------------------------------------------- /jittordet/models/layers/linear.py: -------------------------------------------------------------------------------- 1 | import jittor.nn as nn 2 | 3 | from jittordet.engine import MODELS 4 | 5 | MODELS.register_module('Linear', module=nn.Linear) 6 | -------------------------------------------------------------------------------- /jittordet/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .roi_align import ROIAlign, roi_align 2 | from .roi_pool import ROIPool, roi_pool 3 | 4 | __all__ = ['ROIAlign', 'roi_align', 'ROIPool', 'roi_pool'] 5 | -------------------------------------------------------------------------------- /jittordet/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv_module import ConvModule 2 | from .linear import * # noqa: F403, F401 3 | from .scale import Scale 4 | 5 | __all__ = ['ConvModule', 'Scale'] 6 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | 7 | mpirun -np $GPUS python tools/test.py \ 8 | $CONFIG 9 | $CHECKPOINT 10 | ${@:4} 11 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/prior_generators/__init__.py: -------------------------------------------------------------------------------- 1 | from .anchor_generator import AnchorGenerator 2 | from .utils import anchor_inside_flags 3 | 4 | __all__ = ['AnchorGenerator', 'anchor_inside_flags'] 5 | -------------------------------------------------------------------------------- /jittordet/__init__.py: -------------------------------------------------------------------------------- 1 | import jittordet.datasets # noqa: F401 2 | import jittordet.engine # noqa: F401 3 | import jittordet.models # noqa: F401 4 | from .version import __version__ 5 | 6 | __all__ = ['__version__'] 7 | -------------------------------------------------------------------------------- /jittordet/engine/hooks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_hook import BaseHook 2 | from .checkpoint_hook import CheckpointHook 3 | from .logger_hook import LoggerHook 4 | 5 | __all__ = ['BaseHook', 'CheckpointHook', 'LoggerHook'] 6 | -------------------------------------------------------------------------------- /configs/_common_/default_setting.yml: -------------------------------------------------------------------------------- 1 | hooks: 2 | - type: LoggerHook 3 | interval: 50 4 | interval_exp_name: 1000 5 | - type: CheckpointHook 6 | interval: 1 7 | 8 | disable_cuda: false 9 | seed: null 10 | -------------------------------------------------------------------------------- /configs/gfl/gfl_r50_fpn_coco_1x.yml: -------------------------------------------------------------------------------- 1 | _base_: 2 | - ../_dataset_/coco_detection.yml 3 | - ../_common_/default_setting.yml 4 | - ../_common_/loop_1x.yml 5 | - ../_common_/sgd_0_01.yml 6 | - ./_model_/gfl_r50_fpn.yml 7 | -------------------------------------------------------------------------------- /jittordet/models/roi_heads/roi_extractors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_roi_extractor import BaseRoIExtractor 2 | from .single_level_roi_extractor import SingleRoIExtractor 3 | 4 | __all__ = ['BaseRoIExtractor', 'SingleRoIExtractor'] 5 | -------------------------------------------------------------------------------- /configs/retinanet/retinanet_r50_fpn_voc0712_1x.yml: -------------------------------------------------------------------------------- 1 | _base_: 2 | - ../_dataset_/voc0712.yml 3 | - ../_common_/default_setting.yml 4 | - ../_common_/loop_1x.yml 5 | - ../_common_/sgd_0_01.yml 6 | - ./_model_/retinanet_r50_fpn.yml 7 | -------------------------------------------------------------------------------- /configs/faster_rcnn/faster_rcnn_r50_fpn_voc0712_1x.yml: -------------------------------------------------------------------------------- 1 | _base_: 2 | - ../_dataset_/voc0712.yml 3 | - ../_common_/default_setting.yml 4 | - ../_common_/loop_1x.yml 5 | - ../_common_/sgd_0_02.yml 6 | - ./_model_/faster_rcnn_r50_fpn.yml 7 | -------------------------------------------------------------------------------- /configs/retinanet/retinanet_r50_fpn_coco_1x.yml: -------------------------------------------------------------------------------- 1 | _base_: 2 | - ../_dataset_/coco_detection.yml 3 | - ../_common_/default_setting.yml 4 | - ../_common_/loop_1x.yml 5 | - ../_common_/sgd_0_01.yml 6 | - ./_model_/retinanet_r50_fpn.yml 7 | -------------------------------------------------------------------------------- /jittordet/engine/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_evaluator import BaseEvaluator 2 | from .coco_evaluator import CocoEvaluator 3 | from .voc_evaluator import VocEvaluator 4 | 5 | __all__ = ['BaseEvaluator', 'CocoEvaluator', 'VocEvaluator'] 6 | -------------------------------------------------------------------------------- /configs/faster_rcnn/faster_rcnn_r50_fpn_coco_1x.yml: -------------------------------------------------------------------------------- 1 | _base_: 2 | - ../_dataset_/coco_detection.yml 3 | - ../_common_/default_setting.yml 4 | - ../_common_/loop_1x.yml 5 | - ../_common_/sgd_0_02.yml 6 | - ./_model_/faster_rcnn_r50_fpn.yml 7 | -------------------------------------------------------------------------------- /configs/crosskd/crosskd_gfl_r50_r18_fpn_coco_1x.yml: -------------------------------------------------------------------------------- 1 | _base_: 2 | - ../_dataset_/coco_detection.yml 3 | - ../_common_/default_setting.yml 4 | - ../_common_/loop_1x.yml 5 | - ../_common_/sgd_0_01.yml 6 | - ./_model_/crosskd_gfl_r50_r18_fpn.yml 7 | -------------------------------------------------------------------------------- /jittordet/engine/loops/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_loop import BaseLoop 2 | from .test_loop import TestLoop 3 | from .train_loop import EpochTrainLoop 4 | from .val_loop import ValLoop 5 | 6 | __all__ = ['BaseLoop', 'EpochTrainLoop', 'ValLoop', 'TestLoop'] 7 | -------------------------------------------------------------------------------- /jittordet/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .aspect_ratio_batch_sampler import AspectRatioBatchSampler 2 | from .base_batch_sampler import BaseBatchSampler 3 | from .pad_batch_sampler import PadBatchSampler 4 | 5 | __all__ = ['BaseBatchSampler', 'PadBatchSampler', 'AspectRatioBatchSampler'] 6 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/bbox_coders/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_bbox_coder import BaseBBoxCoder 2 | from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder 3 | from .distance_point_bbox_coder import DistancePointBBoxCoder 4 | 5 | __all__ = ['BaseBBoxCoder', 'DeltaXYWHBBoxCoder', 'DistancePointBBoxCoder'] 6 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_sampler import BaseSampler 2 | from .pseudo_sampler import PseudoSampler 3 | from .random_sampler import RandomSampler 4 | from .sampling_result import SamplingResult 5 | 6 | __all__ = ['BaseSampler', 'PseudoSampler', 'RandomSampler', 'SamplingResult'] 7 | -------------------------------------------------------------------------------- /jittordet/models/dense_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .anchor_head import AnchorHead 2 | from .base_dense_head import BaseDenseHead 3 | from .gfl_head import GFLHead 4 | from .retina_head import RetinaHead 5 | from .rpn_head import RPNHead 6 | 7 | __all__ = ['BaseDenseHead', 'AnchorHead', 'RetinaHead', 'RPNHead', 'GFLHead'] 8 | -------------------------------------------------------------------------------- /jittordet/models/roi_heads/bbox_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .bbox_head import BBoxHead 2 | from .convfc_bbox_head import (ConvFCBBoxHead, Shared2FCBBoxHead, 3 | Shared4Conv1FCBBoxHead) 4 | 5 | __all__ = [ 6 | 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'Shared4Conv1FCBBoxHead' 7 | ] 8 | -------------------------------------------------------------------------------- /jittordet/utils/dist.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | 3 | 4 | def reduce_mean(var): 5 | """"Obtain the mean of tensor on different GPUs.""" 6 | if jt.world_size == 1: 7 | return var 8 | var = var.clone() 9 | var = var.mpi_all_reduce('mean') 10 | var.sync(device_sync=True) 11 | return var 12 | -------------------------------------------------------------------------------- /configs/rpn/rpn_r50_fpn_coco_1x.yml: -------------------------------------------------------------------------------- 1 | _base_: 2 | - ../_dataset_/coco_detection.yml 3 | - ../_common_/default_setting.yml 4 | - ../_common_/loop_1x.yml 5 | - ../_common_/sgd_0_02.yml 6 | - ./_model_/rpn_r50_fpn.yml 7 | 8 | val_evaluator: &val_evaluator 9 | metric: proposal_fast 10 | test_evaluator: *val_evaluator 11 | -------------------------------------------------------------------------------- /jittordet/engine/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .optimizers import register_jittor_optim 2 | from .schedulers import (BaseScheduler, CosineAnnealingLR, ExponentialLR, 3 | StepLR, WarmUpLR) 4 | 5 | __all__ = [ 6 | 'register_jittor_optim', 'BaseScheduler', 'WarmUpLR', 'CosineAnnealingLR', 7 | 'ExponentialLR', 'StepLR' 8 | ] 9 | -------------------------------------------------------------------------------- /jittordet/models/roi_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_roi_head import BaseRoIHead 2 | from .bbox_heads import BBoxHead 3 | from .roi_extractors import BaseRoIExtractor, SingleRoIExtractor 4 | from .standard_roi_head import StandardRoIHead 5 | 6 | __all__ = [ 7 | 'BaseRoIHead', 'StandardRoIHead', 'BBoxHead', 'BaseRoIExtractor', 8 | 'SingleRoIExtractor' 9 | ] 10 | -------------------------------------------------------------------------------- /jittordet/datasets/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .formatting import PackDetInputs 2 | from .loading import LoadAnnotations, LoadImageFromFile 3 | from .transforms import RandomChoiceResize, RandomFlip, RandomResize, Resize 4 | 5 | __all__ = [ 6 | 'PackDetInputs', 'LoadAnnotations', 'LoadImageFromFile', 'Resize', 7 | 'RandomResize', 'RandomChoiceResize', 'RandomFlip' 8 | ] 9 | -------------------------------------------------------------------------------- /jittordet/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * # noqa: F401, F403 2 | from .evaluator import * # noqa: F401, F403 3 | from .hooks import * # noqa: F401, F403 4 | from .logger import * # noqa: F401, F403 5 | from .loops import * # noqa: F401, F403 6 | from .optim import * # noqa: F401, F403 7 | from .register import * # noqa: F401, F403 8 | from .runner import * # noqa: F401, F403 9 | -------------------------------------------------------------------------------- /configs/_common_/loop_1x.yml: -------------------------------------------------------------------------------- 1 | train_loop: 2 | type: EpochTrainLoop 3 | max_epoch: 12 4 | val_interval: 1 5 | 6 | val_loop: 7 | type: ValLoop 8 | 9 | test_loop: 10 | type: TestLoop 11 | 12 | scheduler: 13 | - type: WarmUpLR 14 | warmup_ratio: 0.001 15 | warmup_iters: 500 16 | warmup: linear 17 | - type: MultiStepLR 18 | milestones: [8, 11] 19 | gamma: 0.1 20 | -------------------------------------------------------------------------------- /configs/_common_/loop_2x.yml: -------------------------------------------------------------------------------- 1 | train_loop: 2 | type: EpochTrainLoop 3 | max_epoch: 24 4 | val_interval: 1 5 | 6 | val_loop: 7 | type: ValLoop 8 | 9 | test_loop: 10 | type: TestLoop 11 | 12 | scheduler: 13 | - type: WarmUpLR 14 | warmup_ratio: 0.001 15 | warmup_iters: 500 16 | warmup: linear 17 | - type: MultiStepLR 18 | milestones: [16, 22] 19 | gamma: 0.1 20 | -------------------------------------------------------------------------------- /jittordet/structures/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_data_element import BaseDataElement 2 | from .det_data_sample import DetDataSample, OptSampleList, SampleList 3 | from .instance_data import InstanceData, InstanceList, OptInstanceList 4 | 5 | __all__ = [ 6 | 'BaseDataElement', 'InstanceData', 'DetDataSampler', 'DetDataSample', 7 | 'SampleList', 'OptSampleList', 'InstanceList', 'OptInstanceList' 8 | ] 9 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/assigners/__init__.py: -------------------------------------------------------------------------------- 1 | from .assign_result import AssignResult 2 | from .atss_assigner import ATSSAssigner 3 | from .base_assigner import BaseAssigner 4 | from .iou2d_calculator import BboxOverlaps2D 5 | from .max_iou_assigner import MaxIoUAssigner 6 | 7 | __all__ = [ 8 | 'BaseAssigner', 'MaxIoUAssigner', 'AssignResult', 'BboxOverlaps2D', 9 | 'ATSSAssigner' 10 | ] 11 | -------------------------------------------------------------------------------- /jittordet/engine/register/__init__.py: -------------------------------------------------------------------------------- 1 | from .fields import (BATCH_SAMPLERS, BRICKS, DATASETS, EVALUATORS, HOOKS, 2 | LOOPS, MODELS, OPTIMIZERS, SCHEDULERS, TASK_UTILS, 3 | TRANSFORMS) 4 | from .register import Register 5 | 6 | __all__ = [ 7 | 'Register', 'LOOPS', 'HOOKS', 'OPTIMIZERS', 'SCHEDULERS', 'DATASETS', 8 | 'MODELS', 'EVALUATORS', 'TRANSFORMS', 'BATCH_SAMPLERS', 'BRICKS', 9 | 'TASK_UTILS' 10 | ] 11 | -------------------------------------------------------------------------------- /jittordet/models/frameworks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_framework import BaseFramework 2 | from .gfl_kd_stage import GFLKDFramework 3 | from .kd_stage import KDSingleStageFramework 4 | from .multi_stage import MultiStageFramework 5 | from .rpn import RPNFramework 6 | from .single_stage import SingleStageFramework 7 | 8 | __all__ = [ 9 | 'BaseFramework', 'SingleStageFramework', 'MultiStageFramework', 10 | 'RPNFramework', 'KDSingleStageFramework', 'GFLKDFramework' 11 | ] 12 | -------------------------------------------------------------------------------- /jittordet/engine/config/dumpers.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import yaml 4 | 5 | 6 | def yaml_dumper(cfg, filepath): 7 | cfg = cfg.to_dict() 8 | with open(filepath, 'w') as f: 9 | yaml.dump(cfg, f) 10 | 11 | 12 | def json_dumper(cfg, filepath): 13 | cfg = cfg.to_dict() 14 | with open(filepath, 'w') as f: 15 | json.dump(cfg, f, indent=4) 16 | 17 | 18 | cfg_dumpers = { 19 | '.yml': yaml_dumper, 20 | '.yaml': yaml_dumper, 21 | '.json': json_dumper, 22 | } 23 | -------------------------------------------------------------------------------- /jittordet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import * # noqa: F401, F403 2 | from .dense_heads import * # noqa: F401, F403 3 | from .frameworks import * # noqa: F401, F403 4 | from .layers import * # noqa: F401, F403 5 | from .losses import * # noqa: F401, F403 6 | from .necks import * # noqa: F401, F403 7 | from .preprocessors import * # noqa: F401, F403 8 | from .roi_heads import * # noqa: F401, F403 9 | from .task_utils import * # noqa: F401, F403 10 | from .utils import * # noqa: F401, F403 11 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = jittordet/ops/* 3 | 4 | [isort] 5 | line_length = 79 6 | multi_line_output = 0 7 | extra_standard_library = setuptools 8 | known_first_party = jittordet 9 | known_third_party = yaml 10 | no_lines_before = STDLIB,LOCALFOLDER 11 | default_section = THIRDPARTY 12 | 13 | [yapf] 14 | BASED_ON_STYLE = pep8 15 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true 16 | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true 17 | 18 | [codespell] 19 | ignore-words-list = warmup,dout,combinate,colums 20 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: Set up Python 3.9 11 | uses: actions/setup-python@v2 12 | with: 13 | python-version: 3.9 14 | - name: Install pre-commit hook 15 | run: | 16 | pip install pre-commit 17 | pre-commit install 18 | - name: Linting 19 | run: pre-commit run --all-files 20 | -------------------------------------------------------------------------------- /jittordet/engine/optim/optimizers.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import jittor.optim as optim 4 | 5 | from ..register import OPTIMIZERS 6 | 7 | 8 | def register_jittor_optim(): 9 | for name, module in optim.__dict__.items(): 10 | if not inspect.isclass(module): 11 | continue 12 | 13 | if issubclass(module, 14 | optim.Optimizer) and module is not optim.Optimizer: 15 | OPTIMIZERS.register_module(name=name, module=module) 16 | 17 | 18 | register_jittor_optim() 19 | -------------------------------------------------------------------------------- /jittordet/engine/register/fields.py: -------------------------------------------------------------------------------- 1 | from .register import Register 2 | 3 | # engine 4 | LOOPS = Register('loops') 5 | HOOKS = Register('hooks') 6 | OPTIMIZERS = Register('optimizers') 7 | SCHEDULERS = Register('schedulers') 8 | EVALUATORS = Register('evaluators') 9 | 10 | # dataset 11 | DATASETS = Register('datasets') 12 | TRANSFORMS = Register('transforms') 13 | BATCH_SAMPLERS = Register('batch_sampler') 14 | 15 | # model 16 | MODELS = Register('models') 17 | BRICKS = Register('bricks') 18 | TASK_UTILS = Register('task_utils') 19 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .assigners import (AssignResult, BaseAssigner, BboxOverlaps2D, 2 | MaxIoUAssigner) 3 | from .bbox_coders import BaseBBoxCoder, DeltaXYWHBBoxCoder 4 | from .prior_generators import AnchorGenerator, anchor_inside_flags 5 | from .samplers import BaseSampler, PseudoSampler, RandomSampler, SamplingResult 6 | 7 | __all__ = [ 8 | 'AnchorGenerator', 'anchor_inside_flags', 'BaseAssigner', 'MaxIoUAssigner', 9 | 'AssignResult', 'BboxOverlaps2D', 'BaseBBoxCoder', 'DeltaXYWHBBoxCoder', 10 | 'BaseSampler', 'PseudoSampler', 'RandomSampler', 'SamplingResult' 11 | ] 12 | -------------------------------------------------------------------------------- /jittordet/models/layers/scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import jittor as jt 3 | import jittor.nn as nn 4 | 5 | 6 | class Scale(nn.Module): 7 | """A learnable scale parameter. 8 | 9 | This layer scales the input by a learnable factor. It multiplies a 10 | learnable scale parameter of shape (1,) with input of any shape. 11 | 12 | Args: 13 | scale (float): Initial value of scale factor. Default: 1.0 14 | """ 15 | 16 | def __init__(self, scale: float = 1.0): 17 | super().__init__() 18 | self.scale = jt.float32(scale) 19 | 20 | def execute(self, x): 21 | return x * self.scale 22 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/assigners/base_assigner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Optional 4 | 5 | from jittordet.structures import InstanceData 6 | 7 | 8 | class BaseAssigner(metaclass=ABCMeta): 9 | """Base assigner that assigns boxes to ground truth boxes.""" 10 | 11 | @abstractmethod 12 | def assign(self, 13 | pred_instances: InstanceData, 14 | gt_instances: InstanceData, 15 | gt_instances_ignore: Optional[InstanceData] = None, 16 | **kwargs): 17 | """Assign boxes to either a ground truth boxes or a negative boxes.""" 18 | -------------------------------------------------------------------------------- /jittordet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .bbox_overlaps import bbox_overlaps 2 | from .bbox_transforms import bbox2distance, distance2bbox 3 | from .dist import reduce_mean 4 | from .image import (_scale_size, imflip, imnormalize, impad, impad_to_multiple, 5 | imrescale, imresize, rescale_size) 6 | from .types import is_list_of, is_seq_of, is_tuple_of 7 | from .util_random import ensure_rng 8 | 9 | __all__ = [ 10 | '_scale_size', 'imresize', 'rescale_size', 'imrescale', 'imflip', 11 | 'imnormalize', 'impad', 'impad_to_multiple', 'is_seq_of', 'is_list_of', 12 | 'is_tuple_of', 'bbox_overlaps', 'ensure_rng', 'distance2bbox', 13 | 'bbox2distance', 'reduce_mean' 14 | ] 15 | -------------------------------------------------------------------------------- /jittordet/datasets/samplers/base_batch_sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | from jittordet.engine import BATCH_SAMPLERS 4 | 5 | 6 | @BATCH_SAMPLERS.register_module() 7 | class BaseBatchSampler(metaclass=ABCMeta): 8 | 9 | def __init__(self, dataset): 10 | self.total_bs = dataset.batch_size 11 | self.num_data_list = len(dataset.data_list) 12 | 13 | def __len__(self): 14 | length = int((self.num_data_list - 0.5) // self.total_bs) 15 | if hasattr(self, 'drop_last') and not self.drop_last: 16 | length += 1 17 | return length 18 | 19 | @abstractmethod 20 | def get_index_list(self, rng=None): 21 | pass 22 | -------------------------------------------------------------------------------- /jittordet/datasets/sardet100k.py: -------------------------------------------------------------------------------- 1 | # Modified from mmdetection.dataset.coco 2 | from ..engine import DATASETS 3 | from .coco import CocoDataset 4 | 5 | 6 | @DATASETS.register_module() 7 | class Sardet100k(CocoDataset): 8 | """ 9 | Dataset for Sardet100k. 10 | Download dataset at: https://liveuclac-my.sharepoint.com/:f:/g/personal/zcablii_ucl_ac_uk/EuYYZWXL_bJGvd8s9rGH2KYBV1GM5pIOCngnzlyuB_3e5A?e=bgoINm 11 | """ 12 | 13 | METAINFO = { 14 | 'classes': 15 | ('ship', 'aircraft', 'car', 'tank', 'bridge', 'harbor'), 16 | # palette is a list of color tuples, which is used for visualization. 17 | 'palette': 18 | [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), 19 | (0, 60, 100)] 20 | } 21 | -------------------------------------------------------------------------------- /jittordet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseDetDataset 2 | from .coco import CocoDataset 3 | from .samplers import (AspectRatioBatchSampler, BaseBatchSampler, 4 | PadBatchSampler) 5 | from .transforms import (LoadAnnotations, LoadImageFromFile, PackDetInputs, 6 | RandomChoiceResize, RandomFlip, RandomResize, Resize) 7 | from .voc import VocDataset 8 | from .wrappers import ConcatDataset 9 | from .sardet100k import Sardet100k 10 | 11 | __all__ = [ 12 | 'BaseDetDataset', 'CocoDataset', 'VocDataset', 'BaseBatchSampler', 13 | 'PadBatchSampler', 'AspectRatioBatchSampler', 'PackDetInputs', 'Resize', 14 | 'LoadAnnotations', 'LoadImageFromFile', 'RandomResize', 'RandomFlip', 15 | 'RandomChoiceResize', 'ConcatDataset', 'Sardet100k' 16 | ] 17 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/bbox_coders/base_bbox_coder.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab. 2 | # mmdet/models/task_modules/coders/base_bbox_coder.py 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | from abc import ABCMeta, abstractmethod 5 | 6 | 7 | class BaseBBoxCoder(metaclass=ABCMeta): 8 | """Base bounding box coder. 9 | 10 | Args: 11 | use_box_type (bool): Whether to warp decoded boxes with the 12 | box type data structure. Defaults to False. 13 | """ 14 | 15 | @abstractmethod 16 | def encode(self, bboxes, gt_bboxes): 17 | """Encode deltas between bboxes and ground truth boxes.""" 18 | 19 | @abstractmethod 20 | def decode(self, bboxes, bboxes_pred): 21 | """Decode the predicted bboxes according to prediction and base 22 | boxes.""" 23 | -------------------------------------------------------------------------------- /jittordet/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .accuracy import Accuracy, accuracy 2 | from .cross_entropy_loss import CrossEntropyLoss 3 | from .focal_loss import FocalLoss 4 | from .gfocal_loss import DistributionFocalLoss, QualityFocalLoss 5 | from .iou_loss import GIoULoss, IoULoss 6 | from .kd_loss import KDQualityFocalLoss, KnowledgeDistillationKLDivLoss 7 | from .pkd_loss import PKDLoss 8 | from .smooth_l1_loss import L1Loss, SmoothL1Loss 9 | from .utils import reduce_loss, weight_reduce_loss, weighted_loss 10 | 11 | __all__ = [ 12 | 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'CrossEntropyLoss', 13 | 'Accuracy', 'accuracy', 'FocalLoss', 'IoULoss', 'GIoULoss', 'SmoothL1Loss', 14 | 'L1Loss', 'QualityFocalLoss', 'DistributionFocalLoss', 15 | 'KDQualityFocalLoss', 'KnowledgeDistillationKLDivLoss', 'PKDLoss' 16 | ] 17 | -------------------------------------------------------------------------------- /jittordet/engine/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import (ConfigDict, ConfigType, MultiConfig, OptConfigType, 2 | OptMultiConfig, dump_cfg, load_cfg, merge_cfg) 3 | from .dumpers import cfg_dumpers, json_dumper, yaml_dumper 4 | from .parsers import (cfg_parsers, default_var_parser, env_variable_parser, 5 | python_eval_parser, tuple_parser) 6 | from .readers import cfg_readers, yaml_reader 7 | from .utils import delete_node, iter_leaves, set_leaf 8 | 9 | __all__ = [ 10 | 'load_cfg', 'merge_cfg', 'dump_cfg', 'yaml_reader', 'default_var_parser', 11 | 'cfg_readers', 'env_variable_parser', 'cfg_parsers', 'iter_leaves', 12 | 'set_leaf', 'cfg_dumpers', 'yaml_dumper', 'json_dumper', 'delete_node', 13 | 'ConfigType', 'OptConfigType', 'MultiConfig', 'OptMultiConfig', 14 | 'ConfigDict', 'python_eval_parser', 'tuple_parser' 15 | ] 16 | -------------------------------------------------------------------------------- /jittordet/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .initialize import (bias_init_with_prob, caffe2_xavier_init, 2 | constant_init, kaiming_init, normal_init, 3 | trunc_normal_init, uniform_init, xavier_init) 4 | from .misc import (empty_instances, filter_scores_and_topk, images_to_levels, 5 | multi_apply, select_single_mlvl, unmap, unpack_gt_instances) 6 | from .nms import batched_nms, multiclass_nms 7 | from .transforms import bbox2roi 8 | 9 | __all__ = [ 10 | 'normal_init', 'constant_init', 'xavier_init', 'trunc_normal_init', 11 | 'uniform_init', 'kaiming_init', 'caffe2_xavier_init', 12 | 'bias_init_with_prob', 'unpack_gt_instances', 'empty_instances', 13 | 'select_single_mlvl', 'filter_scores_and_topk', 'batched_nms', 14 | 'multi_apply', 'unmap', 'images_to_levels', 'bbox2roi', 'multiclass_nms' 15 | ] 16 | -------------------------------------------------------------------------------- /jittordet/engine/loops/base_loop.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class BaseLoop(metaclass=ABCMeta): 5 | 6 | def __init__(self, runner): 7 | self._runner = runner 8 | self._max_epoch = 1 9 | self._epoch = 0 10 | self._iter = 0 11 | 12 | @property 13 | def runner(self): 14 | return self._runner 15 | 16 | @property 17 | def cur_epoch(self): 18 | return self._epoch 19 | 20 | @property 21 | def max_epoch(self): 22 | return self._max_epoch 23 | 24 | @property 25 | def cur_iter(self): 26 | return self._iter 27 | 28 | @abstractmethod 29 | def run(self): 30 | pass 31 | 32 | def state_dict(self): 33 | return dict(epoch=self._epoch, iter=self._iter) 34 | 35 | def load_state_dict(self, data): 36 | assert isinstance(data, dict) 37 | self._epoch = data['epoch'] 38 | self._iter = data['iter'] 39 | -------------------------------------------------------------------------------- /jittordet/models/utils/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import jittor as jt 4 | 5 | 6 | def bbox2roi(bbox_list: List[jt.Var]) -> jt.Var: 7 | """Convert a list of bboxes to roi format. 8 | 9 | Args: 10 | bbox_list (List[Union[Tensor, :obj:`BaseBoxes`]): a list of bboxes 11 | corresponding to a batch of images. 12 | 13 | Returns: 14 | Tensor: shape (n, box_dim + 1), where ``box_dim`` depends on the 15 | different box types. For example, If the box type in ``bbox_list`` 16 | is HorizontalBoxes, the output shape is (n, 5). Each row of data 17 | indicates [batch_ind, x1, y1, x2, y2]. 18 | """ 19 | rois_list = [] 20 | for img_id, bboxes in enumerate(bbox_list): 21 | img_inds = jt.full((bboxes.size(0), 1), img_id, dtype=bboxes.dtype) 22 | rois = jt.concat([img_inds, bboxes], dim=-1) 23 | rois_list.append(rois) 24 | rois = jt.concat(rois_list, 0) 25 | return rois 26 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/PyCQA/flake8 3 | rev: 5.0.4 4 | hooks: 5 | - id: flake8 6 | - repo: https://github.com/PyCQA/isort 7 | rev: 5.13.0 8 | hooks: 9 | - id: isort 10 | - repo: https://github.com/pre-commit/mirrors-yapf 11 | rev: v0.32.0 12 | hooks: 13 | - id: yapf 14 | - repo: https://github.com/pre-commit/pre-commit-hooks 15 | rev: v4.3.0 16 | hooks: 17 | - id: trailing-whitespace 18 | - id: check-yaml 19 | - id: end-of-file-fixer 20 | - id: requirements-txt-fixer 21 | - id: double-quote-string-fixer 22 | - id: check-merge-conflict 23 | - id: fix-encoding-pragma 24 | args: ["--remove"] 25 | - id: mixed-line-ending 26 | args: ["--fix=lf"] 27 | - repo: https://github.com/codespell-project/codespell 28 | rev: v2.2.1 29 | hooks: 30 | - id: codespell 31 | - repo: https://github.com/PyCQA/docformatter 32 | rev: v1.3.1 33 | hooks: 34 | - id: docformatter 35 | args: ["--in-place", "--wrap-descriptions", "79"] 36 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/prior_generators/utils.py: -------------------------------------------------------------------------------- 1 | def anchor_inside_flags(flat_anchors, 2 | valid_flags, 3 | img_shape, 4 | allowed_border=0): 5 | """Check whether the anchors are inside the border. 6 | 7 | Args: 8 | flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4). 9 | valid_flags (torch.Tensor): An existing valid flags of anchors. 10 | img_shape (tuple(int)): Shape of current image. 11 | allowed_border (int, optional): The border to allow the valid anchor. 12 | Defaults to 0. 13 | 14 | Returns: 15 | torch.Tensor: Flags indicating whether the anchors are inside a \ 16 | valid range. 17 | """ 18 | img_h, img_w = img_shape[:2] 19 | if allowed_border >= 0: 20 | inside_flags = valid_flags & \ 21 | (flat_anchors[:, 0] >= -allowed_border) & \ 22 | (flat_anchors[:, 1] >= -allowed_border) & \ 23 | (flat_anchors[:, 2] < img_w + allowed_border) & \ 24 | (flat_anchors[:, 3] < img_h + allowed_border) 25 | else: 26 | inside_flags = valid_flags 27 | return inside_flags 28 | -------------------------------------------------------------------------------- /jittordet/utils/util_random.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab. mmdet/utils/util_random.py 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | """Helpers for random number generators.""" 4 | import numpy as np 5 | 6 | 7 | def ensure_rng(rng=None): 8 | """Coerces input into a random number generator. 9 | 10 | If the input is None, then a global random state is returned. 11 | 12 | If the input is a numeric value, then that is used as a seed to construct a 13 | random state. Otherwise the input is returned as-is. 14 | 15 | Adapted from [1]_. 16 | 17 | Args: 18 | rng (int | numpy.random.RandomState | None): 19 | if None, then defaults to the global rng. Otherwise this can be an 20 | integer or a RandomState class 21 | Returns: 22 | (numpy.random.RandomState) : rng - 23 | a numpy random number generator 24 | 25 | References: 26 | .. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501 27 | """ 28 | 29 | if rng is None: 30 | rng = np.random.mtrand._rand 31 | elif isinstance(rng, int): 32 | rng = np.random.RandomState(rng) 33 | else: 34 | rng = rng 35 | return rng 36 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | 4 | from jittordet.engine import Runner, load_cfg 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description='Test a detector') 9 | parser.add_argument('config', help='train config file path') 10 | parser.add_argument('checkpoint', help='checkpoint to load from') 11 | parser.add_argument('--work-dir', help='the dir to save logs') 12 | parser.add_argument( 13 | '--disable-cuda', 14 | action='store_true', 15 | help='disable cuda and use cpu to train net.') 16 | return parser.parse_args() 17 | 18 | 19 | def main(): 20 | args = parse_args() 21 | cfg = load_cfg(args.config) 22 | 23 | cfg.load_from = args.checkpoint 24 | if args.work_dir is not None: 25 | cfg.work_dir = args.work_dir 26 | elif cfg.get('work_dir', None) is None: 27 | # use config filename as default work_dir if cfg.work_dir is None 28 | cfg.work_dir = osp.join('./work_dirs', 29 | osp.splitext(osp.basename(args.config))[0]) 30 | # set disable cuda 31 | cfg.disable_cuda = args.disable_cuda 32 | 33 | runner = Runner.from_cfg(cfg) 34 | runner.test() 35 | 36 | 37 | if __name__ == '__main__': 38 | main() 39 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | 4 | from jittordet.engine import Runner, load_cfg 5 | 6 | 7 | def parse_args(): 8 | parser = argparse.ArgumentParser(description='Train a detector') 9 | parser.add_argument('config', help='train config file path') 10 | parser.add_argument('--work-dir', help='the dir to save logs and models') 11 | parser.add_argument('--seed', type=int, help='random seed for traning') 12 | parser.add_argument( 13 | '--disable-cuda', 14 | action='store_true', 15 | help='disable cuda and use cpu to train net.') 16 | return parser.parse_args() 17 | 18 | 19 | def main(): 20 | args = parse_args() 21 | cfg = load_cfg(args.config) 22 | 23 | if args.work_dir is not None: 24 | cfg.work_dir = args.work_dir 25 | elif cfg.get('work_dir', None) is None: 26 | # use config filename as default work_dir if cfg.work_dir is None 27 | cfg.work_dir = osp.join('./work_dirs', 28 | osp.splitext(osp.basename(args.config))[0]) 29 | # set seed 30 | if args.seed is not None: 31 | cfg.seed = args.seed 32 | # set disable cuda 33 | cfg.disable_cuda = args.disable_cuda 34 | 35 | runner = Runner.from_cfg(cfg) 36 | runner.train() 37 | 38 | 39 | if __name__ == '__main__': 40 | main() 41 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | def readme(): 5 | with open('README.md', encoding='utf-8') as f: 6 | content = f.read() 7 | return content 8 | 9 | 10 | def get_version(): 11 | with open('jittordet/version.py', 'r') as f: 12 | exec(compile(f.read(), 'jittordet/version.py', 'exec')) 13 | return locals()['__version__'] 14 | 15 | 16 | def get_requires(): 17 | with open('requirements.txt', encoding='utf-8') as f: 18 | requires = [line for line in f] 19 | return requires 20 | 21 | 22 | if __name__ == '__main__': 23 | setup( 24 | name='jittordet', 25 | version=get_version(), 26 | description='An object detection codebase based on Jittor.', 27 | long_description=readme(), 28 | long_description_content_type='text/markdown', 29 | packages=find_packages(exclude=('configs', 'tools')), 30 | classifiers=[ 31 | 'Development Status :: 3 - Alpha', 32 | 'License :: OSI Approved :: Apache Software License', 33 | 'Operating System :: OS Independent', 34 | 'Programming Language :: Python :: 3', 35 | 'Programming Language :: Python :: 3.8', 36 | 'Programming Language :: Python :: 3.9', 37 | ], 38 | license='Apache License 2.0', 39 | install_requires=get_requires()) 40 | -------------------------------------------------------------------------------- /jittordet/engine/config/utils.py: -------------------------------------------------------------------------------- 1 | from collections import Mapping 2 | 3 | 4 | def iter_leaves(obj): 5 | """A generator to visit all leaves of obj.""" 6 | if isinstance(obj, Mapping): 7 | for key, value in obj.items(): 8 | for k, v in iter_leaves(value): 9 | k.insert(0, key) 10 | yield (k, v) 11 | elif isinstance(obj, (list, tuple)): 12 | for i, value in enumerate(obj): 13 | for k, v in iter_leaves(value): 14 | k.insert(0, i) 15 | yield (k, v) 16 | else: 17 | yield [], obj 18 | 19 | 20 | def set_leaf(obj, keys, value): 21 | if isinstance(keys, str): 22 | keys = keys.split('.') 23 | for k in keys[:-1]: 24 | obj = obj[k] 25 | obj[keys[-1]] = value 26 | 27 | 28 | def delete_node(obj, keys): 29 | if isinstance(keys, (tuple, list)): 30 | for key in keys[:-1]: 31 | obj = obj[key] 32 | del obj[keys[-1]] 33 | else: 34 | assert isinstance(keys, str), 'only support search str node' 35 | if isinstance(obj, Mapping): 36 | if keys in obj: 37 | del obj[keys] 38 | for value in obj.values(): 39 | delete_node(value, keys) 40 | elif isinstance(obj, (tuple, list)): 41 | for value in obj: 42 | delete_node(value, keys) 43 | -------------------------------------------------------------------------------- /jittordet/engine/loops/val_loop.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | 3 | from ..register import LOOPS 4 | from .base_loop import BaseLoop 5 | 6 | 7 | @LOOPS.register_module() 8 | class ValLoop(BaseLoop): 9 | 10 | def run(self): 11 | """Launch validation.""" 12 | self.runner.call_hook('before_val') 13 | self.runner.call_hook('before_val_epoch') 14 | self.runner.model.eval() 15 | self._iter = 0 16 | 17 | for idx, data_batch in enumerate(self.runner.val_dataset): 18 | self.run_iter(idx, data_batch) 19 | 20 | # compute metrics 21 | metrics = self.runner.val_evaluator.evaluate(self.runner.val_dataset, 22 | self.runner.logger) 23 | 24 | self.runner.call_hook('after_val_epoch', metrics=metrics) 25 | self.runner.call_hook('after_val') 26 | 27 | @jt.no_grad() 28 | def run_iter(self, idx, data_batch): 29 | self.runner.call_hook( 30 | 'before_val_iter', batch_idx=idx, data_batch=data_batch) 31 | 32 | outputs = self.runner.model(data_batch, phase='predict') 33 | self.runner.val_evaluator.process(self.runner.val_dataset, outputs) 34 | 35 | self.runner.call_hook( 36 | 'after_val_iter', 37 | batch_idx=idx, 38 | data_batch=data_batch, 39 | outputs=outputs) 40 | self._iter += 1 41 | -------------------------------------------------------------------------------- /jittordet/engine/loops/test_loop.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | 3 | from ..register import LOOPS 4 | from .base_loop import BaseLoop 5 | 6 | 7 | @LOOPS.register_module() 8 | class TestLoop(BaseLoop): 9 | 10 | def run(self): 11 | """Launch validation.""" 12 | self.runner.call_hook('before_test') 13 | self.runner.call_hook('before_test_epoch') 14 | self.runner.model.eval() 15 | self._iter = 0 16 | 17 | for idx, data_batch in enumerate(self.runner.test_dataset): 18 | self.run_iter(idx, data_batch) 19 | 20 | # compute metrics 21 | metrics = self.runner.test_evaluator.evaluate(self.runner.test_dataset, 22 | self.runner.logger) 23 | 24 | self.runner.call_hook('after_test_epoch', metrics=metrics) 25 | self.runner.call_hook('after_test') 26 | 27 | @jt.no_grad() 28 | def run_iter(self, idx, data_batch): 29 | self.runner.call_hook( 30 | 'before_test_iter', batch_idx=idx, data_batch=data_batch) 31 | 32 | outputs = self.runner.model(data_batch, phase='predict') 33 | self.runner.test_evaluator.process(self.runner.test_dataset, outputs) 34 | 35 | self.runner.call_hook( 36 | 'after_test_iter', 37 | batch_idx=idx, 38 | data_batch=data_batch, 39 | outputs=outputs) 40 | self._iter += 1 41 | -------------------------------------------------------------------------------- /jittordet/models/losses/pkd_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import jittor.nn as nn 3 | 4 | from jittordet.engine import MODELS 5 | from .utils import weighted_loss 6 | 7 | 8 | def norm(feat): 9 | assert len(feat.shape) == 4 10 | N, C, H, W = feat.shape 11 | feat = feat.permute(1, 0, 2, 3).reshape(C, -1) 12 | mean = feat.mean(dim=-1, keepdims=True) 13 | std = feat.std() 14 | feat = (feat - mean) / (std + 1e-6) 15 | return feat.reshape(C, N, H, W).permute(1, 0, 2, 3) 16 | 17 | 18 | @weighted_loss 19 | def pkd_loss(pred, target): 20 | pred = norm(pred) 21 | target = norm(target) 22 | return (pred - target).sqr() / 2 23 | 24 | 25 | @MODELS.register_module() 26 | class PKDLoss(nn.Module): 27 | 28 | def __init__(self, reduction='mean', loss_weight=1.0): 29 | super(PKDLoss, self).__init__() 30 | self.reduction = reduction 31 | self.loss_weight = loss_weight 32 | 33 | def execute(self, 34 | pred, 35 | target, 36 | weight=None, 37 | avg_factor=None, 38 | reduction_override=None): 39 | assert reduction_override in (None, 'none', 'mean', 'sum') 40 | reduction = ( 41 | reduction_override if reduction_override else self.reduction) 42 | loss = self.loss_weight * pkd_loss( 43 | pred, target, weight, reduction=reduction, avg_factor=avg_factor) 44 | return loss 45 | -------------------------------------------------------------------------------- /jittordet/utils/types.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | 3 | 4 | def is_seq_of(seq, expected_type, seq_type=None): 5 | """Check whether it is a sequence of some type. 6 | 7 | Args: 8 | seq (Sequence): The sequence to be checked. 9 | expected_type (type or tuple): Expected type of sequence items. 10 | seq_type (type, optional): Expected sequence type. Defaults to None. 11 | 12 | Returns: 13 | bool: Return True if ``seq`` is valid else False. 14 | 15 | Examples: 16 | >>> from mmengine.utils import is_seq_of 17 | >>> seq = ['a', 'b', 'c'] 18 | >>> is_seq_of(seq, str) 19 | True 20 | >>> is_seq_of(seq, int) 21 | False 22 | """ 23 | if seq_type is None: 24 | exp_seq_type = abc.Sequence 25 | else: 26 | assert isinstance(seq_type, type) 27 | exp_seq_type = seq_type 28 | if not isinstance(seq, exp_seq_type): 29 | return False 30 | for item in seq: 31 | if not isinstance(item, expected_type): 32 | return False 33 | return True 34 | 35 | 36 | def is_list_of(seq, expected_type): 37 | """Check whether it is a list of some type. 38 | 39 | A partial method of :func:`is_seq_of`. 40 | """ 41 | return is_seq_of(seq, expected_type, seq_type=list) 42 | 43 | 44 | def is_tuple_of(seq, expected_type): 45 | """Check whether it is a tuple of some type. 46 | 47 | A partial method of :func:`is_seq_of`. 48 | """ 49 | return is_seq_of(seq, expected_type, seq_type=tuple) 50 | -------------------------------------------------------------------------------- /configs/_dataset_/coco_detection.yml: -------------------------------------------------------------------------------- 1 | data_root: $DATA_ROOT:data/coco/ 2 | num_gpus: $NUM_GPUS:8 3 | 4 | train_dataset: 5 | type: CocoDataset 6 | batch_size: <2 * num_gpus> 7 | num_workers: <1 * num_gpus> 8 | data_root: 9 | data_path: 10 | ann_file: annotations/instances_train2017.json 11 | img_path: train2017 12 | filter_cfg: 13 | filter_empty_gt: true 14 | min_size: 32 15 | batch_sampler: 16 | type: AspectRatioBatchSampler 17 | transforms: 18 | - type: 'LoadImageFromFile' 19 | - type: 'LoadAnnotations' 20 | with_bbox: true 21 | - type: 'Resize' 22 | scale: (1333, 800) 23 | keep_ratio: true 24 | - type: 'RandomFlip' 25 | prob: 0.5 26 | - type: 'PackDetInputs' 27 | 28 | val_dataset: &val_dataset 29 | type: CocoDataset 30 | batch_size: <1 * num_gpus> 31 | num_workers: <1 * num_gpus> 32 | data_root: 33 | data_path: 34 | ann_file: annotations/instances_val2017.json 35 | img_path: val2017 36 | test_mode: true 37 | batch_sampler: 38 | type: PadBatchSampler 39 | transforms: 40 | - type: 'LoadImageFromFile' 41 | - type: 'LoadAnnotations' 42 | with_bbox: true 43 | - type: 'Resize' 44 | scale: (1333, 800) 45 | keep_ratio: true 46 | - type: 'PackDetInputs' 47 | meta_keys: [img_id, img_path, ori_shape, img_shape, 'scale_factor', 'sample_idx'] 48 | 49 | test_dataset: *val_dataset 50 | 51 | val_evaluator: &val_evaluator 52 | type: CocoEvaluator 53 | ann_file: 54 | metric: bbox 55 | format_only: false 56 | test_evaluator: *val_evaluator 57 | -------------------------------------------------------------------------------- /jittordet/engine/logger.py: -------------------------------------------------------------------------------- 1 | """Modified from mmcv.utils.logging.""" 2 | import logging 3 | 4 | import jittor as jt 5 | 6 | __all__ = ['get_logger', 'print_log'] 7 | initialized_loggers = set() 8 | 9 | 10 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 11 | logger = logging.getLogger(name) 12 | if name in initialized_loggers: 13 | return logger 14 | 15 | stream_handler = logging.StreamHandler() 16 | handlers = [stream_handler] 17 | 18 | if jt.rank == 0 and log_file is not None: 19 | file_handler = logging.FileHandler(log_file, file_mode) 20 | handlers.append(file_handler) 21 | 22 | formatter = logging.Formatter( 23 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 24 | for handler in handlers: 25 | handler.setFormatter(formatter) 26 | handler.setLevel(log_level) 27 | logger.addHandler(handler) 28 | 29 | if jt.rank == 0: 30 | logger.setLevel(log_level) 31 | else: 32 | logger.setLevel(logging.ERROR) 33 | 34 | initialized_loggers.add(name) 35 | 36 | return logger 37 | 38 | 39 | def print_log(msg, logger=None, level=logging.INFO): 40 | if logger is None: 41 | print(msg) 42 | elif isinstance(logger, logging.Logger): 43 | logger.log(level, msg) 44 | elif logger == 'silent': 45 | pass 46 | elif isinstance(logger, str): 47 | _logger = get_logger(logger) 48 | _logger.log(level, msg) 49 | else: 50 | raise TypeError( 51 | 'logger should be either a logging.Logger object, str, ' 52 | f'"silent" or None, but got {type(logger)}') 53 | -------------------------------------------------------------------------------- /jittordet/engine/config/readers.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import os.path as osp 4 | import sys 5 | import types 6 | from importlib import import_module 7 | 8 | import yaml 9 | 10 | 11 | def yaml_reader(filepath): 12 | with open(filepath, 'r') as f: 13 | cfg = yaml.load(f, Loader=yaml.Loader) 14 | return cfg 15 | 16 | 17 | def json_reader(filepath): 18 | with open(filepath, 'r') as f: 19 | cfg = json.load(f) 20 | return cfg 21 | 22 | 23 | def python_reader(filepath): 24 | """Reader python type config. 25 | 26 | Refer to mmcv.utils.config. 27 | """ 28 | # validate python syntax 29 | with open(filepath, 'r', encoding='utf-8') as f: 30 | # Setting encoding explicitly to resolve coding issue on windows 31 | content = f.read() 32 | try: 33 | ast.parse(content) 34 | except SyntaxError as e: 35 | raise SyntaxError('There are syntax errors in config ' 36 | f'file {filepath}: {e}') 37 | 38 | filepath = osp.splitext(filepath)[0] 39 | dir_name = osp.dirname(filepath) 40 | module_name = osp.basename(filepath) 41 | sys.path.insert(0, dir_name) 42 | mod = import_module(module_name) 43 | sys.path.pop(0) 44 | cfg = { 45 | name: value 46 | for name, value in mod.__dict__.items() if not name.startswith('__') 47 | and not isinstance(value, types.ModuleType) 48 | and not isinstance(value, types.FunctionType) 49 | } 50 | # delete imported module 51 | del sys.modules[module_name] 52 | return cfg 53 | 54 | 55 | cfg_readers = { 56 | '.yml': yaml_reader, 57 | '.ymal': yaml_reader, 58 | '.json': json_reader, 59 | '.py': python_reader 60 | } 61 | -------------------------------------------------------------------------------- /jittordet/models/frameworks/kd_stage.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab. mmdet/models/detectors/single_stage.py 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | from pathlib import Path 4 | from typing import Optional, Union 5 | 6 | from jittordet.engine import MODELS, ConfigType, OptConfigType, load_cfg 7 | from .single_stage import SingleStageFramework 8 | 9 | 10 | @MODELS.register_module() 11 | class KDSingleStageFramework(SingleStageFramework): 12 | 13 | def __init__(self, 14 | *args, 15 | teacher_config: Union[ConfigType, str, Path], 16 | teacher_ckpt: Optional[str] = None, 17 | eval_teacher: bool = True, 18 | kd_cfg: OptConfigType = None, 19 | **kwargs) -> None: 20 | super().__init__(*args, **kwargs) 21 | 22 | if isinstance(teacher_config, (str, Path)): 23 | teacher_config = load_cfg(teacher_config) 24 | self.teacher = MODELS.build(teacher_config['model']) 25 | if teacher_ckpt is not None: 26 | self.teacher.load(teacher_ckpt) 27 | if eval_teacher: 28 | self.freeze(self.teacher) 29 | self.loss_cls_kd = MODELS.build(kd_cfg['loss_cls_kd']) 30 | self.loss_reg_kd = MODELS.build(kd_cfg['loss_reg_kd']) 31 | self.with_feat_distill = False 32 | if kd_cfg.get('loss_feat_kd', None): 33 | self.loss_feat_kd = MODELS.build(kd_cfg['loss_feat_kd']) 34 | self.with_feat_distill = True 35 | self.reused_teacher_head_idx = kd_cfg['reused_teacher_head_idx'] 36 | 37 | @staticmethod 38 | def freeze(model): 39 | """Freeze the model.""" 40 | model.eval() 41 | for param in model.parameters(): 42 | param.requires_grad = False 43 | -------------------------------------------------------------------------------- /configs/retinanet/_model_/retinanet_r50_fpn.yml: -------------------------------------------------------------------------------- 1 | model: 2 | type: SingleStageFramework 3 | preprocessor: 4 | type: Preprocessor 5 | mean: [123.675, 116.28, 103.53] 6 | std: [58.395, 57.12, 57.375] 7 | bgr_to_rgb: true 8 | pad_size_divisor: 32 9 | backbone: 10 | type: ResNet 11 | depth: 50 12 | frozen_stages: 1 13 | norm_eval: true 14 | return_stages: ['layer1', 'layer2', 'layer3', 'layer4'] 15 | pretrained: 'jittorhub://resnet50.pkl' 16 | neck: 17 | type: FPN 18 | in_channels: [256, 512, 1024, 2048] 19 | out_channels: 256 20 | start_level: 1 21 | add_extra_convs: on_input 22 | num_outs: 5 23 | bbox_head: 24 | type: RetinaHead 25 | num_classes: 80 26 | in_channels: 256 27 | stacked_convs: 4 28 | feat_channels: 256 29 | anchor_generator: 30 | type: 'AnchorGenerator' 31 | octave_base_scale: 4 32 | scales_per_octave: 3 33 | ratios: [0.5, 1.0, 2.0] 34 | strides: [8, 16, 32, 64, 128] 35 | bbox_coder: 36 | type: DeltaXYWHBBoxCoder 37 | target_means: [.0, .0, .0, .0] 38 | target_stds: [1.0, 1.0, 1.0, 1.0] 39 | loss_cls: 40 | type: FocalLoss 41 | use_sigmoid: true 42 | gamma: 2.0 43 | alpha: 0.25 44 | loss_weight: 1.0 45 | loss_bbox: 46 | type: L1Loss 47 | loss_weight: 1.0 48 | train_cfg: 49 | assigner: 50 | type: 'MaxIoUAssigner' 51 | pos_iou_thr: 0.5 52 | neg_iou_thr: 0.4 53 | min_pos_iou: 0 54 | ignore_iof_thr: -1 55 | sampler: 56 | type: PseudoSampler 57 | allowed_border: -1 58 | pos_weight: -1 59 | test_cfg: 60 | num_pre: 1000 61 | min_bbox_size: 0 62 | score_thr: 0.05 63 | nms: 64 | type: nms 65 | thresh: 0.5 66 | max_per_img: 100 67 | -------------------------------------------------------------------------------- /configs/rpn/_model_/rpn_r50_fpn.yml: -------------------------------------------------------------------------------- 1 | model: 2 | type: RPNFramework 3 | preprocessor: 4 | type: Preprocessor 5 | mean: [123.675, 116.28, 103.53] 6 | std: [58.395, 57.12, 57.375] 7 | bgr_to_rgb: true 8 | pad_size_divisor: 32 9 | backbone: 10 | type: ResNet 11 | depth: 50 12 | frozen_stages: 1 13 | norm_eval: true 14 | return_stages: ['layer1', 'layer2', 'layer3', 'layer4'] 15 | pretrained: 'jittorhub://resnet50.pkl' 16 | neck: 17 | type: FPN 18 | in_channels: [256, 512, 1024, 2048] 19 | out_channels: 256 20 | num_outs: 5 21 | rpn_head: 22 | type: RPNHead 23 | num_classes: 1 24 | in_channels: 256 25 | feat_channels: 256 26 | anchor_generator: 27 | type: AnchorGenerator 28 | scales: [8] 29 | ratios: [0.5, 1.0, 2.0] 30 | strides: [4, 8, 16, 32, 64] 31 | bbox_coder: 32 | type: DeltaXYWHBBoxCoder 33 | target_means: [.0, .0, .0, .0] 34 | target_stds: [1.0, 1.0, 1.0, 1.0] 35 | loss_cls: 36 | type: CrossEntropyLoss 37 | use_sigmoid: true 38 | loss_weight: 1.0 39 | loss_bbox: 40 | type: L1Loss 41 | loss_weight: 1.0 42 | train_cfg: 43 | rpn: 44 | assigner: 45 | type: MaxIoUAssigner 46 | pos_iou_thr: 0.7 47 | neg_iou_thr: 0.3 48 | min_pos_iou: 0.3 49 | match_low_quality: true 50 | ignore_iof_thr: -1 51 | sampler: 52 | type: RandomSampler 53 | num: 256 54 | pos_fraction: 0.5 55 | neg_pos_ub: -1 56 | add_gt_as_proposals: false 57 | allowed_border: -1 58 | pos_weight: -1 59 | test_cfg: 60 | rpn: 61 | nms_pre: 1000 62 | max_per_img: 1000 63 | nms: 64 | type: nms 65 | thresh: 0.7 66 | min_bbox_size: 0 67 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/assigners/iou2d_calculator.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab. 2 | # mmdet/models/task_modules/assigners/iou2d_calculator.py 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | from jittordet.engine import TASK_UTILS 5 | from jittordet.utils import bbox_overlaps 6 | 7 | 8 | @TASK_UTILS.register_module() 9 | class BboxOverlaps2D: 10 | """2D Overlaps (e.g. IoUs, GIoUs) Calculator.""" 11 | 12 | def __call__(self, bboxes1, bboxes2, mode='iou', is_aligned=False): 13 | """Calculate IoU between 2D bboxes. 14 | 15 | Args: 16 | bboxes1 (Tensor or :obj:`BaseBoxes`): bboxes have shape (m, 4) 17 | in format, or shape (m, 5) in format. 19 | bboxes2 (Tensor or :obj:`BaseBoxes`): bboxes have shape (m, 4) 20 | in format, shape (m, 5) in format, or be empty. If ``is_aligned `` is ``True``, 22 | then m and n must be equal. 23 | mode (str): "iou" (intersection over union), "iof" (intersection 24 | over foreground), or "giou" (generalized intersection over 25 | union). 26 | is_aligned (bool, optional): If True, then m and n must be equal. 27 | Default False. 28 | 29 | Returns: 30 | Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,) 31 | """ 32 | assert bboxes1.size(-1) in [0, 4, 5] 33 | assert bboxes2.size(-1) in [0, 4, 5] 34 | if bboxes2.size(-1) == 5: 35 | bboxes2 = bboxes2[..., :4] 36 | if bboxes1.size(-1) == 5: 37 | bboxes1 = bboxes1[..., :4] 38 | 39 | return bbox_overlaps(bboxes1, bboxes2, mode, is_aligned) 40 | -------------------------------------------------------------------------------- /jittordet/models/backbones/MSFA.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import jittor as jt 9 | from jittor import nn 10 | from jittordet.engine import MODELS 11 | from kymatio.torch import Scattering2D 12 | 13 | @MODELS.register_module() 14 | class MSFA(nn.Module): 15 | def __init__(self, backbone, use_sar=True, use_wavelet=False, input_size=(800,800)): 16 | self.use_sar = use_sar 17 | self.use_wavelet = use_wavelet 18 | self.input_size=input_size 19 | if use_sar and not use_wavelet : 20 | self.in_channels = 3 21 | elif use_sar and use_wavelet: 22 | self.in_channels = 1 23 | elif not use_sar: 24 | self.in_channels = 0 25 | if use_wavelet: 26 | self.in_channels += 81 27 | self.wavelet_trans = Scattering2D(J=2, shape=self.input_size) 28 | backbone['in_channels'] = self.in_channels 29 | self.backbone = MODELS.build(backbone) 30 | def execute(self, x): 31 | xs = [] 32 | if self.use_sar and not self.use_wavelet: 33 | return self.backbone(x) 34 | x_ = x.mean(1,keepdim=True) 35 | with jt.no_grad(): 36 | if self.use_sar and self.use_wavelet: 37 | xs.append(x_) 38 | if self.use_wavelet: 39 | out = nn.functional.interpolate(self.wavelet_trans(x_).squeeze(1), self.input_size, mode='bilinear') 40 | xs.append(out) 41 | x = jt.cat(xs,1) 42 | x = self.backbone(x) 43 | return x 44 | 45 | 46 | def init_weights(self): 47 | super(MSFA, self).init_weights() 48 | 49 | 50 | -------------------------------------------------------------------------------- /jittordet/engine/hooks/checkpoint_hook.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from ..register import HOOKS 4 | from .base_hook import BaseHook 5 | 6 | 7 | @HOOKS.register_module() 8 | class CheckpointHook(BaseHook): 9 | 10 | def __init__(self, interval=1, by_iter=False): 11 | assert isinstance(interval, (int, list)) 12 | self.interval = interval 13 | self.by_iter = by_iter 14 | 15 | def after_train_epoch(self, runner): 16 | 17 | if self.by_iter: 18 | return None 19 | cur_epoch = runner.train_loop.cur_epoch 20 | 21 | if isinstance(self.interval, int): 22 | save_ckpt = self.every_n_interval(cur_epoch, self.interval) 23 | else: 24 | save_ckpt = cur_epoch in self.interval 25 | 26 | if save_ckpt: 27 | ckpt_filepath = osp.join(runner.log_dir, 28 | f'epoch_{cur_epoch + 1}.pkl') 29 | runner.save_checkpoint(ckpt_filepath) 30 | runner.logger.info(f'save checkpoint to {ckpt_filepath}') 31 | 32 | def after_train_iter(self, 33 | runner, 34 | batch_idx, 35 | data_batch=None, 36 | outputs=None): 37 | if not self.by_iter: 38 | return None 39 | cur_iter = runner.train_loop.cur_iter 40 | 41 | if isinstance(self.interval, int): 42 | save_ckpt = self.every_n_interval(cur_iter, self.interval) 43 | else: 44 | save_ckpt = cur_iter in self.interval 45 | 46 | if save_ckpt: 47 | ckpt_filepath = osp.join(runner.log_dir, 48 | f'iter_{cur_iter + 1}.pkl') 49 | runner.save_checkpoint(ckpt_filepath) 50 | runner.logger.info(f'save checkpoint to {ckpt_filepath}') 51 | -------------------------------------------------------------------------------- /jittordet/datasets/samplers/pad_batch_sampler.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | 3 | import jittor as jt 4 | import numpy as np 5 | 6 | from jittordet.engine import BATCH_SAMPLERS 7 | from .base_batch_sampler import BaseBatchSampler 8 | 9 | 10 | @BATCH_SAMPLERS.register_module() 11 | class PadBatchSampler(BaseBatchSampler): 12 | 13 | def __init__(self, dataset, shuffle=False, drop_last=False): 14 | super().__init__(dataset=dataset) 15 | self.shuffle = shuffle 16 | self.drop_last = drop_last 17 | 18 | def get_index_list(self, rng=None): 19 | if rng is None: 20 | rng = np.random.default_rng() 21 | 22 | index = rng.permutation(self.num_data_list) if self.shuffle \ 23 | else np.arange(self.num_data_list) 24 | 25 | mod_size = self.num_data_list % self.total_bs 26 | if mod_size != 0: 27 | if self.drop_last: 28 | index = index[:-mod_size] 29 | else: 30 | padded_size = int( 31 | ceil(self.num_data_list / self.total_bs) * self.total_bs) 32 | # repeat index to avoid data_list is shorter than batch size 33 | repeat_num = int((self.total_bs - 0.5) // self.num_data_list + 34 | 1) 35 | repeat_num = max(repeat_num, 2) 36 | index = np.concatenate([index] * repeat_num) 37 | index = index[:padded_size] 38 | 39 | if jt.in_mpi: 40 | rank, world_size = jt.rank, jt.world_size 41 | real_bs = int(self.total_bs // world_size) 42 | assert real_bs * world_size == self.total_bs 43 | index = index.reshape(-1, self.total_bs) 44 | index = index[:, rank * real_bs:(rank + 1) * real_bs] 45 | index = index.flatten() 46 | 47 | return index 48 | -------------------------------------------------------------------------------- /configs/gfl/_model_/gfl_r50_fpn.yml: -------------------------------------------------------------------------------- 1 | model: 2 | type: SingleStageFramework 3 | preprocessor: 4 | type: Preprocessor 5 | mean: [123.675, 116.28, 103.53] 6 | std: [58.395, 57.12, 57.375] 7 | bgr_to_rgb: true 8 | pad_size_divisor: 32 9 | backbone: 10 | type: ResNet 11 | depth: 50 12 | frozen_stages: 1 13 | norm_eval: true 14 | return_stages: ['layer1', 'layer2', 'layer3', 'layer4'] 15 | pretrained: 'jittorhub://resnet50.pkl' 16 | neck: 17 | type: FPN 18 | in_channels: [256, 512, 1024, 2048] 19 | out_channels: 256 20 | start_level: 1 21 | add_extra_convs: on_input 22 | num_outs: 5 23 | bbox_head: 24 | type: GFLHead 25 | num_classes: 80 26 | in_channels: 256 27 | stacked_convs: 4 28 | feat_channels: 256 29 | anchor_generator: 30 | type: AnchorGenerator 31 | ratios: [1.0] 32 | octave_base_scale: 8 33 | scales_per_octave: 1 34 | strides: [8, 16, 32, 64, 128] 35 | loss_cls: 36 | type: QualityFocalLoss 37 | use_sigmoid: true 38 | beta: 2.0 39 | loss_weight: 1.0 40 | loss_dfl: 41 | type: DistributionFocalLoss 42 | loss_weight: 0.25 43 | reg_max: 16 44 | loss_bbox: 45 | type: GIoULoss 46 | loss_weight: 2.0 47 | train_cfg: 48 | assigner: 49 | type: ATSSAssigner 50 | topk: 9 51 | allowed_border: -1 52 | pos_weight: -1 53 | debug: false 54 | test_cfg: 55 | nms_pre: 1000 56 | min_bbox_size: 0 57 | score_thr: 0.05 58 | nms: 59 | type: nms 60 | iou_threshold: 0.6 61 | max_per_img: 100 62 | -------------------------------------------------------------------------------- /jittordet/structures/det_data_sample.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab mmdet/structures/det_data_sample.py 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | from typing import List, Optional 4 | 5 | from .base_data_element import BaseDataElement 6 | from .instance_data import InstanceData 7 | 8 | 9 | class DetDataSample(BaseDataElement): 10 | 11 | @property 12 | def proposals(self) -> InstanceData: 13 | return self._proposals 14 | 15 | @proposals.setter 16 | def proposals(self, value: InstanceData): 17 | self.set_field(value, '_proposals', dtype=InstanceData) 18 | 19 | @proposals.deleter 20 | def proposals(self): 21 | del self._proposals 22 | 23 | @property 24 | def gt_instances(self) -> InstanceData: 25 | return self._gt_instances 26 | 27 | @gt_instances.setter 28 | def gt_instances(self, value: InstanceData): 29 | self.set_field(value, '_gt_instances', dtype=InstanceData) 30 | 31 | @gt_instances.deleter 32 | def gt_instances(self): 33 | del self._gt_instances 34 | 35 | @property 36 | def pred_instances(self) -> InstanceData: 37 | return self._pred_instances 38 | 39 | @pred_instances.setter 40 | def pred_instances(self, value: InstanceData): 41 | self.set_field(value, '_pred_instances', dtype=InstanceData) 42 | 43 | @pred_instances.deleter 44 | def pred_instances(self): 45 | del self._pred_instances 46 | 47 | @property 48 | def ignored_instances(self) -> InstanceData: 49 | return self._ignored_instances 50 | 51 | @ignored_instances.setter 52 | def ignored_instances(self, value: InstanceData): 53 | self.set_field(value, '_ignored_instances', dtype=InstanceData) 54 | 55 | @ignored_instances.deleter 56 | def ignored_instances(self): 57 | del self._ignored_instances 58 | 59 | 60 | SampleList = List[DetDataSample] 61 | OptSampleList = Optional[SampleList] 62 | -------------------------------------------------------------------------------- /jittordet/engine/loops/train_loop.py: -------------------------------------------------------------------------------- 1 | from ..register import LOOPS 2 | from .base_loop import BaseLoop 3 | 4 | 5 | @LOOPS.register_module() 6 | class EpochTrainLoop(BaseLoop): 7 | 8 | def __init__(self, runner, max_epoch, val_interval=1): 9 | super().__init__(runner=runner) 10 | self.val_interval = val_interval 11 | self._max_epoch = max_epoch 12 | 13 | def run(self): 14 | self.runner.call_hook('before_train') 15 | 16 | while self._epoch < self._max_epoch: 17 | self.run_epoch() 18 | 19 | if self._epoch % self.val_interval == 0: 20 | if self.runner.val_loop is not None: 21 | self.runner.val_loop.run() 22 | 23 | self.runner.call_hook('after_train') 24 | 25 | def run_epoch(self): 26 | """Iterate one epoch.""" 27 | self.runner.call_hook('before_train_epoch') 28 | 29 | self.runner.model.train() 30 | for idx, data_batch in enumerate(self.runner.train_dataset): 31 | self.run_iter(idx, data_batch) 32 | 33 | for _scheduler in self.runner.scheduler: 34 | if not getattr(_scheduler, 'by_iter', False): 35 | _scheduler.step() 36 | 37 | self.runner.call_hook('after_train_epoch') 38 | self._epoch += 1 39 | 40 | def run_iter(self, idx, data_batch): 41 | """Iterate one min-batch.""" 42 | self.runner.call_hook( 43 | 'before_train_iter', batch_idx=idx, data_batch=data_batch) 44 | 45 | loss, loss_vars = self.runner.model(data_batch, phase='loss') 46 | self.runner.optimizer.step(loss) 47 | for _scheduler in self.runner.scheduler: 48 | # for warmup scheduler 49 | if getattr(_scheduler, 'by_iter', False): 50 | _scheduler.step() 51 | 52 | self.runner.call_hook( 53 | 'after_train_iter', 54 | batch_idx=idx, 55 | data_batch=data_batch, 56 | outputs=loss_vars) 57 | self._iter += 1 58 | -------------------------------------------------------------------------------- /configs/_dataset_/voc0712.yml: -------------------------------------------------------------------------------- 1 | data_root: $DATA_ROOT:data/VOCdevkit/ 2 | num_gpus: $NUM_GPUS:8 3 | 4 | train_transforms: &train_transforms 5 | - type: 'LoadImageFromFile' 6 | - type: 'LoadAnnotations' 7 | with_bbox: true 8 | - type: 'Resize' 9 | scale: (1000, 600) 10 | keep_ratio: true 11 | - type: 'RandomFlip' 12 | prob: 0.5 13 | - type: 'PackDetInputs' 14 | 15 | train_dataset: 16 | type: ConcatDataset 17 | batch_size: <2 * num_gpus> 18 | num_workers: <1 * num_gpus> 19 | datasets: 20 | - type: VocDataset 21 | data_root: 22 | data_path: 23 | ann_file: 'VOC2007/ImageSets/Main/trainval.txt' 24 | img_path: 'VOC2007/JPEGImages' 25 | xml_path: 'VOC2007/Annotations' 26 | filter_cfg: 27 | filter_empty_gt: true 28 | min_size: 32 29 | transforms: *train_transforms 30 | - type: VocDataset 31 | data_root: 32 | data_path: 33 | ann_file: 'VOC2012/ImageSets/Main/trainval.txt' 34 | img_path: 'VOC2012/JPEGImages' 35 | xml_path: 'VOC2012/Annotations' 36 | filter_cfg: 37 | filter_empty_gt: true 38 | min_size: 32 39 | transforms: *train_transforms 40 | batch_sampler: 41 | type: AspectRatioBatchSampler 42 | 43 | val_dataset: &val_dataset 44 | type: VocDataset 45 | batch_size: <1 * num_gpus> 46 | num_workers: <1 * num_gpus> 47 | data_root: 48 | data_path: 49 | ann_file: 'VOC2007/ImageSets/Main/test.txt' 50 | img_path: 'VOC2007/JPEGImages' 51 | xml_path: 'VOC2007/Annotations' 52 | test_mode: true 53 | batch_sampler: 54 | type: PadBatchSampler 55 | transforms: 56 | - type: 'LoadImageFromFile' 57 | - type: 'LoadAnnotations' 58 | with_bbox: true 59 | - type: 'Resize' 60 | scale: (1000, 600) 61 | keep_ratio: true 62 | - type: 'PackDetInputs' 63 | meta_keys: [img_id, img_path, ori_shape, img_shape, 'scale_factor', 'sample_idx'] 64 | 65 | test_dataset: *val_dataset 66 | 67 | val_evaluator: &val_evaluator 68 | type: VocEvaluator 69 | test_evaluator: *val_evaluator 70 | -------------------------------------------------------------------------------- /jittordet/utils/bbox_transforms.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | 3 | 4 | def distance2bbox(points, distance, max_shape=None): 5 | """Decode distance prediction to bounding box. 6 | 7 | Args: 8 | points (Tensor): Shape (B, N, 2) or (N, 2). 9 | distance (Tensor): Distance from the given point to 4 10 | boundaries (left, top, right, bottom). Shape (B, N, 4) or (N, 4) 11 | max_shape (Union[Sequence[int], Tensor, Sequence[Sequence[int]]], 12 | optional): Maximum bounds for boxes, specifies 13 | (H, W, C) or (H, W). If priors shape is (B, N, 4), then 14 | the max_shape should be a Sequence[Sequence[int]] 15 | and the length of max_shape should also be B. 16 | 17 | Returns: 18 | Tensor: Boxes with shape (N, 4) or (B, N, 4) 19 | """ 20 | 21 | x1 = points[..., 0] - distance[..., 0] 22 | y1 = points[..., 1] - distance[..., 1] 23 | x2 = points[..., 0] + distance[..., 2] 24 | y2 = points[..., 1] + distance[..., 3] 25 | 26 | bboxes = jt.stack([x1, y1, x2, y2], -1) 27 | 28 | if max_shape is not None: 29 | bboxes[..., 0::2].clamp_(min=0, max=max_shape[1]) 30 | bboxes[..., 1::2].clamp_(min=0, max=max_shape[0]) 31 | 32 | return bboxes 33 | 34 | 35 | def bbox2distance(points, bbox, max_dis, eps: float = 0.1): 36 | """Decode bounding box based on distances. 37 | 38 | Args: 39 | points (Tensor): Shape (n, 2) or (b, n, 2), [x, y]. 40 | bbox (Tensor): Shape (n, 4) or (b, n, 4), "xyxy" format 41 | max_dis (float, optional): Upper bound of the distance. 42 | eps (float): a small value to ensure target < max_dis, instead <= 43 | 44 | Returns: 45 | Tensor: Decoded distances. 46 | """ 47 | left = points[..., 0] - bbox[..., 0] 48 | top = points[..., 1] - bbox[..., 1] 49 | right = bbox[..., 2] - points[..., 0] 50 | bottom = bbox[..., 3] - points[..., 1] 51 | if max_dis is not None: 52 | left = left.clamp(min_v=0, max_v=max_dis - eps) 53 | top = top.clamp(min_v=0, max_v=max_dis - eps) 54 | right = right.clamp(min_v=0, max_v=max_dis - eps) 55 | bottom = bottom.clamp(min_v=0, max_v=max_dis - eps) 56 | return jt.stack([left, top, right, bottom], -1) 57 | -------------------------------------------------------------------------------- /configs/_dataset_/sardet100k.yml: -------------------------------------------------------------------------------- 1 | data_root: $DATA_ROOT:data/SARDet-100K/ 2 | num_gpus: $NUM_GPUS:8 3 | 4 | train_dataset: 5 | type: Sardet100k 6 | batch_size: <2 * num_gpus> 7 | num_workers: <2 * num_gpus> 8 | data_root: 9 | data_path: 10 | ann_file: Annotations/train.json 11 | img_path: JPEGImages 12 | filter_cfg: 13 | filter_empty_gt: true 14 | min_size: 32 15 | batch_sampler: 16 | type: AspectRatioBatchSampler 17 | transforms: 18 | - type: 'LoadImageFromFile' 19 | - type: 'LoadAnnotations' 20 | with_bbox: true 21 | - type: 'Resize' 22 | scale: (800, 800) 23 | keep_ratio: False 24 | - type: 'RandomFlip' 25 | prob: 0.5 26 | - type: 'PackDetInputs' 27 | 28 | val_dataset: &val_dataset 29 | type: Sardet100k 30 | batch_size: <1 * num_gpus> 31 | num_workers: <1 * num_gpus> 32 | data_root: 33 | data_path: 34 | ann_file: Annotations/val.json 35 | img_path: JPEGImages 36 | test_mode: true 37 | batch_sampler: 38 | type: PadBatchSampler 39 | transforms: 40 | - type: 'LoadImageFromFile' 41 | - type: 'LoadAnnotations' 42 | with_bbox: true 43 | - type: 'Resize' 44 | scale: (800, 800) 45 | keep_ratio: true 46 | - type: 'PackDetInputs' 47 | meta_keys: [img_id, img_path, ori_shape, img_shape, 'scale_factor', 'sample_idx'] 48 | 49 | test_dataset: 50 | type: Sardet100k 51 | batch_size: <1 * num_gpus> 52 | num_workers: <1 * num_gpus> 53 | data_root: 54 | data_path: 55 | ann_file: Annotations/test.json 56 | img_path: JPEGImages 57 | test_mode: true 58 | batch_sampler: 59 | type: PadBatchSampler 60 | transforms: 61 | - type: 'LoadImageFromFile' 62 | - type: 'LoadAnnotations' 63 | with_bbox: true 64 | - type: 'Resize' 65 | scale: (800, 800) 66 | keep_ratio: true 67 | - type: 'PackDetInputs' 68 | meta_keys: [img_id, img_path, ori_shape, img_shape, 'scale_factor', 'sample_idx'] 69 | 70 | val_evaluator: &val_evaluator 71 | type: CocoEvaluator 72 | ann_file: 73 | metric: bbox 74 | format_only: false 75 | 76 | test_evaluator: 77 | type: CocoEvaluator 78 | ann_file: 79 | metric: bbox 80 | format_only: false 81 | -------------------------------------------------------------------------------- /jittordet/datasets/transforms/loading.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from jittordet.engine import TRANSFORMS 5 | 6 | 7 | @TRANSFORMS.register_module() 8 | class LoadImageFromFile: 9 | """Load an image from file.""" 10 | 11 | def __init__(self, to_float32=False): 12 | self.to_float32 = to_float32 13 | 14 | def __call__(self, results): 15 | img = cv2.imread(results['img_path']) 16 | if self.to_float32: 17 | img = img.astype(np.float32) 18 | 19 | results['img'] = img 20 | results['img_shape'] = img.shape[:2] 21 | results['ori_shape'] = img.shape[:2] 22 | return results 23 | 24 | def __repr__(self): 25 | repr_str = (f'{self.__class__.__name__}(' 26 | f'to_float32={self.to_float32})') 27 | return repr_str 28 | 29 | 30 | @TRANSFORMS.register_module() 31 | class LoadAnnotations: 32 | """Load multiple types of annotations.""" 33 | 34 | def __init__(self, with_bbox=True, with_label=True): 35 | self.with_bbox = with_bbox 36 | self.with_label = with_label 37 | 38 | def _load_bboxes(self, results): 39 | gt_bboxes = [] 40 | gt_ignore_flags = [] 41 | for instance in results.get('instances', []): 42 | gt_bboxes.append(instance['bbox']) 43 | gt_ignore_flags.append(instance['ignore_flag']) 44 | 45 | results['gt_bboxes'] = np.array( 46 | gt_bboxes, dtype=np.float32).reshape((-1, 4)) 47 | results['gt_ignore_flags'] = np.array(gt_ignore_flags, dtype=np.bool) 48 | 49 | def _load_labels(self, results): 50 | gt_bboxes_labels = [] 51 | for instance in results.get('instances', []): 52 | gt_bboxes_labels.append(instance['bbox_label']) 53 | results['gt_bboxes_labels'] = np.array( 54 | gt_bboxes_labels, dtype=np.int64) 55 | 56 | def __call__(self, results): 57 | if self.with_bbox: 58 | self._load_bboxes(results) 59 | if self.with_label: 60 | self._load_labels(results) 61 | return results 62 | 63 | def __repr__(self) -> str: 64 | repr_str = self.__class__.__name__ 65 | repr_str += f'(with_bbox={self.with_bbox}, ' 66 | repr_str += f'with_label={self.with_label})' 67 | return repr_str 68 | -------------------------------------------------------------------------------- /jittordet/datasets/wrappers.py: -------------------------------------------------------------------------------- 1 | from ..engine import BATCH_SAMPLERS, DATASETS 2 | from ..utils import is_list_of 3 | from .base import BaseDetDataset 4 | 5 | 6 | class PartCompose: 7 | 8 | def __init__(self, composes, partition): 9 | assert len(composes) == len(partition) 10 | self.partition = partition 11 | self.composes = composes 12 | 13 | def __call__(self, data): 14 | assert 'sample_idx' in data 15 | sample_idx = data['sample_idx'] 16 | for compose, part in zip(self.composes, self.partition): 17 | if sample_idx >= part: 18 | sample_idx -= part 19 | continue 20 | 21 | return compose(data) 22 | 23 | 24 | @DATASETS.register_module() 25 | class ConcatDataset(BaseDetDataset): 26 | 27 | def __init__(self, 28 | datasets, 29 | batch_size, 30 | num_workers=0, 31 | metainfo=None, 32 | test_mode=False, 33 | batch_sampler=None, 34 | max_refetch=100, 35 | **kwargs): 36 | super(BaseDetDataset, self).__init__( 37 | batch_size=batch_size, num_workers=num_workers, **kwargs) 38 | 39 | # override some setting in sub dataset 40 | assert is_list_of(datasets, dict) 41 | self.data_list, self.lengths = [], [] 42 | transforms = [] 43 | for dataset in datasets: 44 | dataset['batch_size'] = 1 45 | dataset['num_workers'] = 0 46 | dataset['metainfo'] = metainfo 47 | dataset['test_mode'] = test_mode 48 | dataset['batch_sampler'] = None 49 | dataset = DATASETS.build(dataset) 50 | self.data_list.extend(dataset.data_list) 51 | self.lengths.append(len(dataset.data_list)) 52 | transforms.append(dataset.transforms) 53 | 54 | self.transforms = PartCompose(transforms, self.lengths) 55 | 56 | self.test_mode = test_mode 57 | self.max_refetch = max_refetch 58 | 59 | # set total length for jittor.utils.dataset 60 | self.total_len = len(self.data_list) 61 | 62 | if batch_sampler is not None: 63 | self.batch_sampler = BATCH_SAMPLERS.build( 64 | batch_sampler, dataset=self) 65 | else: 66 | self.batch_sampler = None 67 | -------------------------------------------------------------------------------- /jittordet/ops/nms.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | 3 | 4 | def nms(boxes, scores, thresh): 5 | assert boxes.shape[-1] == 4 and len(scores) == len(boxes) 6 | if scores.ndim == 1: 7 | scores = scores.unsqueeze(-1) 8 | dets = jt.concat([boxes, scores], dim=1) 9 | return jt.nms(dets, thresh) 10 | 11 | 12 | def multiclass_nms(mlvl_bboxes, mlvl_scores, score_thr, nms, max_per_img): 13 | """NMS for multi-class bboxes. 14 | 15 | Args: 16 | multi_bboxes (Var): shape (n, #class*4) or (n, 4) 17 | multi_scores (Var): shape (n, #class), where the last column 18 | contains scores of the background class, but this will be ignored. 19 | score_thr (float): bbox threshold, bboxes with scores lower than it 20 | will not be considered. 21 | nms_thr (float): NMS IoU threshold 22 | max_num (int, optional): if there are more than max_num bboxes after 23 | NMS, only top max_num will be kept. Default to -1. 24 | 25 | Returns: 26 | tuple: (dets, labels), Var of shape (k, 5), 27 | (k), and (k). Dets are boxes with scores. Labels are 0-based. 28 | """ 29 | boxes = [] 30 | scores = [] 31 | labels = [] 32 | n_class = mlvl_scores.size(1) 33 | if mlvl_bboxes.shape[1] > 4: 34 | mlvl_bboxes = mlvl_bboxes.view(mlvl_bboxes.size(0), -1, 4) 35 | else: 36 | mlvl_bboxes = mlvl_bboxes.unsqueeze(1) 37 | mlvl_bboxes = mlvl_bboxes.expand((mlvl_bboxes.size(0), n_class, 4)) 38 | for j in range(0, n_class - 1): 39 | bbox_j = mlvl_bboxes[:, j, :] 40 | score_j = mlvl_scores[:, j:j + 1] 41 | mask = jt.where(score_j > score_thr)[0] 42 | bbox_j = bbox_j[mask, :] 43 | score_j = score_j[mask] 44 | dets = jt.concat([bbox_j, score_j], dim=1) 45 | keep = jt.nms(dets, nms['thresh']) 46 | bbox_j = bbox_j[keep] 47 | score_j = score_j[keep] 48 | label_j = jt.ones_like(score_j).int32() * j 49 | boxes.append(bbox_j) 50 | scores.append(score_j) 51 | labels.append(label_j) 52 | 53 | boxes = jt.concat(boxes, dim=0) 54 | scores = jt.concat(scores, dim=0) 55 | index, _ = jt.argsort(scores, dim=0, descending=True) 56 | index = index[:max_per_img, 0] 57 | boxes = jt.concat([boxes, scores], dim=1)[index] 58 | labels = jt.concat(labels, dim=0).squeeze(1)[index] 59 | return boxes, labels 60 | -------------------------------------------------------------------------------- /configs/crosskd/_model_/crosskd_gfl_r50_r18_fpn.yml: -------------------------------------------------------------------------------- 1 | model: 2 | type: GFLKDFramework 3 | preprocessor: 4 | type: Preprocessor 5 | mean: [123.675, 116.28, 103.53] 6 | std: [58.395, 57.12, 57.375] 7 | bgr_to_rgb: true 8 | pad_size_divisor: 32 9 | backbone: 10 | type: ResNet 11 | depth: 50 12 | frozen_stages: 1 13 | norm_eval: true 14 | return_stages: ['layer1', 'layer2', 'layer3', 'layer4'] 15 | pretrained: 'jittorhub://resnet50.pkl' 16 | teacher_config: configs/gfl/_model_/gfl_r50_fpn.yml 17 | teacher_ckpt: teacher.pkl 18 | eval_teacher: true 19 | kd_cfg: 20 | loss_cls_kd: 21 | type: KDQualityFocalLoss 22 | beta: 1 23 | loss_weight: 1.0 24 | loss_reg_kd: 25 | type: KnowledgeDistillationKLDivLoss 26 | class_reduction: sum 27 | T: 1 28 | loss_weight: 4.0 29 | loss_feat_kd: 30 | type: PKDLoss 31 | loss_weight: 1.0 32 | reused_teacher_head_idx: 3 33 | neck: 34 | type: FPN 35 | in_channels: [256, 512, 1024, 2048] 36 | out_channels: 256 37 | start_level: 1 38 | add_extra_convs: on_input 39 | num_outs: 5 40 | bbox_head: 41 | type: GFLHead 42 | num_classes: 80 43 | in_channels: 256 44 | stacked_convs: 4 45 | feat_channels: 256 46 | anchor_generator: 47 | type: AnchorGenerator 48 | ratios: [1.0] 49 | octave_base_scale: 8 50 | scales_per_octave: 1 51 | strides: [8, 16, 32, 64, 128] 52 | loss_cls: 53 | type: QualityFocalLoss 54 | use_sigmoid: true 55 | beta: 2.0 56 | loss_weight: 1.0 57 | loss_dfl: 58 | type: DistributionFocalLoss 59 | loss_weight: 0.25 60 | reg_max: 16 61 | loss_bbox: 62 | type: GIoULoss 63 | loss_weight: 2.0 64 | train_cfg: 65 | assigner: 66 | type: ATSSAssigner 67 | topk: 9 68 | allowed_border: -1 69 | pos_weight: -1 70 | debug: false 71 | test_cfg: 72 | nms_pre: 1000 73 | min_bbox_size: 0 74 | score_thr: 0.05 75 | nms: 76 | type: nms 77 | iou_threshold: 0.6 78 | max_per_img: 100 79 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/samplers/pseudo_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab. 2 | # mmdet/models/task_modules/samplers/pseudo_sampler.py 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | import jittor as jt 5 | 6 | from jittordet.engine import TASK_UTILS 7 | from jittordet.structures import InstanceData 8 | from ..assigners import AssignResult 9 | from .base_sampler import BaseSampler 10 | from .sampling_result import SamplingResult 11 | 12 | 13 | @TASK_UTILS.register_module() 14 | class PseudoSampler(BaseSampler): 15 | """A pseudo sampler that does not do sampling actually.""" 16 | 17 | def __init__(self, **kwargs): 18 | pass 19 | 20 | def _sample_pos(self, **kwargs): 21 | """Sample positive samples.""" 22 | raise NotImplementedError 23 | 24 | def _sample_neg(self, **kwargs): 25 | """Sample negative samples.""" 26 | raise NotImplementedError 27 | 28 | def sample(self, assign_result: AssignResult, pred_instances: InstanceData, 29 | gt_instances: InstanceData, *args, **kwargs): 30 | """Directly returns the positive and negative indices of samples. 31 | 32 | Args: 33 | assign_result (:obj:`AssignResult`): Bbox assigning results. 34 | pred_instances (:obj:`InstanceData`): Instances of model 35 | predictions. It includes ``priors``, and the priors can 36 | be anchors, points, or bboxes predicted by the model, 37 | shape(n, 4). 38 | gt_instances (:obj:`InstanceData`): Ground truth of instance 39 | annotations. It usually includes ``bboxes`` and ``labels`` 40 | attributes. 41 | 42 | Returns: 43 | :obj:`SamplingResult`: sampler results 44 | """ 45 | gt_bboxes = gt_instances.bboxes 46 | priors = pred_instances.priors 47 | 48 | pos_inds = jt.nonzero(assign_result.gt_inds > 0).squeeze(-1).unique() 49 | neg_inds = jt.nonzero(assign_result.gt_inds == 0).squeeze(-1).unique() 50 | 51 | gt_flags = jt.zeros(priors.shape[0], dtype=jt.uint8) 52 | sampling_result = SamplingResult( 53 | pos_inds=pos_inds, 54 | neg_inds=neg_inds, 55 | priors=priors, 56 | gt_bboxes=gt_bboxes, 57 | assign_result=assign_result, 58 | gt_flags=gt_flags, 59 | avg_factor_with_neg=False) 60 | return sampling_result 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | work_dirs/ 132 | -------------------------------------------------------------------------------- /jittordet/datasets/transforms/formatting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from jittordet.engine import TRANSFORMS 4 | from jittordet.structures import DetDataSample, InstanceData 5 | 6 | 7 | @TRANSFORMS.register_module() 8 | class PackDetInputs: 9 | """Pack the inputs data for detection.""" 10 | mapping_table = { 11 | 'gt_bboxes': 'bboxes', 12 | 'gt_bboxes_labels': 'labels', 13 | } 14 | 15 | def __init__(self, 16 | meta_keys=('sample_idx', 'img_id', 'img_path', 'ori_shape', 17 | 'img_shape', 'scale_factor', 'flip', 18 | 'flip_direction')): 19 | self.meta_keys = meta_keys 20 | 21 | def __call__(self, results): 22 | packed_results = dict() 23 | if 'img' in results: 24 | img = results['img'] 25 | if len(img.shape) < 3: 26 | img = np.expand_dims(img, -1) 27 | img = np.ascontiguousarray(img.transpose(2, 0, 1)) 28 | packed_results['inputs'] = img 29 | 30 | if 'gt_ignore_flags' in results: 31 | valid_idx = np.where(results['gt_ignore_flags'] == 0)[0] 32 | ignore_idx = np.where(results['gt_ignore_flags'] == 1)[0] 33 | 34 | data_sample = DetDataSample() 35 | instance_data = InstanceData() 36 | ignore_instance_data = InstanceData() 37 | 38 | for key in self.mapping_table.keys(): 39 | if key not in results: 40 | continue 41 | if 'gt_ignore_flags' in results: 42 | instance_data[ 43 | self.mapping_table[key]] = results[key][valid_idx] 44 | ignore_instance_data[ 45 | self.mapping_table[key]] = results[key][ignore_idx] 46 | else: 47 | instance_data[self.mapping_table[key]] = results[key] 48 | data_sample.gt_instances = instance_data 49 | data_sample.ignored_instances = ignore_instance_data 50 | 51 | img_meta = {} 52 | for key in self.meta_keys: 53 | assert key in results, f'`{key}` is not found in `results`, ' \ 54 | f'the valid keys are {list(results)}.' 55 | img_meta[key] = results[key] 56 | 57 | data_sample.set_metainfo(img_meta) 58 | packed_results['data_samples'] = data_sample 59 | 60 | return packed_results 61 | 62 | def __repr__(self) -> str: 63 | repr_str = self.__class__.__name__ 64 | repr_str += f'(meta_keys={self.meta_keys})' 65 | return repr_str 66 | -------------------------------------------------------------------------------- /jittordet/datasets/samplers/aspect_ratio_batch_sampler.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | import numpy as np 3 | 4 | from jittordet.engine import BATCH_SAMPLERS 5 | from .base_batch_sampler import BaseBatchSampler 6 | 7 | 8 | @BATCH_SAMPLERS.register_module() 9 | class AspectRatioBatchSampler(BaseBatchSampler): 10 | 11 | def __init__(self, dataset, drop_last=False): 12 | super().__init__(dataset=dataset) 13 | self.drop_last = drop_last 14 | # statstic different aspect ratios idx. 15 | idx_bucket1, idx_bucket2 = [], [] 16 | for idx, data in enumerate(dataset.data_list): 17 | if data['width'] > data['height']: 18 | idx_bucket1.append(idx) 19 | else: 20 | idx_bucket2.append(idx) 21 | self.idx_bucket1 = np.array(idx_bucket1) 22 | self.idx_bucket2 = np.array(idx_bucket2) 23 | 24 | def get_index_list(self, rng=None): 25 | if rng is None: 26 | rng = np.random.default_rng() 27 | 28 | # shuffle 29 | shuffle_idx = rng.permutation(self.idx_bucket1.size) 30 | idx_bucket1 = self.idx_bucket1[shuffle_idx] 31 | shuffle_idx = rng.permutation(self.idx_bucket2.size) 32 | idx_bucket2 = self.idx_bucket2[shuffle_idx] 33 | 34 | # drop last size 35 | world_size = 1 if not jt.in_mpi else jt.world_size 36 | total_bs = self.total_bs 37 | real_bs = int(total_bs // world_size) 38 | assert real_bs * jt.world_size == total_bs 39 | if idx_bucket1.size % real_bs != 0: 40 | mod_size = idx_bucket1.size % real_bs 41 | idx_bucket1 = idx_bucket1[:-mod_size] 42 | idx_bucket1 = idx_bucket1.reshape(-1, real_bs) 43 | if idx_bucket2.size % real_bs != 0: 44 | mod_size = idx_bucket2.size % real_bs 45 | idx_bucket2 = idx_bucket2[:-mod_size] 46 | idx_bucket2 = idx_bucket2.reshape(-1, real_bs) 47 | 48 | index = np.concatenate([idx_bucket1, idx_bucket2], axis=0) 49 | shuffle_idx = rng.permutation(index.shape[0]) 50 | index = index[shuffle_idx] 51 | 52 | real_bs_num = len(self) * world_size 53 | repeat_num = int((real_bs_num - 0.5) // index.shape[0] + 1) 54 | index = np.concatenate([index] * repeat_num, axis=0) 55 | index = index[:real_bs_num] 56 | 57 | if jt.in_mpi: 58 | index = index.reshape(-1, total_bs) 59 | index = index[:, jt.rank * real_bs:(jt.rank + 1) * real_bs] 60 | 61 | index = index.flatten() 62 | return index 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JittorDet 2 | 3 | ## introduction 4 | 5 | JittorDet is an object detection benchmark based on [Jittor](https://cg.cs.tsinghua.edu.cn/jittor/). 6 | 7 | ## Supported Models 8 | 9 | JittorDet supports commonly used datasets (COCO, VOC) and models (RetinaNet, Faster R-CNN) out of box. 10 | 11 | Currently supported models are as below: 12 | 13 | - RetinaNet 14 | - Faster R-CNN 15 | - GFocalLoss 16 | - PKD 17 | - ⭐ CrossKD: [https://github.com/jbwang1997/CrossKD](https://github.com/jbwang1997/CrossKD) 18 | 19 | New state-of-the-art models are also being implemented: 20 | 21 | - SARDet100K 22 | 23 | ## Getting Started 24 | 25 | ### Install 26 | 27 | Please first follow the [tutorial](https://github.com/Jittor/jittor) to install jittor. 28 | Here, we recommend using jittor==1.3.6.10, which we have tested on. 29 | 30 | Then, install the `jittordet` by running: 31 | ``` 32 | pip install -v -e . 33 | ``` 34 | 35 | If you want to use multi-gpu training or testing, please install OpenMPI 36 | ``` 37 | sudo apt install openmpi-bin openmpi-common libopenmpi-dev 38 | ``` 39 | 40 | ### Training 41 | 42 | We support single-gpu, multi-gpu training. 43 | ``` 44 | #Single-GPU 45 | python tools/train.py {CONFIG_PATH} 46 | 47 | # Multi-GPU 48 | bash tools/dist_train.sh {CONFIG_PATH} {NUM_GPUS} 49 | ``` 50 | 51 | ### Testing 52 | 53 | We support single-gpu, multi-gpu testing. 54 | ``` 55 | #Single-GPU 56 | python tools/test.py {CONFIG_PATH} 57 | 58 | # Multi-GPU 59 | bash tools/dist_test.sh {CONFIG_PATH} {NUM_GPUS} 60 | ``` 61 | 62 | # Citation 63 | 64 | If this work is helpful for your research, please consider citing the following entry. 65 | 66 | ``` 67 | @article{hu2020jittor, 68 | title={Jittor: a novel deep learning framework with meta-operators and unified graph execution}, 69 | author={Hu, Shi-Min and Liang, Dun and Yang, Guo-Ye and Yang, Guo-Wei and Zhou, Wen-Yang}, 70 | journal={Science China Information Sciences}, 71 | volume={63}, 72 | number={222103}, 73 | pages={1--21}, 74 | year={2020} 75 | } 76 | 77 | @inproceedings{wang2024crosskd, 78 | title={CrossKD: Cross-head knowledge distillation for object detection}, 79 | author={Wang, Jiabao and Chen, Yuming and Zheng, Zhaohui and Li, Xiang and Cheng, Ming-Ming and Hou, Qibin}, 80 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 81 | pages={16520--16530}, 82 | year={2024} 83 | } 84 | 85 | @article{li2024sardet, 86 | title={Sardet-100k: Towards open-source benchmark and toolkit for large-scale sar object detection}, 87 | author={Li, Yuxuan and Li, Xiang and Li, Weijie and Hou, Qibin and Liu, Li and Cheng, Ming-Ming and Yang, Jian}, 88 | journal={arXiv preprint arXiv:2403.06534}, 89 | year={2024} 90 | } 91 | ``` 92 | 93 | # Acknowledge 94 | 95 | Our code is developed on top of following open source codebase: 96 | 97 | - [Jittor](https://github.com/Jittor/jittor) 98 | - [JDet](https://github.com/Jittor/JDet) 99 | - [MMCV](https://github.com/open-mmlab/mmcv) 100 | - [MMDetection](https://github.com/open-mmlab/mmdetection) 101 | 102 | We sincerely appreciate their amazing works. 103 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/bbox_coders/distance_point_bbox_coder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Optional, Sequence, Union 3 | 4 | from jittor import jt 5 | 6 | from jittordet.engine import TASK_UTILS 7 | from jittordet.utils import bbox2distance, distance2bbox 8 | from .base_bbox_coder import BaseBBoxCoder 9 | 10 | 11 | @TASK_UTILS.register_module() 12 | class DistancePointBBoxCoder(BaseBBoxCoder): 13 | """Distance Point BBox coder. 14 | 15 | This coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left, 16 | right) and decode it back to the original. 17 | 18 | Args: 19 | clip_border (bool, optional): Whether clip the objects outside the 20 | border of the image. Defaults to True. 21 | """ 22 | 23 | def __init__(self, clip_border: Optional[bool] = True, **kwargs) -> None: 24 | super().__init__(**kwargs) 25 | self.clip_border = clip_border 26 | 27 | def encode(self, 28 | points, 29 | gt_bboxes, 30 | max_dis: Optional[float] = None, 31 | eps: float = 0.1) -> jt.Var: 32 | """Encode bounding box to distances. 33 | 34 | Args: 35 | points (Tensor): Shape (N, 2), The format is [x, y]. 36 | gt_bboxes (Tensor or :obj:`BaseBoxes`): Shape (N, 4), The format 37 | is "xyxy" 38 | max_dis (float): Upper bound of the distance. Default None. 39 | eps (float): a small value to ensure target < max_dis, instead <=. 40 | Default 0.1. 41 | 42 | Returns: 43 | Tensor: Box transformation deltas. The shape is (N, 4). 44 | """ 45 | assert points.size(0) == gt_bboxes.size(0) 46 | assert points.size(-1) == 2 47 | assert gt_bboxes.size(-1) == 4 48 | return bbox2distance(points, gt_bboxes, max_dis, eps) 49 | 50 | def decode(self, 51 | points: jt.Var, 52 | pred_bboxes: jt.Var, 53 | max_shape: Optional[Union[Sequence[int], jt.Var, 54 | Sequence[Sequence[int]]]] = None): 55 | """Decode distance prediction to bounding box. 56 | 57 | Args: 58 | points (Tensor): Shape (B, N, 2) or (N, 2). 59 | pred_bboxes (Tensor): Distance from the given point to 4 60 | boundaries (left, top, right, bottom). Shape (B, N, 4) 61 | or (N, 4) 62 | max_shape (Sequence[int] or torch.Tensor or Sequence[ 63 | Sequence[int]],optional): Maximum bounds for boxes, specifies 64 | (H, W, C) or (H, W). If priors shape is (B, N, 4), then 65 | the max_shape should be a Sequence[Sequence[int]], 66 | and the length of max_shape should also be B. 67 | Default None. 68 | Returns: 69 | Union[Tensor, :obj:`BaseBoxes`]: Boxes with shape (N, 4) or 70 | (B, N, 4) 71 | """ 72 | assert points.size(0) == pred_bboxes.size(0) 73 | assert points.size(-1) == 2 74 | assert pred_bboxes.size(-1) == 4 75 | if self.clip_border is False: 76 | max_shape = None 77 | bboxes = distance2bbox(points, pred_bboxes, max_shape) 78 | 79 | return bboxes 80 | -------------------------------------------------------------------------------- /jittordet/engine/config/parsers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import os.path as osp 4 | import re 5 | 6 | from .utils import iter_leaves, set_leaf 7 | 8 | 9 | def env_variable_parser(cfg): 10 | """use environment variables in cfg.""" 11 | regexp1 = r'^\s*\$(\w+)\s*\:\s*(\S*?)\s*$' 12 | regexp2 = r'\<\s*\$(\w+)\s*\:\s*(\S*?)\s*\>' 13 | for keys, value in iter_leaves(cfg): 14 | if not isinstance(value, str): 15 | continue 16 | # entire string is environment 17 | results = re.match(regexp1, value) 18 | if results: 19 | var_name, def_value = results.groups() 20 | new_value = os.environ[var_name] \ 21 | if var_name in os.environ else def_value 22 | if new_value.isdigit(): 23 | new_value = eval(new_value) 24 | set_leaf(cfg, keys, new_value) 25 | continue 26 | # partial string is environment 27 | results = re.findall(regexp2, value) 28 | _value = value 29 | for var_name, def_value in results: 30 | regexp = r'\<\s*\$' + var_name + r'\s*\:\s*' \ 31 | + def_value + r'\s*\>' 32 | new_value = os.environ[var_name] if var_name in os.environ \ 33 | else def_value 34 | _value = re.sub(regexp, new_value, _value) 35 | if value != _value: 36 | set_leaf(cfg, keys, _value) 37 | 38 | 39 | def default_var_parser(cfg): 40 | """set some default value in cfg.""" 41 | filename = cfg['filename'] 42 | file_dirname = osp.dirname(filename) 43 | file_basename = osp.basename(filename) 44 | file_basename_no_extension = osp.splitext(file_basename)[0] 45 | file_extname = osp.splitext(filename)[1] 46 | support_templates = dict( 47 | fileDirname=file_dirname, 48 | fileBasename=file_basename, 49 | fileBasenameNoExtension=file_basename_no_extension, 50 | fileExtname=file_extname) 51 | 52 | for keys, value in iter_leaves(cfg): 53 | if not isinstance(value, str): 54 | continue 55 | 56 | new_value = copy.copy(value) 57 | for k, v in support_templates.items(): 58 | regexp = r'\<\s*' + str(k) + r'\s*\>' 59 | v = v.replace('\\', '/') 60 | new_value = re.sub(regexp, v, new_value) 61 | 62 | if value != new_value: 63 | set_leaf(cfg, keys, value) 64 | 65 | 66 | def tuple_parser(cfg): 67 | for keys, value in iter_leaves(cfg): 68 | if not isinstance(value, str): 69 | continue 70 | 71 | if value.startswith('(') and value.endswith(')'): 72 | try: 73 | value = eval(value) 74 | except: # noqa: E722 75 | pass 76 | 77 | if isinstance(value, tuple): 78 | set_leaf(cfg, keys, value) 79 | 80 | 81 | def python_eval_parser(cfg): 82 | eval_global = copy.deepcopy(cfg) 83 | for keys, value in iter_leaves(cfg): 84 | if not isinstance(value, str): 85 | continue 86 | if not value.startswith('<') or not value.endswith('>'): 87 | continue 88 | 89 | value = value[1:-1] 90 | value = eval(value, eval_global) 91 | set_leaf(cfg, keys, value) 92 | 93 | 94 | cfg_parsers = [ 95 | env_variable_parser, default_var_parser, tuple_parser, python_eval_parser 96 | ] 97 | -------------------------------------------------------------------------------- /jittordet/models/utils/initialize.py: -------------------------------------------------------------------------------- 1 | # Modified from jdet/models/utils/weight_init.py 2 | import jittor as jt 3 | import numpy as np 4 | from jittor import init, nn 5 | 6 | 7 | def normal_init(module, mean=0, std=1, bias=0): 8 | if hasattr(module, 'weight') and module.weight is not None: 9 | init.gauss_(module.weight, mean, std) 10 | if hasattr(module, 'bias') and isinstance( 11 | module.bias, jt.Var) and module.bias is not None: 12 | init.constant_(module.bias, bias) 13 | 14 | 15 | def constant_init(module, val, bias=0): 16 | if hasattr(module, 'weight') and module.weight is not None: 17 | init.constant_(module.weight, val) 18 | if hasattr(module, 'bias') and module.bias is not None: 19 | init.constant_(module.bias, bias) 20 | 21 | 22 | def xavier_init(module, gain=1, bias=0, distribution='normal'): 23 | assert distribution in ['uniform', 'normal'] 24 | if hasattr(module, 'weight') and module.weight is not None: 25 | if distribution == 'uniform': 26 | init.xavier_uniform_(module.weight, gain=gain) 27 | else: 28 | init.xavier_normal_(module.weight, gain=gain) 29 | if hasattr(module, 'bias') and module.bias is not None: 30 | init.constant_(module.bias, bias) 31 | 32 | 33 | def trunc_normal_init(module: nn.Module, 34 | mean: float = 0, 35 | std: float = 1, 36 | a: float = -2, 37 | b: float = 2, 38 | bias: float = 0) -> None: 39 | if hasattr(module, 'weight') and module.weight is not None: 40 | init.trunc_normal_(module.weight, mean, std, a, b) # type: ignore 41 | if hasattr(module, 'bias') and module.bias is not None: 42 | init.constant_(module.bias, bias) # type: ignore 43 | 44 | 45 | def uniform_init(module, a=0, b=1, bias=0): 46 | if hasattr(module, 'weight') and module.weight is not None: 47 | init.uniform_(module.weight, a, b) 48 | if hasattr(module, 'bias') and module.bias is not None: 49 | init.constant_(module.bias, bias) 50 | 51 | 52 | def kaiming_init(module, 53 | a=0, 54 | mode='fan_out', 55 | nonlinearity='relu', 56 | bias=0, 57 | distribution='normal'): 58 | assert distribution in ['uniform', 'normal'] 59 | if hasattr(module, 'weight') and module.weight is not None: 60 | if distribution == 'uniform': 61 | init.kaiming_uniform_( 62 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 63 | else: 64 | init.kaiming_normal_( 65 | module.weight, a=a, mode=mode, nonlinearity=nonlinearity) 66 | if hasattr(module, 'bias') and module.bias is not None: 67 | init.constant_(module.bias, bias) 68 | 69 | 70 | def caffe2_xavier_init(module, bias=0): 71 | # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch 72 | # Acknowledgment to FAIR's internal code 73 | kaiming_init( 74 | module, 75 | a=1, 76 | mode='fan_in', 77 | nonlinearity='leaky_relu', 78 | bias=bias, 79 | distribution='uniform') 80 | 81 | 82 | def bias_init_with_prob(prior_prob): 83 | """initialize conv/fc bias value according to a given probability value.""" 84 | bias_init = float(-np.log((1 - prior_prob) / prior_prob)) 85 | return bias_init 86 | -------------------------------------------------------------------------------- /jittordet/engine/optim/schedulers.py: -------------------------------------------------------------------------------- 1 | import jittor.lr_scheduler as scheduler 2 | 3 | from ..register import SCHEDULERS 4 | 5 | 6 | @SCHEDULERS.register_module() 7 | class BaseScheduler: 8 | 9 | def state_dict(self): 10 | state_dict = {} 11 | exclude = ['optimizer'] 12 | for key, value in self.__dict__.items(): 13 | if key in exclude or callable(value): 14 | continue 15 | state_dict[key] = value 16 | return state_dict 17 | 18 | def load_state_dict(self, data): 19 | assert isinstance(data, dict) 20 | for key, value in data.items(): 21 | if key in self.__dict__: 22 | self.__dict__[key] = value 23 | 24 | 25 | @SCHEDULERS.register_module() 26 | class WarmUpLR(BaseScheduler): 27 | """Copy from JDet. Warm LR scheduler, which is the base lr_scheduler, 28 | default we use it. 29 | 30 | Args: 31 | optimizer (Optimizer): the optimizer to optimize the model 32 | warmup (string): Type of warmup used. It can be None(use no warmup), 33 | 'constant', 'linear' or 'exp' 34 | warmup_iters (int): The number of iterations or epochs that warmup 35 | lasts 36 | warmup_ratio (float): LR used at the beginning of warmup equals to 37 | warmup_ratio * initial_lr 38 | """ 39 | 40 | def __init__(self, 41 | optimizer, 42 | warmup_ratio=1.0 / 3, 43 | warmup_iters=500, 44 | warmup='linear', 45 | last_iter=-1): 46 | self.optimizer = optimizer 47 | self.warmup_ratio = warmup_ratio 48 | self.warmup_iters = warmup_iters 49 | self.warmup = warmup 50 | self.base_lr = optimizer.lr 51 | self.base_lr_pg = [ 52 | pg.get('lr', optimizer.lr) for pg in optimizer.param_groups 53 | ] 54 | self.by_iter = True 55 | self.last_iter = last_iter 56 | 57 | def get_warmup_lr(self, lr, cur_iters): 58 | if self.warmup == 'constant': 59 | k = self.warmup_ratio 60 | elif self.warmup == 'linear': 61 | k = 1 - (1 - cur_iters / self.warmup_iters) * (1 - 62 | self.warmup_ratio) 63 | elif self.warmup == 'exp': 64 | k = self.warmup_ratio**(1 - cur_iters / self.warmup_iters) 65 | return k * lr 66 | 67 | def _update_lr(self, steps): 68 | self.optimizer.lr = self.get_warmup_lr(self.base_lr, steps) 69 | for i, param_group in enumerate(self.optimizer.param_groups): 70 | param_group['lr'] = self.get_warmup_lr(self.base_lr_pg[i], steps) 71 | 72 | def step(self): 73 | self.last_iter += 1 74 | if self.last_iter <= self.warmup_iters: 75 | self._update_lr(self.last_iter) 76 | 77 | 78 | @SCHEDULERS.register_module() 79 | class CosineAnnealingLR(scheduler.CosineAnnealingLR, BaseScheduler): 80 | """CosineAnnealing LR Scheduler.""" 81 | 82 | 83 | @SCHEDULERS.register_module() 84 | class ExponentialLR(scheduler.ExponentialLR, BaseScheduler): 85 | """Exponential LR Scheduler.""" 86 | 87 | 88 | @SCHEDULERS.register_module() 89 | class StepLR(scheduler.StepLR, BaseScheduler): 90 | """Step LR Scheduler.""" 91 | 92 | 93 | @SCHEDULERS.register_module() 94 | class MultiStepLR(scheduler.MultiStepLR, BaseScheduler): 95 | """Multiple Step LR Scheduler.""" 96 | -------------------------------------------------------------------------------- /jittordet/models/losses/accuracy.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab mmdet/models/losses/accuracy.py 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | import jittor as jt 4 | import jittor.nn as nn 5 | 6 | from jittordet.engine import MODELS 7 | 8 | 9 | def accuracy(pred, target, topk=1, thresh=None): 10 | """Calculate accuracy according to the prediction and target. 11 | 12 | Args: 13 | pred (torch.Tensor): The model prediction, shape (N, num_class) 14 | target (torch.Tensor): The target of each prediction, shape (N, ) 15 | topk (int | tuple[int], optional): If the predictions in ``topk`` 16 | matches the target, the predictions will be regarded as 17 | correct ones. Defaults to 1. 18 | thresh (float, optional): If not None, predictions with scores under 19 | this threshold are considered incorrect. Default to None. 20 | 21 | Returns: 22 | float | tuple[float]: If the input ``topk`` is a single integer, 23 | the function will return a single float as accuracy. If 24 | ``topk`` is a tuple containing multiple integers, the 25 | function will return a tuple containing accuracies of 26 | each ``topk`` number. 27 | """ 28 | assert isinstance(topk, (int, tuple)) 29 | if isinstance(topk, int): 30 | topk = (topk, ) 31 | return_single = True 32 | else: 33 | return_single = False 34 | 35 | maxk = max(topk) 36 | if pred.size(0) == 0: 37 | accu = [jt.array(0, dtype=pred.dtype) for i in range(len(topk))] 38 | return accu[0] if return_single else accu 39 | assert pred.ndim == 2 and target.ndim == 1 40 | assert pred.size(0) == target.size(0) 41 | assert maxk <= pred.size(1), \ 42 | f'maxk {maxk} exceeds pred dimension {pred.size(1)}' 43 | pred_value, pred_label = pred.topk(maxk, dim=1) 44 | pred_label = pred_label.t() # transpose to shape (maxk, N) 45 | 46 | correct = jt.equal(pred_label, target.view(1, -1).expand_as(pred_label)) 47 | if thresh is not None: 48 | # Only prediction values larger than thresh are counted as correct 49 | correct = correct & (pred_value > thresh).t() 50 | res = [] 51 | for k in topk: 52 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdims=True) 53 | res.append(jt.multiply(correct_k, 100.0 / pred.size(0))) 54 | return res[0] if return_single else res 55 | 56 | 57 | @MODELS.register_module() 58 | class Accuracy(nn.Module): 59 | 60 | def __init__(self, topk=(1, ), thresh=None): 61 | """Module to calculate the accuracy. 62 | 63 | Args: 64 | topk (tuple, optional): The criterion used to calculate the 65 | accuracy. Defaults to (1,). 66 | thresh (float, optional): If not None, predictions with scores 67 | under this threshold are considered incorrect. Default to None. 68 | """ 69 | super().__init__() 70 | self.topk = topk 71 | self.thresh = thresh 72 | 73 | def execute(self, pred, target): 74 | """Forward function to calculate accuracy. 75 | 76 | Args: 77 | pred (torch.Tensor): Prediction of models. 78 | target (torch.Tensor): Target for each prediction. 79 | 80 | Returns: 81 | tuple[float]: The accuracies under different topk criterions. 82 | """ 83 | return accuracy(pred, target, self.topk, self.thresh) 84 | -------------------------------------------------------------------------------- /configs/faster_rcnn/_model_/faster_rcnn_r50_fpn.yml: -------------------------------------------------------------------------------- 1 | model: 2 | type: MultiStageFramework 3 | preprocessor: 4 | type: Preprocessor 5 | mean: [123.675, 116.28, 103.53] 6 | std: [58.395, 57.12, 57.375] 7 | bgr_to_rgb: true 8 | pad_size_divisor: 32 9 | backbone: 10 | type: ResNet 11 | depth: 50 12 | frozen_stages: 1 13 | norm_eval: true 14 | return_stages: ['layer1', 'layer2', 'layer3', 'layer4'] 15 | pretrained: 'jittorhub://resnet50.pkl' 16 | neck: 17 | type: FPN 18 | in_channels: [256, 512, 1024, 2048] 19 | out_channels: 256 20 | num_outs: 5 21 | rpn_head: 22 | type: RPNHead 23 | num_classes: 1 24 | in_channels: 256 25 | feat_channels: 256 26 | anchor_generator: 27 | type: AnchorGenerator 28 | scales: [8] 29 | ratios: [0.5, 1.0, 2.0] 30 | strides: [4, 8, 16, 32, 64] 31 | bbox_coder: 32 | type: DeltaXYWHBBoxCoder 33 | target_means: [.0, .0, .0, .0] 34 | target_stds: [1.0, 1.0, 1.0, 1.0] 35 | loss_cls: 36 | type: CrossEntropyLoss 37 | use_sigmoid: true 38 | loss_weight: 1.0 39 | loss_bbox: 40 | type: L1Loss 41 | loss_weight: 1.0 42 | roi_head: 43 | type: StandardRoIHead 44 | bbox_roi_extractor: 45 | type: SingleRoIExtractor 46 | roi_layer: 47 | type: ROIAlign 48 | output_size: 7 49 | sampling_ratio: 0 50 | out_channels: 256 51 | featmap_strides: [4, 8, 16, 32] 52 | bbox_head: 53 | type: Shared2FCBBoxHead 54 | in_channels: 256 55 | fc_out_channels: 1024 56 | roi_feat_size: 7 57 | num_classes: 80 58 | bbox_coder: 59 | type: DeltaXYWHBBoxCoder 60 | target_means: [.0, .0, .0, .0] 61 | target_stds: [0.1, 0.1, 0.2, 0.2] 62 | reg_class_agnostic: false 63 | loss_cls: 64 | type: CrossEntropyLoss 65 | use_sigmoid: false 66 | loss_weight: 1.0 67 | loss_bbox: 68 | type: L1Loss 69 | loss_weight: 1.0 70 | train_cfg: 71 | rpn: 72 | assigner: 73 | type: MaxIoUAssigner 74 | pos_iou_thr: 0.7 75 | neg_iou_thr: 0.3 76 | min_pos_iou: 0.3 77 | match_low_quality: true 78 | ignore_iof_thr: -1 79 | sampler: 80 | type: RandomSampler 81 | num: 256 82 | pos_fraction: 0.5 83 | neg_pos_ub: -1 84 | add_gt_as_proposals: false 85 | allowed_border: -1 86 | pos_weight: -1 87 | rpn_proposal: 88 | nms_pre: 2000 89 | max_per_img: 1000 90 | nms: 91 | type: 'nms' 92 | thresh: 0.7 93 | min_bbox_size: 0 94 | rcnn: 95 | assigner: 96 | type: MaxIoUAssigner 97 | pos_iou_thr: 0.5 98 | neg_iou_thr: 0.5 99 | min_pos_iou: 0.5 100 | match_low_quality: false 101 | ignore_iof_thr: -1 102 | sampler: 103 | type: RandomSampler 104 | num: 512 105 | pos_fraction: 0.25 106 | neg_pos_ub: -1 107 | add_gt_as_proposals: true 108 | pos_weight: -1 109 | test_cfg: 110 | rpn: 111 | nms_pre: 1000 112 | max_per_img: 1000 113 | nms: 114 | type: nms 115 | thresh: 0.7 116 | min_bbox_size: 0 117 | rcnn: 118 | score_thr: 0.05 119 | max_per_img: 100 120 | nms: 121 | type: nms 122 | thresh: 0.5 123 | -------------------------------------------------------------------------------- /jittordet/models/frameworks/rpn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | import warnings 4 | 5 | import jittor as jt 6 | 7 | from jittordet.engine import MODELS, ConfigType, OptConfigType 8 | from jittordet.structures import SampleList 9 | from .single_stage import SingleStageFramework 10 | 11 | 12 | @MODELS.register_module() 13 | class RPNFramework(SingleStageFramework): 14 | """Implementation of Region Proposal Network. 15 | 16 | Args: 17 | backbone (:obj:`ConfigDict` or dict): The backbone config. 18 | neck (:obj:`ConfigDict` or dict): The neck config. 19 | bbox_head (:obj:`ConfigDict` or dict): The bbox head config. 20 | train_cfg (:obj:`ConfigDict` or dict, optional): The training config. 21 | test_cfg (:obj:`ConfigDict` or dict, optional): The testing config. 22 | data_preprocessor (:obj:`ConfigDict` or dict, optional): Config of 23 | :class:`DetDataPreprocessor` to process the input data. 24 | Defaults to None. 25 | init_cfg (:obj:`ConfigDict` or list[:obj:`ConfigDict`] or dict or 26 | list[dict], optional): Initialization config dict. 27 | Defaults to None. 28 | """ 29 | 30 | def __init__(self, 31 | backbone: ConfigType, 32 | neck: ConfigType, 33 | rpn_head: ConfigType, 34 | train_cfg: ConfigType, 35 | test_cfg: ConfigType, 36 | preprocessor: OptConfigType = None, 37 | **kwargs) -> None: 38 | super(SingleStageFramework, self).__init__(preprocessor=preprocessor) 39 | self.backbone = MODELS.build(backbone) 40 | self.neck = MODELS.build(neck) if neck is not None else None 41 | rpn_train_cfg = train_cfg['rpn'] if train_cfg is not None else None 42 | rpn_head_num_classes = rpn_head.get('num_classes', 1) 43 | if rpn_head_num_classes != 1: 44 | warnings.warn('The `num_classes` should be 1 in RPN, but get ' 45 | f'{rpn_head_num_classes}, please set ' 46 | 'rpn_head.num_classes = 1 in your config file.') 47 | rpn_head.update(num_classes=1) 48 | rpn_head.update(train_cfg=rpn_train_cfg) 49 | rpn_head.update(test_cfg=test_cfg['rpn']) 50 | self.bbox_head = MODELS.build(rpn_head) 51 | self.train_cfg = train_cfg 52 | self.test_cfg = test_cfg 53 | 54 | def loss(self, batch_inputs: jt.Var, 55 | batch_data_samples: SampleList) -> dict: 56 | """Calculate losses from a batch of inputs and data samples. 57 | 58 | Args: 59 | batch_inputs (Tensor): Input images of shape (N, C, H, W). 60 | These should usually be mean centered and std scaled. 61 | batch_data_samples (list[:obj:`DetDataSample`]): The batch 62 | data samples. It usually includes information such 63 | as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. 64 | 65 | Returns: 66 | dict[str, Tensor]: A dictionary of loss components. 67 | """ 68 | x = self.extract_feat(batch_inputs) 69 | 70 | # set cat_id of gt_labels to 0 in RPN 71 | rpn_data_samples = copy.deepcopy(batch_data_samples) 72 | for data_sample in rpn_data_samples: 73 | data_sample.gt_instances.labels = \ 74 | jt.zeros_like(data_sample.gt_instances.labels) 75 | 76 | losses = self.bbox_head.loss(x, rpn_data_samples) 77 | return losses 78 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/assigners/assign_result.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import jittor as jt 3 | 4 | 5 | class AssignResult: 6 | """Stores assignments between predicted and truth boxes. 7 | 8 | Attributes: 9 | num_gts (int): the number of truth boxes considered when computing this 10 | assignment 11 | gt_inds (Tensor): for each predicted box indicates the 1-based 12 | index of the assigned truth box. 0 means unassigned and -1 means 13 | ignore. 14 | max_overlaps (Tensor): the iou between the predicted box and its 15 | assigned truth box. 16 | labels (Tensor): If specified, for each predicted box 17 | indicates the category label of the assigned truth box. 18 | 19 | Example: 20 | >>> # An assign result between 4 predicted boxes and 9 true boxes 21 | >>> # where only two boxes were assigned. 22 | >>> num_gts = 9 23 | >>> max_overlaps = torch.LongTensor([0, .5, .9, 0]) 24 | >>> gt_inds = torch.LongTensor([-1, 1, 2, 0]) 25 | >>> labels = torch.LongTensor([0, 3, 4, 0]) 26 | >>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels) 27 | >>> print(str(self)) # xdoctest: +IGNORE_WANT 28 | 30 | >>> # Force addition of gt labels (when adding gt as proposals) 31 | >>> new_labels = torch.LongTensor([3, 4, 5]) 32 | >>> self.add_gt_(new_labels) 33 | >>> print(str(self)) # xdoctest: +IGNORE_WANT 34 | 36 | """ 37 | 38 | def __init__(self, num_gts: int, gt_inds: jt.Var, max_overlaps: jt.Var, 39 | labels: jt.Var) -> None: 40 | self.num_gts = num_gts 41 | self.gt_inds = gt_inds 42 | self.max_overlaps = max_overlaps 43 | self.labels = labels 44 | # Interface for possible user-defined properties 45 | self._extra_properties = {} 46 | 47 | @property 48 | def num_preds(self): 49 | """int: the number of predictions in this assignment""" 50 | return len(self.gt_inds) 51 | 52 | def set_extra_property(self, key, value): 53 | """Set user-defined new property.""" 54 | assert key not in self.info 55 | self._extra_properties[key] = value 56 | 57 | def get_extra_property(self, key): 58 | """Get user-defined property.""" 59 | return self._extra_properties.get(key, None) 60 | 61 | @property 62 | def info(self): 63 | """dict: a dictionary of info about the object""" 64 | basic_info = { 65 | 'num_gts': self.num_gts, 66 | 'num_preds': self.num_preds, 67 | 'gt_inds': self.gt_inds, 68 | 'max_overlaps': self.max_overlaps, 69 | 'labels': self.labels, 70 | } 71 | basic_info.update(self._extra_properties) 72 | return basic_info 73 | 74 | def add_gt_(self, gt_labels): 75 | """Add ground truth as assigned results. 76 | 77 | Args: 78 | gt_labels (torch.Tensor): Labels of gt boxes 79 | """ 80 | self_inds = jt.arange(1, len(gt_labels) + 1, dtype=jt.int64) 81 | self.gt_inds = jt.concat([self_inds, self.gt_inds]) 82 | 83 | self.max_overlaps = jt.concat([ 84 | jt.ones(len(gt_labels), dtype=self.max_overlaps.dtype), 85 | self.max_overlaps 86 | ]) 87 | 88 | self.labels = jt.concat([gt_labels, self.labels]) 89 | -------------------------------------------------------------------------------- /jittordet/models/losses/utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab mmdet/models/losses/utils.py 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | import functools 4 | 5 | 6 | def reduce_loss(loss, reduction): 7 | """Reduce loss as specified. 8 | 9 | Args: 10 | loss (Tensor): Elementwise loss tensor. 11 | reduction (str): Options are "none", "mean" and "sum". 12 | 13 | Return: 14 | Tensor: Reduced loss tensor. 15 | """ 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction == 'none': 18 | return loss 19 | elif reduction == 'mean': 20 | return loss.mean() 21 | elif reduction == 'sum': 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. 31 | reduction (str): Same as built-in losses of PyTorch. 32 | avg_factor (float): Average factor when computing the mean of losses. 33 | 34 | Returns: 35 | Tensor: Processed loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | loss = loss * weight 40 | 41 | # if avg_factor is not specified, just reduce the loss 42 | if avg_factor is None: 43 | loss = reduce_loss(loss, reduction) 44 | else: 45 | # if reduction is mean, then average the loss by avg_factor 46 | if reduction == 'mean': 47 | # Avoid causing ZeroDivisionError when avg_factor is 0.0, 48 | # i.e., all labels of an image belong to ignore index. 49 | eps = 1e-6 50 | loss = loss.sum() / (avg_factor + eps) 51 | # if reduction is 'none', then do nothing, otherwise raise an error 52 | elif reduction != 'none': 53 | raise ValueError('avg_factor can not be used with reduction="sum"') 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | avg_factor=None, **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, avg_factor=2) 85 | tensor(1.5000) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, 90 | target, 91 | weight=None, 92 | reduction='mean', 93 | avg_factor=None, 94 | **kwargs): 95 | # get element-wise loss 96 | loss = loss_func(pred, target, **kwargs) 97 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 98 | return loss 99 | 100 | return wrapper 101 | -------------------------------------------------------------------------------- /configs/faster_rcnn/MSFA_faster_rcnn_r50_fpn_sardet_1x.yml: -------------------------------------------------------------------------------- 1 | _base_: 2 | - ../_dataset_/sardet100k.yml 3 | - ../_common_/default_setting.yml 4 | - ../_common_/loop_1x.yml 5 | - ../_common_/adamw_0_0001.yml 6 | 7 | model: 8 | type: MultiStageFramework 9 | preprocessor: 10 | type: Preprocessor 11 | mean: [123.675, 116.28, 103.53] 12 | std: [58.395, 57.12, 57.375] 13 | bgr_to_rgb: true 14 | pad_size_divisor: 32 15 | backbone: 16 | use_sar: True 17 | use_wavelet: True 18 | backbone: 19 | type: ResNet 20 | depth: 50 21 | frozen_stages: 1 22 | norm_eval: true 23 | return_stages: ['layer1', 'layer2', 'layer3', 'layer4'] 24 | pretrained: 'jittorhub://resnet50.pkl' 25 | neck: 26 | type: FPN 27 | in_channels: [256, 512, 1024, 2048] 28 | out_channels: 256 29 | num_outs: 5 30 | rpn_head: 31 | type: RPNHead 32 | num_classes: 1 33 | in_channels: 256 34 | feat_channels: 256 35 | anchor_generator: 36 | type: AnchorGenerator 37 | scales: [8] 38 | ratios: [0.5, 1.0, 2.0] 39 | strides: [4, 8, 16, 32, 64] 40 | bbox_coder: 41 | type: DeltaXYWHBBoxCoder 42 | target_means: [.0, .0, .0, .0] 43 | target_stds: [1.0, 1.0, 1.0, 1.0] 44 | loss_cls: 45 | type: CrossEntropyLoss 46 | use_sigmoid: true 47 | loss_weight: 1.0 48 | loss_bbox: 49 | type: L1Loss 50 | loss_weight: 1.0 51 | roi_head: 52 | type: StandardRoIHead 53 | bbox_roi_extractor: 54 | type: SingleRoIExtractor 55 | roi_layer: 56 | type: ROIAlign 57 | output_size: 7 58 | sampling_ratio: 0 59 | out_channels: 256 60 | featmap_strides: [4, 8, 16, 32] 61 | bbox_head: 62 | type: Shared2FCBBoxHead 63 | in_channels: 256 64 | fc_out_channels: 1024 65 | roi_feat_size: 7 66 | num_classes: 80 67 | bbox_coder: 68 | type: DeltaXYWHBBoxCoder 69 | target_means: [.0, .0, .0, .0] 70 | target_stds: [0.1, 0.1, 0.2, 0.2] 71 | reg_class_agnostic: false 72 | loss_cls: 73 | type: CrossEntropyLoss 74 | use_sigmoid: false 75 | loss_weight: 1.0 76 | loss_bbox: 77 | type: L1Loss 78 | loss_weight: 1.0 79 | train_cfg: 80 | rpn: 81 | assigner: 82 | type: MaxIoUAssigner 83 | pos_iou_thr: 0.7 84 | neg_iou_thr: 0.3 85 | min_pos_iou: 0.3 86 | match_low_quality: true 87 | ignore_iof_thr: -1 88 | sampler: 89 | type: RandomSampler 90 | num: 256 91 | pos_fraction: 0.5 92 | neg_pos_ub: -1 93 | add_gt_as_proposals: false 94 | allowed_border: -1 95 | pos_weight: -1 96 | rpn_proposal: 97 | nms_pre: 2000 98 | max_per_img: 1000 99 | nms: 100 | type: 'nms' 101 | thresh: 0.7 102 | min_bbox_size: 0 103 | rcnn: 104 | assigner: 105 | type: MaxIoUAssigner 106 | pos_iou_thr: 0.5 107 | neg_iou_thr: 0.5 108 | min_pos_iou: 0.5 109 | match_low_quality: false 110 | ignore_iof_thr: -1 111 | sampler: 112 | type: RandomSampler 113 | num: 512 114 | pos_fraction: 0.25 115 | neg_pos_ub: -1 116 | add_gt_as_proposals: true 117 | pos_weight: -1 118 | test_cfg: 119 | rpn: 120 | nms_pre: 1000 121 | max_per_img: 1000 122 | nms: 123 | type: nms 124 | thresh: 0.7 125 | min_bbox_size: 0 126 | rcnn: 127 | score_thr: 0.05 128 | max_per_img: 100 129 | nms: 130 | type: nms 131 | thresh: 0.5 132 | 133 | -------------------------------------------------------------------------------- /jittordet/engine/evaluator/base_evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import pickle 4 | import shutil 5 | import tempfile 6 | from abc import ABCMeta, abstractmethod 7 | 8 | import jittor as jt 9 | 10 | 11 | class BaseEvaluator(metaclass=ABCMeta): 12 | 13 | @abstractmethod 14 | def process(self, dataset, data_samples): 15 | pass 16 | 17 | @abstractmethod 18 | def compute_metrics(self, dataset, results, logger): 19 | pass 20 | 21 | def evaluate(self, dataset, logger): 22 | results = self.collect_results(self.results, len(dataset.data_list)) 23 | if not jt.in_mpi or jt.rank == 0: 24 | metrics = self.compute_metrics(dataset, results, logger) 25 | else: 26 | metrics = None 27 | metrics = self.broadcast_metrics(metrics) 28 | self.results.clear() 29 | return metrics 30 | 31 | def gen_broadcasted_tmpdir(self): 32 | MAX_LEN = 512 33 | # 32 is whitespace 34 | dir_tensor = jt.full((MAX_LEN, ), 32, dtype=jt.uint8) 35 | if jt.rank == 0: 36 | if not osp.exists('.dist_test'): 37 | os.makedirs('.dist_test') 38 | tmpdir = tempfile.mkdtemp(dir='.dist_test') 39 | tmpdir = jt.array(bytearray(tmpdir.encode()), dtype=jt.uint8) 40 | dir_tensor[:len(tmpdir)] = tmpdir 41 | dir_tensor = dir_tensor.mpi_broadcast(root=0) 42 | tmpdir = dir_tensor.numpy().tobytes().decode().rstrip() 43 | return tmpdir 44 | 45 | def collect_results(self, result_part, size): 46 | """Collect results under cpu mode.""" 47 | rank, world_size = jt.rank, jt.world_size 48 | if world_size == 1: 49 | return result_part[:size] 50 | 51 | tmpdir = self.gen_broadcasted_tmpdir() 52 | # dump the part result to the dir 53 | with open(osp.join(tmpdir, f'part_{rank}.pkl'), 54 | 'wb') as f: # type: ignore 55 | pickle.dump(result_part, f, protocol=2) 56 | 57 | self.barrier() 58 | 59 | if rank != 0: 60 | return None 61 | else: 62 | # load results of all parts from tmp dir 63 | part_list = [] 64 | for i in range(world_size): 65 | path = osp.join(tmpdir, f'part_{i}.pkl') # type: ignore 66 | with open(path, 'rb') as f: 67 | part_list.append(pickle.load(f)) 68 | # sort the results 69 | ordered_results = [] 70 | for res in zip(*part_list): 71 | ordered_results.extend(list(res)) 72 | # the dataloader may pad some samples 73 | ordered_results = ordered_results[:size] 74 | # remove tmp dir 75 | shutil.rmtree(tmpdir) # type: ignore 76 | return ordered_results 77 | 78 | def broadcast_metrics(self, metrics): 79 | if jt.world_size == 1: 80 | return metrics 81 | 82 | tmpdir = self.gen_broadcasted_tmpdir() 83 | if jt.rank == 0: 84 | with open(osp.join(tmpdir, 'metrics.pkl'), 'wb') as f: 85 | pickle.dump(metrics, f, protocol=2) 86 | 87 | self.barrier() 88 | 89 | with open(osp.join(tmpdir, 'metrics.pkl'), 'rb') as f: 90 | metrics = pickle.load(f) 91 | 92 | if jt.rank == 0: 93 | shutil.rmtree(tmpdir) # type: ignore 94 | return metrics 95 | 96 | @staticmethod 97 | def barrier(): 98 | if jt.in_mpi: 99 | lock = jt.array([1]) 100 | lock = lock.mpi_all_reduce('mean') 101 | lock.sync(device_sync=True) 102 | del lock 103 | -------------------------------------------------------------------------------- /jittordet/engine/hooks/base_hook.py: -------------------------------------------------------------------------------- 1 | # copy from OpenMMLab mmengine.hooks 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | 4 | 5 | class BaseHook: 6 | 7 | priority = 50 8 | 9 | def before_train(self, runner): 10 | pass 11 | 12 | def after_train(self, runner): 13 | pass 14 | 15 | def before_val(self, runner): 16 | pass 17 | 18 | def after_val(self, runner): 19 | pass 20 | 21 | def before_test(self, runner): 22 | pass 23 | 24 | def after_test(self, runner): 25 | pass 26 | 27 | def before_train_epoch(self, runner): 28 | self._before_epoch(runner, mode='train') 29 | 30 | def before_val_epoch(self, runner): 31 | self._before_epoch(runner, mode='val') 32 | 33 | def before_test_epoch(self, runner): 34 | self._before_epoch(runner, mode='test') 35 | 36 | def after_train_epoch(self, runner): 37 | self._after_epoch(runner, mode='train') 38 | 39 | def after_val_epoch(self, runner, metrics=None): 40 | self._after_epoch(runner, mode='val') 41 | 42 | def after_test_epoch(self, runner, metrics=None): 43 | self._after_epoch(runner, mode='test') 44 | 45 | def before_train_iter(self, runner, batch_idx, data_batch=None): 46 | self._before_iter( 47 | runner, batch_idx=batch_idx, data_batch=data_batch, mode='train') 48 | 49 | def before_val_iter(self, runner, batch_idx, data_batch=None): 50 | self._before_iter( 51 | runner, batch_idx=batch_idx, data_batch=data_batch, mode='val') 52 | 53 | def before_test_iter(self, runner, batch_idx, data_batch=None): 54 | self._before_iter( 55 | runner, batch_idx=batch_idx, data_batch=data_batch, mode='test') 56 | 57 | def after_train_iter(self, 58 | runner, 59 | batch_idx, 60 | data_batch=None, 61 | outputs=None): 62 | self._after_iter( 63 | runner, 64 | batch_idx=batch_idx, 65 | data_batch=data_batch, 66 | outputs=outputs, 67 | mode='train') 68 | 69 | def after_val_iter(self, 70 | runner, 71 | batch_idx, 72 | data_batch=None, 73 | outputs=None) -> None: 74 | self._after_iter( 75 | runner, 76 | batch_idx=batch_idx, 77 | data_batch=data_batch, 78 | outputs=outputs, 79 | mode='val') 80 | 81 | def after_test_iter(self, 82 | runner, 83 | batch_idx, 84 | data_batch=None, 85 | outputs=None) -> None: 86 | self._after_iter( 87 | runner, 88 | batch_idx=batch_idx, 89 | data_batch=data_batch, 90 | outputs=outputs, 91 | mode='test') 92 | 93 | def _before_epoch(self, runner, mode='train'): 94 | pass 95 | 96 | def _after_epoch(self, runner, mode='train'): 97 | pass 98 | 99 | def _before_iter(self, 100 | runner, 101 | batch_idx, 102 | data_batch=None, 103 | mode='train') -> None: 104 | pass 105 | 106 | def _after_iter(self, 107 | runner, 108 | batch_idx, 109 | data_batch=None, 110 | outputs=None, 111 | mode='train'): 112 | pass 113 | 114 | def every_n_interval(self, idx, n): 115 | return (idx + 1) % n == 0 if n > 0 else False 116 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/samplers/sampling_result.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab. 2 | # mmdet/models/task_modules/samplers/sampling_result.py 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | import jittor as jt 5 | 6 | from ..assigners import AssignResult 7 | 8 | 9 | class SamplingResult: 10 | """Bbox sampling result. 11 | 12 | Args: 13 | pos_inds (Tensor): Indices of positive samples. 14 | neg_inds (Tensor): Indices of negative samples. 15 | priors (Tensor): The priors can be anchors or points, 16 | or the bboxes predicted by the previous stage. 17 | gt_bboxes (Tensor): Ground truth of bboxes. 18 | assign_result (:obj:`AssignResult`): Assigning results. 19 | gt_flags (Tensor): The Ground truth flags. 20 | avg_factor_with_neg (bool): If True, ``avg_factor`` equal to 21 | the number of total priors; Otherwise, it is the number of 22 | positive priors. Defaults to True. 23 | 24 | Example: 25 | >>> # xdoctest: +IGNORE_WANT 26 | >>> from mmdet.models.task_modules.samplers.sampling_result import * # NOQA 27 | >>> self = SamplingResult.random(rng=10) 28 | >>> print(f'self = {self}') 29 | self = 42 | """ 43 | 44 | def __init__(self, 45 | pos_inds: jt.Var, 46 | neg_inds: jt.Var, 47 | priors: jt.Var, 48 | gt_bboxes: jt.Var, 49 | assign_result: AssignResult, 50 | gt_flags: jt.Var, 51 | avg_factor_with_neg: bool = True) -> None: 52 | self.pos_inds = pos_inds 53 | self.neg_inds = neg_inds 54 | self.num_pos = max(pos_inds.numel(), 1) 55 | self.num_neg = max(neg_inds.numel(), 1) 56 | self.avg_factor_with_neg = avg_factor_with_neg 57 | self.avg_factor = self.num_pos + self.num_neg \ 58 | if avg_factor_with_neg else self.num_pos 59 | self.pos_priors = priors[pos_inds] 60 | self.neg_priors = priors[neg_inds] 61 | self.pos_is_gt = gt_flags[pos_inds] 62 | 63 | self.num_gts = gt_bboxes.shape[0] 64 | self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 65 | self.pos_gt_labels = assign_result.labels[pos_inds] 66 | if gt_bboxes.numel() == 0: 67 | # hack for index error case 68 | assert self.pos_assigned_gt_inds.numel() == 0 69 | self.pos_gt_bboxes = gt_bboxes.view(-1, 4) 70 | else: 71 | if len(gt_bboxes.shape) < 2: 72 | gt_bboxes = gt_bboxes.view(-1, 4) 73 | self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long()] 74 | 75 | @property 76 | def priors(self): 77 | """torch.Tensor: concatenated positive and negative priors""" 78 | return jt.concat([self.pos_priors, self.neg_priors]) 79 | 80 | @property 81 | def info(self): 82 | """Returns a dictionary of info about the object.""" 83 | return { 84 | 'pos_inds': self.pos_inds, 85 | 'neg_inds': self.neg_inds, 86 | 'pos_priors': self.pos_priors, 87 | 'neg_priors': self.neg_priors, 88 | 'pos_is_gt': self.pos_is_gt, 89 | 'num_gts': self.num_gts, 90 | 'pos_assigned_gt_inds': self.pos_assigned_gt_inds, 91 | 'num_pos': self.num_pos, 92 | 'num_neg': self.num_neg, 93 | 'avg_factor': self.avg_factor 94 | } 95 | -------------------------------------------------------------------------------- /jittordet/engine/config/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from collections.abc import Mapping 4 | from typing import List, Optional, Union 5 | 6 | from addict import Dict 7 | 8 | from .dumpers import cfg_dumpers 9 | from .parsers import cfg_parsers 10 | from .readers import cfg_readers 11 | from .utils import delete_node 12 | 13 | BASE_KEY = '_base_' 14 | COVER_KEY = '_cover_' 15 | RESERVED_KEYS = ['filename'] 16 | 17 | 18 | class ConfigDict(Dict): 19 | """copy from mmengine https://github.com/open- 20 | mmlab/mmengine/blob/main/mmengine/config/config.py. 21 | 22 | A dictionary for config which has the same interface as python's built- in 23 | dictionary and can be used as a normal dictionary. The Config class would 24 | transform the nested fields (dictionary-like fields) in config file into 25 | ``ConfigDict``. 26 | """ 27 | 28 | def __missing__(self, name): 29 | raise KeyError(name) 30 | 31 | def __getattr__(self, name): 32 | try: 33 | value = super().__getattr__(name) 34 | except KeyError: 35 | raise AttributeError(f"'{self.__class__.__name__}' object has no " 36 | f"attribute '{name}'") 37 | except Exception as e: 38 | raise e 39 | else: 40 | return value 41 | 42 | 43 | ConfigType = Union[ConfigDict, dict] 44 | OptConfigType = Optional[ConfigType] 45 | 46 | MultiConfig = Union[ConfigType, List[ConfigType]] 47 | OptMultiConfig = Optional[MultiConfig] 48 | 49 | 50 | def load_cfg(filepath): 51 | """load cfg from different file.""" 52 | assert osp.isfile(filepath), f'{filepath} is not a exist file' 53 | ext = osp.splitext(filepath)[-1] 54 | if ext not in cfg_readers: 55 | raise NotImplementedError( 56 | f'Cannot parse "{filepath}" with {ext} type yet') 57 | cfg = ConfigDict(cfg_readers[ext](filepath)) 58 | for key in RESERVED_KEYS: 59 | if key in cfg: 60 | raise KeyError('f"{key}" is a reserved key') 61 | 62 | # use parsers to translate some leaves 63 | cfg['filename'] = filepath 64 | for parser in cfg_parsers: 65 | parser(cfg) 66 | 67 | if BASE_KEY in cfg: 68 | base_cfg_paths = cfg.pop(BASE_KEY) 69 | if isinstance(base_cfg_paths, str): 70 | base_cfg_paths = [base_cfg_paths] 71 | all_cfg = ConfigDict() 72 | root_path = osp.dirname(filepath) 73 | for base_cfg_path in base_cfg_paths: 74 | if base_cfg_path.startswith('~'): 75 | base_cfg_path = osp.expanduser(base_cfg_path) 76 | if base_cfg_path.startswith('.'): 77 | base_cfg_path = osp.join(root_path, base_cfg_path) 78 | all_cfg.update(load_cfg(base_cfg_path)) 79 | merge_cfg(all_cfg, cfg) 80 | cfg = all_cfg 81 | 82 | delete_node(cfg, COVER_KEY) 83 | return cfg 84 | 85 | 86 | def merge_cfg(cfg_a, cfg_b): 87 | """merge cfg_b into cfg_a.""" 88 | for k, v in cfg_b.items(): 89 | if k in cfg_a and (isinstance(v, Mapping) 90 | and isinstance(cfg_a[k], Mapping)): 91 | if v.pop(COVER_KEY, False): 92 | cfg_a[k] = v 93 | else: 94 | merge_cfg(cfg_a[k], v) 95 | else: 96 | cfg_a[k] = v 97 | 98 | 99 | def dump_cfg(cfg, filepath, allow_exist=False, create_dir=True): 100 | """dump cfg into different files.""" 101 | ext = osp.splitext(filepath)[-1] 102 | dir_name = osp.dirname(filepath) 103 | if ext not in cfg_dumpers: 104 | raise NotImplementedError(f'Cannot dump cfg to {ext} type file yet') 105 | 106 | if osp.exists(filepath) and not allow_exist: 107 | raise FileExistsError('The target file has existed') 108 | 109 | if dir_name and not osp.exists(dir_name) and create_dir: 110 | os.makedirs(dir_name) 111 | 112 | cfg_dumpers[ext](cfg, filepath) 113 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/samplers/random_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab. 2 | # mmdet/models/task_modules/samplers/random_sampler.py 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | from typing import Union 5 | 6 | import jittor as jt 7 | from numpy import ndarray 8 | 9 | from jittordet.engine import TASK_UTILS 10 | from ..assigners import AssignResult 11 | from .base_sampler import BaseSampler 12 | 13 | 14 | @TASK_UTILS.register_module() 15 | class RandomSampler(BaseSampler): 16 | """Random sampler. 17 | 18 | Args: 19 | num (int): Number of samples 20 | pos_fraction (float): Fraction of positive samples 21 | neg_pos_up (int): Upper bound number of negative and 22 | positive samples. Defaults to -1. 23 | add_gt_as_proposals (bool): Whether to add ground truth 24 | boxes as proposals. Defaults to True. 25 | """ 26 | 27 | def __init__(self, 28 | num: int, 29 | pos_fraction: float, 30 | neg_pos_ub: int = -1, 31 | add_gt_as_proposals: bool = True, 32 | **kwargs): 33 | from jittordet.utils import ensure_rng 34 | super().__init__( 35 | num=num, 36 | pos_fraction=pos_fraction, 37 | neg_pos_ub=neg_pos_ub, 38 | add_gt_as_proposals=add_gt_as_proposals) 39 | self.rng = ensure_rng(kwargs.get('rng', None)) 40 | 41 | def random_choice(self, gallery: Union[jt.Var, ndarray, list], 42 | num: int) -> Union[jt.Var, ndarray]: 43 | """Random select some elements from the gallery. 44 | 45 | If `gallery` is a Tensor, the returned indices will be a Tensor; 46 | If `gallery` is a ndarray or list, the returned indices will be a 47 | ndarray. 48 | 49 | Args: 50 | gallery (Tensor | ndarray | list): indices pool. 51 | num (int): expected sample num. 52 | 53 | Returns: 54 | Tensor or ndarray: sampled indices. 55 | """ 56 | assert len(gallery) >= num 57 | 58 | is_tensor = isinstance(gallery, jt.Var) 59 | if not is_tensor: 60 | gallery = jt.array(gallery, dtype=jt.int64) 61 | # This is a temporary fix. We can revert the following code 62 | # when PyTorch fixes the abnormal return of torch.randperm. 63 | # See: https://github.com/open-mmlab/mmdetection/pull/5014 64 | perm = jt.randperm(gallery.numel())[:num] 65 | rand_inds = gallery[perm] 66 | if not is_tensor: 67 | rand_inds = rand_inds.numpy() 68 | return rand_inds 69 | 70 | def _sample_pos(self, assign_result: AssignResult, num_expected: int, 71 | **kwargs) -> Union[jt.Var, ndarray]: 72 | """Randomly sample some positive samples. 73 | 74 | Args: 75 | assign_result (:obj:`AssignResult`): Bbox assigning results. 76 | num_expected (int): The number of expected positive samples 77 | 78 | Returns: 79 | Tensor or ndarray: sampled indices. 80 | """ 81 | pos_inds = jt.nonzero(assign_result.gt_inds > 0) 82 | if pos_inds.numel() != 0: 83 | pos_inds = pos_inds.squeeze(1) 84 | if pos_inds.numel() <= num_expected: 85 | return pos_inds 86 | else: 87 | return self.random_choice(pos_inds, num_expected) 88 | 89 | def _sample_neg(self, assign_result: AssignResult, num_expected: int, 90 | **kwargs) -> Union[jt.Var, ndarray]: 91 | """Randomly sample some negative samples. 92 | 93 | Args: 94 | assign_result (:obj:`AssignResult`): Bbox assigning results. 95 | num_expected (int): The number of expected positive samples 96 | 97 | Returns: 98 | Tensor or ndarray: sampled indices. 99 | """ 100 | neg_inds = jt.nonzero(assign_result.gt_inds == 0) 101 | if neg_inds.numel() != 0: 102 | neg_inds = neg_inds.squeeze(1) 103 | if len(neg_inds) <= num_expected: 104 | return neg_inds 105 | else: 106 | return self.random_choice(neg_inds, num_expected) 107 | -------------------------------------------------------------------------------- /jittordet/models/roi_heads/roi_extractors/base_roi_extractor.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab. 2 | # mmdet/models/roi_heads/roi_extractors/base_roi_extractor.py 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | from abc import ABCMeta, abstractmethod 5 | from typing import List, Optional, Tuple 6 | 7 | import jittor as jt 8 | import jittor.nn as nn 9 | 10 | from jittordet import ops 11 | from jittordet.engine import ConfigType 12 | 13 | 14 | class BaseRoIExtractor(nn.Module, metaclass=ABCMeta): 15 | """Base class for RoI extractor. 16 | 17 | Args: 18 | roi_layer (:obj:`ConfigDict` or dict): Specify RoI layer type and 19 | arguments. 20 | out_channels (int): Output channels of RoI layers. 21 | featmap_strides (list[int]): Strides of input feature maps. 22 | init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ 23 | dict], optional): Initialization config dict. Defaults to None. 24 | """ 25 | 26 | def __init__(self, roi_layer: ConfigType, out_channels: int, 27 | featmap_strides: List[int]) -> None: 28 | super().__init__() 29 | self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides) 30 | self.out_channels = out_channels 31 | self.featmap_strides = featmap_strides 32 | 33 | @property 34 | def num_inputs(self) -> int: 35 | """int: Number of input feature maps.""" 36 | return len(self.featmap_strides) 37 | 38 | def build_roi_layers(self, layer_cfg: ConfigType, 39 | featmap_strides: List[int]) -> nn.ModuleList: 40 | """Build RoI operator to extract feature from each level feature map. 41 | 42 | Args: 43 | layer_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and 44 | config RoI layer operation. Options are modules under 45 | ``mmcv/ops`` such as ``RoIAlign``. 46 | featmap_strides (list[int]): The stride of input feature map w.r.t 47 | to the original image size, which would be used to scale RoI 48 | coordinate (original image coordinate system) to feature 49 | coordinate system. 50 | Returns: 51 | :obj:`nn.ModuleList`: The RoI extractor modules for each level 52 | feature map. 53 | """ 54 | 55 | cfg = layer_cfg.copy() 56 | layer_type = cfg.pop('type') 57 | assert hasattr(ops, layer_type) 58 | layer_cls = getattr(ops, layer_type) 59 | roi_layers = nn.ModuleList( 60 | [layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides]) 61 | return roi_layers 62 | 63 | def roi_rescale(self, rois: jt.Var, scale_factor: float) -> jt.Var: 64 | """Scale RoI coordinates by scale factor. 65 | 66 | Args: 67 | rois (Tensor): RoI (Region of Interest), shape (n, 5) 68 | scale_factor (float): Scale factor that RoI will be multiplied by. 69 | Returns: 70 | Tensor: Scaled RoI. 71 | """ 72 | 73 | cx = (rois[:, 1] + rois[:, 3]) * 0.5 74 | cy = (rois[:, 2] + rois[:, 4]) * 0.5 75 | w = rois[:, 3] - rois[:, 1] 76 | h = rois[:, 4] - rois[:, 2] 77 | new_w = w * scale_factor 78 | new_h = h * scale_factor 79 | x1 = cx - new_w * 0.5 80 | x2 = cx + new_w * 0.5 81 | y1 = cy - new_h * 0.5 82 | y2 = cy + new_h * 0.5 83 | new_rois = jt.stack((rois[:, 0], x1, y1, x2, y2), dim=-1) 84 | return new_rois 85 | 86 | @abstractmethod 87 | def execute(self, 88 | feats: Tuple[jt.Var], 89 | rois: jt.Var, 90 | roi_scale_factor: Optional[float] = None) -> jt.Var: 91 | """Extractor ROI feats. 92 | 93 | Args: 94 | feats (Tuple[Tensor]): Multi-scale features. 95 | rois (Tensor): RoIs with the shape (n, 5) where the first 96 | column indicates batch id of each RoI. 97 | roi_scale_factor (Optional[float]): RoI scale factor. 98 | Defaults to None. 99 | Returns: 100 | Tensor: RoI feature. 101 | """ 102 | pass 103 | -------------------------------------------------------------------------------- /jittordet/engine/register/register.py: -------------------------------------------------------------------------------- 1 | # modified from openmmlab mmengine 2 | # https://github.com/open-mmlab/mmengine/blob/main/mmengine/registry/registry.py 3 | 4 | import copy 5 | import inspect 6 | from functools import partial 7 | 8 | 9 | class Register: 10 | 11 | def __init__(self, name): 12 | self._name = name 13 | self._module_dict = dict() 14 | 15 | def __len__(self): 16 | return len(self._module_dict) 17 | 18 | def __contains__(self, key): 19 | return key in self._module_dict 20 | 21 | def __repr__(self): 22 | format_str = self.__class__.__name__ + \ 23 | f'(name={self._name}, ' \ 24 | f'items={self._module_dict})' 25 | return format_str 26 | 27 | @property 28 | def name(self): 29 | return self._name 30 | 31 | @property 32 | def module_dict(self): 33 | return self._module_dict 34 | 35 | def get(self, key): 36 | return self._module_dict.get(key) 37 | 38 | def _register_module(self, module, module_name=None, force=False): 39 | if not inspect.isclass(module) and not inspect.isfunction(module): 40 | raise TypeError('module must be a class or a function, ' 41 | f'but got {type(module)}') 42 | 43 | if module_name is None: 44 | module_name = module.__name__ 45 | if isinstance(module_name, str): 46 | module_name = [module_name] 47 | for name in module_name: 48 | if not force and name in self._module_dict: 49 | existed_module = self.module_dict[name] 50 | raise KeyError(f'{name} is already registered in {self.name} ' 51 | f'at {existed_module.__module__}') 52 | self._module_dict[name] = module 53 | 54 | def register_module(self, name=None, force=False, module=None): 55 | if not isinstance(force, bool): 56 | raise TypeError(f'force must be a boolean, but got {type(force)}') 57 | 58 | # raise the error ahead of time 59 | if not (name is None or isinstance(name, str) 60 | or isinstance(name, (list, tuple))): 61 | raise TypeError( 62 | 'name must be None, an instance of str, or a sequence of str, ' 63 | f'but got {type(name)}') 64 | 65 | # use it as a normal method: x.register_module(module=SomeClass) 66 | if module is not None: 67 | self._register_module(module=module, module_name=name, force=force) 68 | return module 69 | 70 | # use it as a decorator: @x.register_module() 71 | def _register(module): 72 | self._register_module(module=module, module_name=name, force=force) 73 | return module 74 | 75 | return _register 76 | 77 | def build(self, cfg, **default_args): 78 | if not isinstance(cfg, dict): 79 | raise TypeError(f'cfg should be a dict, but got {type(cfg)}') 80 | 81 | if not (isinstance(default_args, dict) or default_args is None): 82 | raise TypeError( 83 | f'default_args should be a dict, but got {type(default_args)}') 84 | 85 | if 'type' not in cfg: 86 | if default_args is None or 'type' not in default_args: 87 | raise KeyError( 88 | '`cfg` or `default_args` must contain the key "type", ' 89 | f'but got {cfg}\n{default_args}') 90 | 91 | args = copy.deepcopy(cfg) 92 | if default_args is not None: 93 | for name, value in default_args.items(): 94 | args.setdefault(name, value) 95 | 96 | obj_type = args.pop('type') 97 | if isinstance(obj_type, str): 98 | obj_cls = self.get(obj_type) 99 | if obj_cls is None: 100 | raise KeyError( 101 | f'{obj_type} is not in the {self.name} registry.') 102 | elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): 103 | obj_cls = obj_type 104 | else: 105 | raise TypeError( 106 | f'type must be a str or valid type, but got {type(obj_type)}') 107 | 108 | if inspect.isclass(obj_cls): 109 | if hasattr(obj_cls, 'from_cfg'): 110 | return obj_cls.from_cfg(*args) 111 | else: 112 | return obj_cls(**args) 113 | elif inspect.isfunction(obj_cls): 114 | return partial(obj_cls, **args) 115 | -------------------------------------------------------------------------------- /jittordet/ops/bbox_geometry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import shapely.geometry as shgeo 3 | from jdet.ops.bbox_transforms import * 4 | 5 | 6 | def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6): 7 | assert mode in ['iou', 'iof'] 8 | assert get_bbox_type(bboxes1) != 'notype' 9 | assert get_bbox_type(bboxes2) != 'notype' 10 | rows = bboxes1.shape[0] 11 | cols = bboxes2.shape[0] 12 | if is_aligned: 13 | assert rows == cols 14 | 15 | if rows * cols == 0: 16 | return np.zeros((rows, 1), dtype=np.float32) \ 17 | if is_aligned else np.zeros((rows, cols), dtype=np.float32) 18 | 19 | hbboxes1 = bbox2type(bboxes1, 'hbb') 20 | hbboxes2 = bbox2type(bboxes2, 'hbb') 21 | if not is_aligned: 22 | hbboxes1 = hbboxes1[:, None, :] 23 | lt = np.maximum(hbboxes1[..., :2], hbboxes2[..., :2]) 24 | rb = np.minimum(hbboxes1[..., 2:], hbboxes2[..., 2:]) 25 | wh = np.clip(rb - lt, 0, np.inf) 26 | h_overlaps = wh[..., 0] * wh[..., 1] 27 | 28 | if get_bbox_type(bboxes1) == 'hbb' and get_bbox_type(bboxes2) == 'hbb': 29 | overlaps = h_overlaps 30 | areas1 = (hbboxes1[..., 2] - hbboxes1[..., 0]) * ( 31 | hbboxes1[..., 3] - hbboxes1[..., 1]) 32 | 33 | if mode == 'iou': 34 | areas2 = (hbboxes2[..., 2] - hbboxes2[..., 0]) * ( 35 | hbboxes2[..., 3] - hbboxes2[..., 1]) 36 | unions = areas1 + areas2 - overlaps 37 | else: 38 | unions = areas1 39 | 40 | else: 41 | polys1 = bbox2type(bboxes1, 'poly') 42 | polys2 = bbox2type(bboxes2, 'poly') 43 | sg_polys1 = [shgeo.Polygon(p) for p in polys1.reshape(rows, -1, 2)] 44 | sg_polys2 = [shgeo.Polygon(p) for p in polys2.reshape(cols, -1, 2)] 45 | 46 | overlaps = np.zeros(h_overlaps.shape) 47 | for p in zip(*np.nonzero(h_overlaps)): 48 | overlaps[p] = sg_polys1[p[0]].intersection(sg_polys2[p[-1]]).area 49 | 50 | if mode == 'iou': 51 | unions = np.zeros(h_overlaps.shape, dtype=np.float32) 52 | for p in zip(*np.nonzero(h_overlaps)): 53 | unions[p] = sg_polys1[p[0]].union(sg_polys2[p[-1]]).area 54 | else: 55 | unions = np.array([p.area for p in sg_polys1], dtype=np.float32) 56 | if not is_aligned: 57 | unions = unions[..., None] 58 | 59 | unions = np.clip(unions, eps, np.inf) 60 | outputs = overlaps / unions 61 | if outputs.ndim == 1: 62 | outputs = outputs[..., None] 63 | return outputs 64 | 65 | 66 | def bbox_areas(bboxes): 67 | bbox_type = get_bbox_type(bboxes) 68 | assert bbox_type != 'notype' 69 | 70 | if bbox_type == 'hbb': 71 | areas = (bboxes[..., 2] - bboxes[..., 0]) * ( 72 | bboxes[..., 3] - bboxes[..., 1]) 73 | 74 | if bbox_type == 'obb': 75 | areas = bboxes[..., 2] * bboxes[..., 3] 76 | 77 | if bbox_type == 'poly': 78 | areas = np.zeros(bboxes.shape[:-1], dtype=np.float32) 79 | bboxes = bboxes.reshape(*bboxes.shape[:-1], 4, 2) 80 | for i in range(4): 81 | areas += 0.5 * ( 82 | bboxes[..., i, 0] * bboxes[..., (i + 1) % 4, 1] - 83 | bboxes[..., (i + 1) % 4, 0] * bboxes[..., i, 1]) 84 | areas = np.abs(areas) 85 | return areas 86 | 87 | 88 | def bbox_nms(bboxes, scores, iou_thr=0.5, score_thr=0.01): 89 | assert get_bbox_type(bboxes) != 'notype' 90 | order = scores.argsort()[::-1] 91 | order = order[scores[order] > score_thr] 92 | keep = [] 93 | 94 | while order.size > 0: 95 | i = order[0] 96 | keep.append(i) 97 | 98 | keep_bbox = bboxes[[i]] 99 | other_bboxes = bboxes[order[1:]] 100 | ious = bbox_overlaps(keep_bbox, other_bboxes) 101 | 102 | idx = np.where(ious <= iou_thr)[1] 103 | order = order[idx + 1] 104 | 105 | return np.array(keep, dtype=np.int64) 106 | 107 | 108 | def bbox_area_nms(bboxes, iou_thr=0.5): 109 | assert get_bbox_type(bboxes) != 'notype' 110 | areas = bbox_areas(bboxes) 111 | order = areas.argsort()[::-1] 112 | keep = [] 113 | 114 | while order.size > 0: 115 | i = order[0] 116 | keep.append(i) 117 | 118 | keep_bbox = bboxes[[i]] 119 | other_bboxes = bboxes[order[1:]] 120 | ious = bbox_overlaps(keep_bbox, other_bboxes) 121 | 122 | idx = np.where(ious <= iou_thr)[1] 123 | order = order[idx + 1] 124 | 125 | return np.array(keep, dtype=np.int64) 126 | -------------------------------------------------------------------------------- /jittordet/models/roi_heads/base_roi_head.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab mmdet/models/roi_heads/base_roi_head.py 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | from abc import ABCMeta, abstractmethod 4 | from typing import Tuple 5 | 6 | import jittor as jt 7 | import jittor.nn as nn 8 | 9 | from jittordet.engine import MODELS, OptConfigType, OptMultiConfig 10 | from jittordet.structures import InstanceList, SampleList 11 | 12 | 13 | class BaseRoIHead(nn.Module, metaclass=ABCMeta): 14 | """Base class for RoIHeads.""" 15 | 16 | def __init__(self, 17 | bbox_roi_extractor: OptMultiConfig = None, 18 | bbox_head: OptConfigType = None, 19 | shared_head: OptConfigType = None, 20 | train_cfg: OptConfigType = None, 21 | test_cfg: OptConfigType = None) -> None: 22 | super().__init__() 23 | self.train_cfg = train_cfg 24 | self.test_cfg = test_cfg 25 | if shared_head is not None: 26 | self.shared_head = MODELS.build(shared_head) 27 | 28 | if bbox_head is not None: 29 | self.init_bbox_head(bbox_roi_extractor, bbox_head) 30 | 31 | self.init_assigner_sampler() 32 | 33 | @property 34 | def with_bbox(self) -> bool: 35 | """bool: whether the RoI head contains a `bbox_head`""" 36 | return hasattr(self, 'bbox_head') and self.bbox_head is not None 37 | 38 | @property 39 | def with_shared_head(self) -> bool: 40 | """bool: whether the RoI head contains a `shared_head`""" 41 | return hasattr(self, 'shared_head') and self.shared_head is not None 42 | 43 | @abstractmethod 44 | def init_bbox_head(self, *args, **kwargs): 45 | """Initialize ``bbox_head``""" 46 | pass 47 | 48 | @abstractmethod 49 | def init_assigner_sampler(self, *args, **kwargs): 50 | """Initialize assigner and sampler.""" 51 | pass 52 | 53 | @abstractmethod 54 | def loss(self, x: Tuple[jt.Var], rpn_results_list: InstanceList, 55 | batch_data_samples: SampleList): 56 | """Perform forward propagation and loss calculation of the roi head on 57 | the features of the upstream network.""" 58 | 59 | def predict(self, 60 | x: Tuple[jt.Var], 61 | rpn_results_list: InstanceList, 62 | batch_data_samples: SampleList, 63 | rescale: bool = False) -> InstanceList: 64 | """Perform forward propagation of the roi head and predict detection 65 | results on the features of the upstream network. 66 | Args: 67 | x (tuple[Tensor]): Features from upstream network. Each 68 | has shape (N, C, H, W). 69 | rpn_results_list (list[:obj:`InstanceData`]): list of region 70 | proposals. 71 | batch_data_samples (List[:obj:`DetDataSample`]): The Data 72 | Samples. It usually includes information such as 73 | `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. 74 | rescale (bool): Whether to rescale the results to 75 | the original image. Defaults to True. 76 | Returns: 77 | list[obj:`InstanceData`]: Detection results of each image. 78 | Each item usually contains following keys. 79 | - scores (Tensor): Classification scores, has a shape 80 | (num_instance, ) 81 | - labels (Tensor): Labels of bboxes, has a shape 82 | (num_instances, ). 83 | - bboxes (Tensor): Has a shape (num_instances, 4), 84 | the last dimension 4 arrange as (x1, y1, x2, y2). 85 | - masks (Tensor): Has a shape (num_instances, H, W). 86 | """ 87 | assert self.with_bbox, 'Bbox head must be implemented.' 88 | batch_img_metas = [ 89 | data_samples.metainfo for data_samples in batch_data_samples 90 | ] 91 | 92 | # TODO: nms_op in mmcv need be enhanced, the bbox result may get 93 | # difference when not rescale in bbox_head 94 | 95 | # If it has the mask branch, the bbox branch does not need 96 | # to be scaled to the original image scale, because the mask 97 | # branch will scale both bbox and mask at the same time. 98 | bbox_rescale = rescale 99 | results_list = self.predict_bbox( 100 | x, 101 | batch_img_metas, 102 | rpn_results_list, 103 | rcnn_test_cfg=self.test_cfg, 104 | rescale=bbox_rescale) 105 | return results_list 106 | -------------------------------------------------------------------------------- /jittordet/datasets/voc.py: -------------------------------------------------------------------------------- 1 | # modified from mmdetection.datasets.xml_style 2 | import os.path as osp 3 | import xml.etree.ElementTree as ET 4 | 5 | from PIL import Image 6 | 7 | from ..engine import DATASETS 8 | from .base import BaseDetDataset 9 | 10 | 11 | @DATASETS.register_module() 12 | class VocDataset(BaseDetDataset): 13 | 14 | METAINFO = { 15 | 'classes': 16 | ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 17 | 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 18 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'), 19 | # PALETTE is a list of color tuples, which is used for visualization. 20 | 'palette': [(106, 0, 228), (119, 11, 32), (165, 42, 42), (0, 0, 192), 21 | (197, 226, 255), (0, 60, 100), (0, 0, 142), (255, 77, 255), 22 | (153, 69, 1), (120, 166, 157), (0, 182, 199), 23 | (0, 226, 252), (182, 182, 255), (0, 0, 230), (220, 20, 60), 24 | (163, 255, 0), (0, 82, 0), (3, 95, 161), (0, 80, 100), 25 | (183, 130, 88)] 26 | } 27 | 28 | def load_data_list(self): 29 | assert self.metainfo.get('classes', None) is not None, \ 30 | 'CLASSES in `VocDataset` can not be None.' 31 | self.cat2label = { 32 | cat: i 33 | for i, cat in enumerate(self.metainfo['classes']) 34 | } 35 | 36 | data_list = [] 37 | for img_id in open(self.ann_file, 'r'): 38 | img_id = img_id.strip() 39 | img_path = osp.join(self.img_path, f'{img_id}.jpg') 40 | xml_path = osp.join(self.xml_path, f'{img_id}.xml') 41 | 42 | raw_img_info = {} 43 | raw_img_info['img_id'] = img_id 44 | raw_img_info['img_path'] = img_path 45 | raw_img_info['xml_path'] = xml_path 46 | 47 | parsed_data_info = self.parse_data_info(raw_img_info) 48 | data_list.append(parsed_data_info) 49 | return data_list 50 | 51 | def parse_data_info(self, img_info): 52 | data_info = {} 53 | data_info['img_id'] = img_info['img_id'] 54 | data_info['img_path'] = img_info['img_path'] 55 | data_info['xml_path'] = img_info['xml_path'] 56 | 57 | # deal with xml file 58 | raw_ann_info = ET.parse(data_info['xml_path']) 59 | root = raw_ann_info.getroot() 60 | size = root.find('size') 61 | if size is not None: 62 | width = int(size.find('width').text) 63 | height = int(size.find('height').text) 64 | else: 65 | img = Image.open(img_info['img_path']) 66 | height, width = img.height, img.width 67 | del img 68 | 69 | data_info['height'] = height 70 | data_info['width'] = width 71 | 72 | instances = [] 73 | for obj in raw_ann_info.findall('object'): 74 | instance = {} 75 | name = obj.find('name').text 76 | if name not in self._metainfo['classes']: 77 | continue 78 | difficult = obj.find('difficult') 79 | difficult = 0 if difficult is None else int(difficult.text) 80 | bnd_box = obj.find('bndbox') 81 | bbox = [ 82 | int(float(bnd_box.find('xmin').text)) - 1, 83 | int(float(bnd_box.find('ymin').text)) - 1, 84 | int(float(bnd_box.find('xmax').text)) - 1, 85 | int(float(bnd_box.find('ymax').text)) - 1 86 | ] 87 | instance['ignore_flag'] = difficult 88 | instance['bbox'] = bbox 89 | instance['bbox_label'] = self.cat2label[name] 90 | instances.append(instance) 91 | data_info['instances'] = instances 92 | return data_info 93 | 94 | def filter_data(self): 95 | if self.test_mode: 96 | return self.data_list 97 | 98 | filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \ 99 | if self.filter_cfg is not None else False 100 | min_size = self.filter_cfg.get('min_size', 0) \ 101 | if self.filter_cfg is not None else 0 102 | 103 | valid_data_infos = [] 104 | for i, data_info in enumerate(self.data_list): 105 | width = data_info['width'] 106 | height = data_info['height'] 107 | if filter_empty_gt and len(data_info['instances']) == 0: 108 | continue 109 | if min(width, height) >= min_size: 110 | valid_data_infos.append(data_info) 111 | 112 | return valid_data_infos 113 | -------------------------------------------------------------------------------- /jittordet/models/roi_heads/roi_extractors/single_level_roi_extractor.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab. 2 | # mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | from typing import List, Optional, Tuple 5 | 6 | import jittor as jt 7 | 8 | from jittordet.engine import MODELS, ConfigType 9 | from .base_roi_extractor import BaseRoIExtractor 10 | 11 | 12 | @MODELS.register_module() 13 | class SingleRoIExtractor(BaseRoIExtractor): 14 | """Extract RoI features from a single level feature map. 15 | 16 | If there are multiple input feature levels, each RoI is mapped to a level 17 | according to its scale. The mapping rule is proposed in 18 | `FPN `_. 19 | Args: 20 | roi_layer (:obj:`ConfigDict` or dict): Specify RoI layer type and 21 | arguments. 22 | out_channels (int): Output channels of RoI layers. 23 | featmap_strides (List[int]): Strides of input feature maps. 24 | finest_scale (int): Scale threshold of mapping to level 0. 25 | Defaults to 56. 26 | init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \ 27 | dict], optional): Initialization config dict. Defaults to None. 28 | """ 29 | 30 | def __init__(self, 31 | roi_layer: ConfigType, 32 | out_channels: int, 33 | featmap_strides: List[int], 34 | finest_scale: int = 56) -> None: 35 | super().__init__( 36 | roi_layer=roi_layer, 37 | out_channels=out_channels, 38 | featmap_strides=featmap_strides) 39 | self.finest_scale = finest_scale 40 | 41 | def map_roi_levels(self, rois: jt.Var, num_levels: int) -> jt.Var: 42 | """Map rois to corresponding feature levels by scales. 43 | 44 | - scale < finest_scale * 2: level 0 45 | - finest_scale * 2 <= scale < finest_scale * 4: level 1 46 | - finest_scale * 4 <= scale < finest_scale * 8: level 2 47 | - scale >= finest_scale * 8: level 3 48 | Args: 49 | rois (Tensor): Input RoIs, shape (k, 5). 50 | num_levels (int): Total level number. 51 | Returns: 52 | Tensor: Level index (0-based) of each RoI, shape (k, ) 53 | """ 54 | scale = jt.sqrt((rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2])) 55 | target_lvls = jt.floor(jt.log2(scale / self.finest_scale + 1e-6)) 56 | target_lvls = target_lvls.clamp( 57 | min_v=0, max_v=num_levels - 1).astype(jt.int64) 58 | return target_lvls 59 | 60 | def execute(self, 61 | feats: Tuple[jt.Var], 62 | rois: jt.Var, 63 | roi_scale_factor: Optional[float] = None): 64 | """Extractor ROI feats. 65 | 66 | Args: 67 | feats (Tuple[Tensor]): Multi-scale features. 68 | rois (Tensor): RoIs with the shape (n, 5) where the first 69 | column indicates batch id of each RoI. 70 | roi_scale_factor (Optional[float]): RoI scale factor. 71 | Defaults to None. 72 | Returns: 73 | Tensor: RoI feature. 74 | """ 75 | # convert fp32 to fp16 when amp is on 76 | rois = rois.type_as(feats[0]) 77 | out_size = self.roi_layers[0].output_size 78 | num_levels = len(feats) 79 | roi_feats = jt.zeros( 80 | rois.size(0), self.out_channels, *out_size, dtype=feats[0].dtype) 81 | 82 | if num_levels == 1: 83 | if len(rois) == 0: 84 | return roi_feats 85 | return self.roi_layers[0](feats[0], rois) 86 | 87 | target_lvls = self.map_roi_levels(rois, num_levels) 88 | 89 | if roi_scale_factor is not None: 90 | rois = self.roi_rescale(rois, roi_scale_factor) 91 | for i in range(num_levels): 92 | mask = target_lvls == i 93 | inds = mask.nonzero().squeeze(1) 94 | if inds.numel() > 0: 95 | rois_ = rois[inds] 96 | roi_feats_t = self.roi_layers[i](feats[i], rois_) 97 | roi_feats[inds] = roi_feats_t 98 | else: 99 | # Sometimes some pyramid levels will not be used for RoI 100 | # feature extraction and this will cause an incomplete 101 | # computation graph in one GPU, which is different from those 102 | # in other GPUs and will cause a hanging error. 103 | # Therefore, we add it to ensure each feature pyramid is 104 | # included in the computation graph to avoid runtime bugs. 105 | roi_feats += sum( 106 | x.view(-1)[0] 107 | for x in self.parameters()) * 0. + feats[i].sum() * 0. 108 | return roi_feats 109 | -------------------------------------------------------------------------------- /jittordet/models/dense_heads/retina_head.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab mmdet/models/dense_heads/retina_head.py 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | import jittor.nn as nn 4 | 5 | from jittordet.engine import MODELS 6 | from ..layers import ConvModule 7 | from ..utils import bias_init_with_prob, normal_init 8 | from .anchor_head import AnchorHead 9 | 10 | 11 | @MODELS.register_module() 12 | class RetinaHead(AnchorHead): 13 | r"""An anchor-based head used in `RetinaNet 14 | `_. 15 | 16 | The head contains two subnetworks. The first classifies anchor boxes and 17 | the second regresses deltas for the anchors. 18 | 19 | Example: 20 | >>> import torch 21 | >>> self = RetinaHead(11, 7) 22 | >>> x = torch.rand(1, 7, 32, 32) 23 | >>> cls_score, bbox_pred = self.forward_single(x) 24 | >>> # Each anchor predicts a score for each class except background 25 | >>> cls_per_anchor = cls_score.shape[1] / self.num_anchors 26 | >>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors 27 | >>> assert cls_per_anchor == (self.num_classes) 28 | >>> assert box_per_anchor == 4 29 | """ 30 | 31 | def __init__(self, 32 | num_classes, 33 | in_channels, 34 | stacked_convs=4, 35 | conv_cfg=None, 36 | norm_cfg=None, 37 | anchor_generator=dict( 38 | type='AnchorGenerator', 39 | octave_base_scale=4, 40 | scales_per_octave=3, 41 | ratios=[0.5, 1.0, 2.0], 42 | strides=[8, 16, 32, 64, 128]), 43 | **kwargs): 44 | assert stacked_convs >= 0, \ 45 | '`stacked_convs` must be non-negative integers, ' \ 46 | f'but got {stacked_convs} instead.' 47 | self.stacked_convs = stacked_convs 48 | self.conv_cfg = conv_cfg 49 | self.norm_cfg = norm_cfg 50 | super(RetinaHead, self).__init__( 51 | num_classes, 52 | in_channels, 53 | anchor_generator=anchor_generator, 54 | **kwargs) 55 | 56 | def _init_layers(self): 57 | """Initialize layers of the head.""" 58 | self.relu = nn.ReLU(inplace=True) 59 | self.cls_convs = nn.ModuleList() 60 | self.reg_convs = nn.ModuleList() 61 | in_channels = self.in_channels 62 | for i in range(self.stacked_convs): 63 | self.cls_convs.append( 64 | ConvModule( 65 | in_channels, 66 | self.feat_channels, 67 | 3, 68 | stride=1, 69 | padding=1, 70 | conv_cfg=self.conv_cfg, 71 | norm_cfg=self.norm_cfg)) 72 | self.reg_convs.append( 73 | ConvModule( 74 | in_channels, 75 | self.feat_channels, 76 | 3, 77 | stride=1, 78 | padding=1, 79 | conv_cfg=self.conv_cfg, 80 | norm_cfg=self.norm_cfg)) 81 | in_channels = self.feat_channels 82 | self.retina_cls = nn.Conv2d( 83 | in_channels, 84 | self.num_base_priors * self.cls_out_channels, 85 | 3, 86 | padding=1) 87 | self.retina_reg = nn.Conv2d( 88 | in_channels, self.num_base_priors * 4, 3, padding=1) 89 | 90 | def init_weights(self): 91 | for m in self.cls_convs.modules(): 92 | if isinstance(m, nn.Conv): 93 | normal_init(m, std=0.01) 94 | for m in self.reg_convs.modules(): 95 | if isinstance(m, nn.Conv): 96 | normal_init(m, std=0.01) 97 | 98 | normal_init(self.retina_reg, std=0.01) 99 | bias = bias_init_with_prob(0.01) 100 | normal_init(self.retina_cls, std=0.01, bias=bias) 101 | 102 | def execute_single(self, x): 103 | """Forward feature of a single scale level. 104 | 105 | Args: 106 | x (Tensor): Features of a single scale level. 107 | 108 | Returns: 109 | tuple: 110 | cls_score (Tensor): Cls scores for a single scale level 111 | the channels number is num_anchors * num_classes. 112 | bbox_pred (Tensor): Box energies / deltas for a single scale 113 | level, the channels number is num_anchors * 4. 114 | """ 115 | cls_feat = x 116 | reg_feat = x 117 | for cls_conv in self.cls_convs: 118 | cls_feat = cls_conv(cls_feat) 119 | for reg_conv in self.reg_convs: 120 | reg_feat = reg_conv(reg_feat) 121 | cls_score = self.retina_cls(cls_feat) 122 | bbox_pred = self.retina_reg(reg_feat) 123 | return cls_score, bbox_pred 124 | -------------------------------------------------------------------------------- /jittordet/models/frameworks/single_stage.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab. mmdet/models/detectors/single_stage.py 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | from typing import List, Tuple, Union 4 | 5 | import jittor as jt 6 | 7 | from jittordet.engine import MODELS, ConfigType, OptConfigType 8 | from jittordet.structures import OptSampleList, SampleList 9 | from .base_framework import BaseFramework 10 | 11 | 12 | @MODELS.register_module() 13 | class SingleStageFramework(BaseFramework): 14 | """Base class for single-stage detectors. 15 | 16 | Single-stage detectors directly and densely predict bounding boxes on the 17 | output features of the backbone+neck. 18 | """ 19 | 20 | def __init__(self, 21 | backbone: ConfigType, 22 | neck: OptConfigType = None, 23 | bbox_head: OptConfigType = None, 24 | train_cfg: OptConfigType = None, 25 | test_cfg: OptConfigType = None, 26 | preprocessor: OptConfigType = None) -> None: 27 | super().__init__(preprocessor=preprocessor) 28 | self.backbone = MODELS.build(backbone) 29 | if neck is not None: 30 | self.neck = MODELS.build(neck) 31 | bbox_head.update(train_cfg=train_cfg) 32 | bbox_head.update(test_cfg=test_cfg) 33 | self.bbox_head = MODELS.build(bbox_head) 34 | self.train_cfg = train_cfg 35 | self.test_cfg = test_cfg 36 | 37 | def loss(self, batch_inputs: jt.Var, 38 | batch_data_samples: SampleList) -> Union[dict, list]: 39 | """Calculate losses from a batch of inputs and data samples. 40 | 41 | Args: 42 | batch_inputs (Tensor): Input images of shape (N, C, H, W). 43 | These should usually be mean centered and std scaled. 44 | batch_data_samples (list[:obj:`DetDataSample`]): The batch 45 | data samples. It usually includes information such 46 | as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. 47 | 48 | Returns: 49 | dict: A dictionary of loss components. 50 | """ 51 | x = self.extract_feat(batch_inputs) 52 | losses = self.bbox_head.loss(x, batch_data_samples) 53 | return losses 54 | 55 | def predict(self, 56 | batch_inputs: jt.Var, 57 | batch_data_samples: SampleList, 58 | rescale: bool = True) -> SampleList: 59 | """Predict results from a batch of inputs and data samples with post- 60 | processing. 61 | 62 | Args: 63 | batch_inputs (Tensor): Inputs with shape (N, C, H, W). 64 | batch_data_samples (List[:obj:`DetDataSample`]): The Data 65 | Samples. It usually includes information such as 66 | `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. 67 | rescale (bool): Whether to rescale the results. 68 | Defaults to True. 69 | 70 | Returns: 71 | list[:obj:`DetDataSample`]: Detection results of the 72 | input images. Each DetDataSample usually contain 73 | 'pred_instances'. And the ``pred_instances`` usually 74 | contains following keys. 75 | 76 | - scores (Tensor): Classification scores, has a shape 77 | (num_instance, ) 78 | - labels (Tensor): Labels of bboxes, has a shape 79 | (num_instances, ). 80 | - bboxes (Tensor): Has a shape (num_instances, 4), 81 | the last dimension 4 arrange as (x1, y1, x2, y2). 82 | """ 83 | x = self.extract_feat(batch_inputs) 84 | results_list = self.bbox_head.predict( 85 | x, batch_data_samples, rescale=rescale) 86 | batch_data_samples = self.add_pred_to_datasample( 87 | batch_data_samples, results_list) 88 | return batch_data_samples 89 | 90 | def _execute( 91 | self, 92 | batch_inputs: jt.Var, 93 | batch_data_samples: OptSampleList = None) -> Tuple[List[jt.Var]]: 94 | """Network forward process. Usually includes backbone, neck and head 95 | forward without any post-processing. 96 | 97 | Args: 98 | batch_inputs (Tensor): Inputs with shape (N, C, H, W). 99 | 100 | Returns: 101 | tuple[list]: A tuple of features from ``bbox_head`` forward. 102 | """ 103 | x = self.extract_feat(batch_inputs) 104 | results = self.bbox_head.execute(x) 105 | return results 106 | 107 | def extract_feat(self, batch_inputs: jt.Var) -> Tuple[jt.Var]: 108 | """Extract features. 109 | 110 | Args: 111 | batch_inputs (Tensor): Image tensor with shape (N, C, H ,W). 112 | 113 | Returns: 114 | tuple[Tensor]: Multi-level features that may have 115 | different resolutions. 116 | """ 117 | x = self.backbone(batch_inputs) 118 | if self.with_neck: 119 | x = self.neck(x) 120 | return x 121 | -------------------------------------------------------------------------------- /jittordet/models/losses/smooth_l1_loss.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab. mmdet/models/losses/smooth_l1_loss.py 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | import jittor as jt 4 | import jittor.nn as nn 5 | 6 | from jittordet.engine import MODELS 7 | from .utils import weighted_loss 8 | 9 | 10 | @weighted_loss 11 | def smooth_l1_loss(pred, target, beta=1.0): 12 | """Smooth L1 loss. 13 | 14 | Args: 15 | pred (torch.Tensor): The prediction. 16 | target (torch.Tensor): The learning target of the prediction. 17 | beta (float, optional): The threshold in the piecewise function. 18 | Defaults to 1.0. 19 | 20 | Returns: 21 | torch.Tensor: Calculated loss 22 | """ 23 | assert beta > 0 24 | if target.numel() == 0: 25 | return pred.sum() * 0 26 | 27 | assert pred.size() == target.size() 28 | diff = jt.abs(pred - target) 29 | condition = (diff < beta).astype(diff.dtype) 30 | loss = (0.5 * diff * diff / beta) * condition + \ 31 | (diff - 0.5 * beta) * (1 - condition) 32 | return loss 33 | 34 | 35 | @weighted_loss 36 | def l1_loss(pred, target): 37 | """L1 loss. 38 | 39 | Args: 40 | pred (torch.Tensor): The prediction. 41 | target (torch.Tensor): The learning target of the prediction. 42 | 43 | Returns: 44 | torch.Tensor: Calculated loss 45 | """ 46 | if target.numel() == 0: 47 | return pred.sum() * 0 48 | 49 | assert pred.size() == target.size() 50 | loss = jt.abs(pred - target) 51 | return loss 52 | 53 | 54 | @MODELS.register_module() 55 | class SmoothL1Loss(nn.Module): 56 | """Smooth L1 loss. 57 | 58 | Args: 59 | beta (float, optional): The threshold in the piecewise function. 60 | Defaults to 1.0. 61 | reduction (str, optional): The method to reduce the loss. 62 | Options are "none", "mean" and "sum". Defaults to "mean". 63 | loss_weight (float, optional): The weight of loss. 64 | """ 65 | 66 | def __init__(self, beta=1.0, reduction='mean', loss_weight=1.0): 67 | super(SmoothL1Loss, self).__init__() 68 | self.beta = beta 69 | self.reduction = reduction 70 | self.loss_weight = loss_weight 71 | 72 | def execute(self, 73 | pred, 74 | target, 75 | weight=None, 76 | avg_factor=None, 77 | reduction_override=None, 78 | **kwargs): 79 | """Forward function. 80 | 81 | Args: 82 | pred (torch.Tensor): The prediction. 83 | target (torch.Tensor): The learning target of the prediction. 84 | weight (torch.Tensor, optional): The weight of loss for each 85 | prediction. Defaults to None. 86 | avg_factor (int, optional): Average factor that is used to average 87 | the loss. Defaults to None. 88 | reduction_override (str, optional): The reduction method used to 89 | override the original reduction method of the loss. 90 | Defaults to None. 91 | """ 92 | assert reduction_override in (None, 'none', 'mean', 'sum') 93 | reduction = ( 94 | reduction_override if reduction_override else self.reduction) 95 | loss_bbox = self.loss_weight * smooth_l1_loss( 96 | pred, 97 | target, 98 | weight, 99 | beta=self.beta, 100 | reduction=reduction, 101 | avg_factor=avg_factor, 102 | **kwargs) 103 | return loss_bbox 104 | 105 | 106 | @MODELS.register_module() 107 | class L1Loss(nn.Module): 108 | """L1 loss. 109 | 110 | Args: 111 | reduction (str, optional): The method to reduce the loss. 112 | Options are "none", "mean" and "sum". 113 | loss_weight (float, optional): The weight of loss. 114 | """ 115 | 116 | def __init__(self, reduction='mean', loss_weight=1.0): 117 | super(L1Loss, self).__init__() 118 | self.reduction = reduction 119 | self.loss_weight = loss_weight 120 | 121 | def execute(self, 122 | pred, 123 | target, 124 | weight=None, 125 | avg_factor=None, 126 | reduction_override=None): 127 | """Forward function. 128 | 129 | Args: 130 | pred (torch.Tensor): The prediction. 131 | target (torch.Tensor): The learning target of the prediction. 132 | weight (torch.Tensor, optional): The weight of loss for each 133 | prediction. Defaults to None. 134 | avg_factor (int, optional): Average factor that is used to average 135 | the loss. Defaults to None. 136 | reduction_override (str, optional): The reduction method used to 137 | override the original reduction method of the loss. 138 | Defaults to None. 139 | """ 140 | assert reduction_override in (None, 'none', 'mean', 'sum') 141 | reduction = ( 142 | reduction_override if reduction_override else self.reduction) 143 | loss_bbox = self.loss_weight * l1_loss( 144 | pred, target, weight, reduction=reduction, avg_factor=avg_factor) 145 | return loss_bbox 146 | -------------------------------------------------------------------------------- /jittordet/models/task_utils/samplers/base_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab. 2 | # mmdet/models/task_modules/samplers/base_sampler.py 3 | # Copyright (c) OpenMMLab. All rights reserved. 4 | from abc import ABCMeta, abstractmethod 5 | 6 | import jittor as jt 7 | 8 | from jittordet.structures import InstanceData 9 | from ..assigners import AssignResult 10 | from .sampling_result import SamplingResult 11 | 12 | 13 | class BaseSampler(metaclass=ABCMeta): 14 | """Base class of samplers. 15 | 16 | Args: 17 | num (int): Number of samples 18 | pos_fraction (float): Fraction of positive samples 19 | neg_pos_up (int): Upper bound number of negative and 20 | positive samples. Defaults to -1. 21 | add_gt_as_proposals (bool): Whether to add ground truth 22 | boxes as proposals. Defaults to True. 23 | """ 24 | 25 | def __init__(self, 26 | num: int, 27 | pos_fraction: float, 28 | neg_pos_ub: int = -1, 29 | add_gt_as_proposals: bool = True, 30 | **kwargs) -> None: 31 | self.num = num 32 | self.pos_fraction = pos_fraction 33 | self.neg_pos_ub = neg_pos_ub 34 | self.add_gt_as_proposals = add_gt_as_proposals 35 | self.pos_sampler = self 36 | self.neg_sampler = self 37 | 38 | @abstractmethod 39 | def _sample_pos(self, assign_result: AssignResult, num_expected: int, 40 | **kwargs): 41 | """Sample positive samples.""" 42 | pass 43 | 44 | @abstractmethod 45 | def _sample_neg(self, assign_result: AssignResult, num_expected: int, 46 | **kwargs): 47 | """Sample negative samples.""" 48 | pass 49 | 50 | def sample(self, assign_result: AssignResult, pred_instances: InstanceData, 51 | gt_instances: InstanceData, **kwargs) -> SamplingResult: 52 | """Sample positive and negative bboxes. 53 | 54 | This is a simple implementation of bbox sampling given candidates, 55 | assigning results and ground truth bboxes. 56 | 57 | Args: 58 | assign_result (:obj:`AssignResult`): Assigning results. 59 | pred_instances (:obj:`InstanceData`): Instances of model 60 | predictions. It includes ``priors``, and the priors can 61 | be anchors or points, or the bboxes predicted by the 62 | previous stage, has shape (n, 4). The bboxes predicted by 63 | the current model or stage will be named ``bboxes``, 64 | ``labels``, and ``scores``, the same as the ``InstanceData`` 65 | in other places. 66 | gt_instances (:obj:`InstanceData`): Ground truth of instance 67 | annotations. It usually includes ``bboxes``, with shape (k, 4), 68 | and ``labels``, with shape (k, ). 69 | 70 | Returns: 71 | :obj:`SamplingResult`: Sampling result. 72 | 73 | Example: 74 | >>> from mmengine.structures import InstanceData 75 | >>> from mmdet.models.task_modules.samplers import RandomSampler, 76 | >>> from mmdet.models.task_modules.assigners import AssignResult 77 | >>> from mmdet.models.task_modules.samplers. 78 | ... sampling_result import ensure_rng, random_boxes 79 | >>> rng = ensure_rng(None) 80 | >>> assign_result = AssignResult.random(rng=rng) 81 | >>> pred_instances = InstanceData() 82 | >>> pred_instances.priors = random_boxes(assign_result.num_preds, 83 | ... rng=rng) 84 | >>> gt_instances = InstanceData() 85 | >>> gt_instances.bboxes = random_boxes(assign_result.num_gts, 86 | ... rng=rng) 87 | >>> gt_instances.labels = torch.randint( 88 | ... 0, 5, (assign_result.num_gts,), dtype=torch.long) 89 | >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, 90 | >>> add_gt_as_proposals=False) 91 | >>> self = self.sample(assign_result, pred_instances, gt_instances) 92 | """ 93 | gt_bboxes = gt_instances.bboxes 94 | priors = pred_instances.priors 95 | gt_labels = gt_instances.labels 96 | if len(priors.shape) < 2: 97 | priors = priors[None, :] 98 | 99 | gt_flags = jt.zeros((priors.shape[0], ), dtype=jt.uint8) 100 | if self.add_gt_as_proposals and len(gt_bboxes) > 0: 101 | priors = jt.concat([gt_bboxes, priors], dim=0) 102 | assign_result.add_gt_(gt_labels) 103 | gt_ones = jt.ones(gt_bboxes.shape[0], dtype=jt.uint8) 104 | gt_flags = jt.concat([gt_ones, gt_flags]) 105 | 106 | num_expected_pos = int(self.num * self.pos_fraction) 107 | pos_inds = self.pos_sampler._sample_pos( 108 | assign_result, num_expected_pos, bboxes=priors, **kwargs) 109 | # We found that sampled indices have duplicated items occasionally. 110 | # (may be a bug of PyTorch) 111 | pos_inds = pos_inds.unique() 112 | num_sampled_pos = pos_inds.numel() 113 | num_expected_neg = self.num - num_sampled_pos 114 | if self.neg_pos_ub >= 0: 115 | _pos = max(1, num_sampled_pos) 116 | neg_upper_bound = int(self.neg_pos_ub * _pos) 117 | if num_expected_neg > neg_upper_bound: 118 | num_expected_neg = neg_upper_bound 119 | neg_inds = self.neg_sampler._sample_neg( 120 | assign_result, num_expected_neg, bboxes=priors, **kwargs) 121 | neg_inds = neg_inds.unique() 122 | 123 | sampling_result = SamplingResult( 124 | pos_inds=pos_inds, 125 | neg_inds=neg_inds, 126 | priors=priors, 127 | gt_bboxes=gt_bboxes, 128 | assign_result=assign_result, 129 | gt_flags=gt_flags) 130 | return sampling_result 131 | -------------------------------------------------------------------------------- /jittordet/models/losses/focal_loss.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenMMLab. mmdet/models/losses/focal_loss.py 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | import jittor.nn as nn 4 | 5 | from jittordet.engine import MODELS 6 | from .cross_entropy_loss import binary_cross_entropy_with_logits 7 | from .utils import weight_reduce_loss 8 | 9 | 10 | # This method is only for debugging 11 | def py_sigmoid_focal_loss(pred, 12 | target, 13 | weight=None, 14 | gamma=2.0, 15 | alpha=0.25, 16 | reduction='mean', 17 | avg_factor=None): 18 | """PyTorch version of `Focal Loss `_. 19 | 20 | Args: 21 | pred (torch.Tensor): The prediction with shape (N, C), C is the 22 | number of classes 23 | target (torch.Tensor): The learning label of the prediction. 24 | weight (torch.Tensor, optional): Sample-wise loss weight. 25 | gamma (float, optional): The gamma for calculating the modulating 26 | factor. Defaults to 2.0. 27 | alpha (float, optional): A balanced form for Focal Loss. 28 | Defaults to 0.25. 29 | reduction (str, optional): The method used to reduce the loss into 30 | a scalar. Defaults to 'mean'. 31 | avg_factor (int, optional): Average factor that is used to average 32 | the loss. Defaults to None. 33 | """ 34 | pred_sigmoid = pred.sigmoid() 35 | target = target.astype(pred.dtype) 36 | pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target) 37 | focal_weight = (alpha * target + (1 - alpha) * 38 | (1 - target)) * pt.pow(gamma) 39 | loss = binary_cross_entropy_with_logits( 40 | pred, target, reduction='none') * focal_weight 41 | if weight is not None: 42 | if weight.shape != loss.shape: 43 | if weight.size(0) == loss.size(0): 44 | # For most cases, weight is of shape (num_priors, ), 45 | # which means it does not have the second axis num_class 46 | weight = weight.view(-1, 1) 47 | else: 48 | # Sometimes, weight per anchor per class is also needed. e.g. 49 | # in FSAF. But it may be flattened of shape 50 | # (num_priors x num_class, ), while loss is still of shape 51 | # (num_priors, num_class). 52 | assert weight.numel() == loss.numel() 53 | weight = weight.view(loss.size(0), -1) 54 | assert weight.ndim == loss.ndim 55 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 56 | return loss 57 | 58 | 59 | @MODELS.register_module() 60 | class FocalLoss(nn.Module): 61 | 62 | def __init__(self, 63 | use_sigmoid=True, 64 | gamma=2.0, 65 | alpha=0.25, 66 | reduction='mean', 67 | loss_weight=1.0): 68 | """`Focal Loss `_ 69 | 70 | Args: 71 | use_sigmoid (bool, optional): Whether to the prediction is 72 | used for sigmoid or softmax. Defaults to True. 73 | gamma (float, optional): The gamma for calculating the modulating 74 | factor. Defaults to 2.0. 75 | alpha (float, optional): A balanced form for Focal Loss. 76 | Defaults to 0.25. 77 | reduction (str, optional): The method used to reduce the loss into 78 | a scalar. Defaults to 'mean'. Options are "none", "mean" and 79 | "sum". 80 | loss_weight (float, optional): Weight of loss. Defaults to 1.0. 81 | activated (bool, optional): Whether the input is activated. 82 | If True, it means the input has been activated and can be 83 | treated as probabilities. Else, it should be treated as logits. 84 | Defaults to False. 85 | """ 86 | super(FocalLoss, self).__init__() 87 | assert use_sigmoid is True, 'Only sigmoid focal loss supported now.' 88 | self.use_sigmoid = use_sigmoid 89 | self.gamma = gamma 90 | self.alpha = alpha 91 | self.reduction = reduction 92 | self.loss_weight = loss_weight 93 | 94 | def execute(self, 95 | pred, 96 | target, 97 | weight=None, 98 | avg_factor=None, 99 | reduction_override=None): 100 | """Forward function. 101 | 102 | Args: 103 | pred (torch.Tensor): The prediction. 104 | target (torch.Tensor): The learning label of the prediction. 105 | weight (torch.Tensor, optional): The weight of loss for each 106 | prediction. Defaults to None. 107 | avg_factor (int, optional): Average factor that is used to average 108 | the loss. Defaults to None. 109 | reduction_override (str, optional): The reduction method used to 110 | override the original reduction method of the loss. 111 | Options are "none", "mean" and "sum". 112 | 113 | Returns: 114 | torch.Tensor: The calculated loss 115 | """ 116 | assert reduction_override in (None, 'none', 'mean', 'sum') 117 | reduction = ( 118 | reduction_override if reduction_override else self.reduction) 119 | if self.use_sigmoid: 120 | num_classes = pred.size(1) 121 | target = nn.one_hot(target, num_classes=num_classes + 1) 122 | target = target[:, :num_classes] 123 | loss_cls = self.loss_weight * py_sigmoid_focal_loss( 124 | pred, 125 | target, 126 | weight, 127 | gamma=self.gamma, 128 | alpha=self.alpha, 129 | reduction=reduction, 130 | avg_factor=avg_factor) 131 | else: 132 | raise NotImplementedError 133 | return loss_cls 134 | -------------------------------------------------------------------------------- /jittordet/models/losses/kd_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import jittor 3 | import jittor.nn as nn 4 | 5 | from jittordet.engine import MODELS 6 | from .cross_entropy_loss import binary_cross_entropy_with_logits 7 | from .utils import weight_reduce_loss, weighted_loss 8 | 9 | 10 | @weighted_loss 11 | def knowledge_distillation_kl_div_loss(pred, 12 | soft_label, 13 | T, 14 | class_reduction='mean', 15 | detach_target=True): 16 | assert pred.size() == soft_label.size() 17 | target = nn.softmax(soft_label / T, dim=1) 18 | if detach_target: 19 | target = target.detach() 20 | 21 | kd_loss_func = nn.KLDivLoss(reduction='none') 22 | kd_loss = kd_loss_func(nn.log_softmax(pred / T, dim=1), target) 23 | 24 | if class_reduction == 'mean': 25 | kd_loss = kd_loss.mean(1) 26 | elif class_reduction == 'sum': 27 | kd_loss = kd_loss.sum(1) 28 | else: 29 | raise NotImplementedError 30 | kd_loss = kd_loss * (T * T) 31 | return kd_loss 32 | 33 | 34 | def kd_quality_focal_loss(pred, 35 | target, 36 | weight=None, 37 | beta=1, 38 | reduction='mean', 39 | avg_factor=None): 40 | num_classes = pred.size(1) 41 | if weight is not None: 42 | weight = weight[:, None].repeat(1, num_classes) 43 | 44 | target = target.detach().sigmoid() 45 | loss = binary_cross_entropy_with_logits(pred, target, reduction='none') 46 | focal_weight = jittor.abs(pred.sigmoid() - target).pow(beta) 47 | loss = loss * focal_weight 48 | loss = weight_reduce_loss(loss, weight, reduction, avg_factor) 49 | return loss 50 | 51 | 52 | @MODELS.register_module() 53 | class KnowledgeDistillationKLDivLoss(nn.Module): 54 | """Loss function for knowledge distilling using KL divergence. 55 | 56 | Args: 57 | reduction (str): Options are `'none'`, `'mean'` and `'sum'`. 58 | loss_weight (float): Loss weight of current loss. 59 | T (int): Temperature for distillation. 60 | """ 61 | 62 | def __init__(self, 63 | class_reduction='mean', 64 | reduction='mean', 65 | loss_weight=1.0, 66 | T=10): 67 | super(KnowledgeDistillationKLDivLoss, self).__init__() 68 | assert T >= 1 69 | self.class_reduction = class_reduction 70 | self.reduction = reduction 71 | self.loss_weight = loss_weight 72 | self.T = T 73 | 74 | def execute(self, 75 | pred, 76 | soft_label, 77 | weight=None, 78 | avg_factor=None, 79 | reduction_override=None): 80 | """Forward function. 81 | 82 | Args: 83 | pred (Tensor): Predicted logits with shape (N, n + 1). 84 | soft_label (Tensor): Target logits with shape (N, N + 1). 85 | weight (torch.Tensor, optional): The weight of loss for each 86 | prediction. Defaults to None. 87 | avg_factor (int, optional): Average factor that is used to average 88 | the loss. Defaults to None. 89 | reduction_override (str, optional): The reduction method used to 90 | override the original reduction method of the loss. 91 | Defaults to None. 92 | """ 93 | assert reduction_override in (None, 'none', 'mean', 'sum') 94 | 95 | reduction = ( 96 | reduction_override if reduction_override else self.reduction) 97 | 98 | loss_kd = self.loss_weight * knowledge_distillation_kl_div_loss( 99 | pred, 100 | soft_label, 101 | weight, 102 | class_reduction=self.class_reduction, 103 | reduction=reduction, 104 | avg_factor=avg_factor, 105 | T=self.T) 106 | 107 | return loss_kd 108 | 109 | 110 | @MODELS.register_module() 111 | class KDQualityFocalLoss(nn.Module): 112 | 113 | def __init__(self, 114 | use_sigmoid=True, 115 | beta=1.0, 116 | reduction='mean', 117 | loss_weight=1.0): 118 | super(KDQualityFocalLoss, self).__init__() 119 | assert use_sigmoid is True, 'Only sigmoid in QFL supported now.' 120 | self.use_sigmoid = use_sigmoid 121 | self.beta = beta 122 | self.reduction = reduction 123 | self.loss_weight = loss_weight 124 | 125 | def execute(self, 126 | pred, 127 | target, 128 | weight=None, 129 | avg_factor=None, 130 | reduction_override=None): 131 | """Forward function. 132 | 133 | Args: 134 | pred (torch.Tensor): Predicted joint representation of 135 | classification and quality (IoU) estimation with shape (N, C), 136 | C is the number of classes. 137 | target (tuple([torch.Tensor])): Target category label with shape 138 | (N,) and target quality label with shape (N,). 139 | weight (torch.Tensor, optional): The weight of loss for each 140 | prediction. Defaults to None. 141 | avg_factor (int, optional): Average factor that is used to average 142 | the loss. Defaults to None. 143 | reduction_override (str, optional): The reduction method used to 144 | override the original reduction method of the loss. 145 | Defaults to None. 146 | """ 147 | assert reduction_override in (None, 'none', 'mean', 'sum') 148 | reduction = ( 149 | reduction_override if reduction_override else self.reduction) 150 | if self.use_sigmoid: 151 | loss = self.loss_weight * kd_quality_focal_loss( 152 | pred, 153 | target, 154 | weight, 155 | beta=self.beta, 156 | reduction=reduction, 157 | avg_factor=avg_factor) 158 | else: 159 | raise NotImplementedError 160 | return loss 161 | -------------------------------------------------------------------------------- /jittordet/datasets/coco.py: -------------------------------------------------------------------------------- 1 | # Modified from mmdetection.dataset.coco 2 | import copy 3 | import os.path as osp 4 | 5 | from pycocotools.coco import COCO 6 | 7 | from ..engine import DATASETS 8 | from .base import BaseDetDataset 9 | 10 | 11 | @DATASETS.register_module() 12 | class CocoDataset(BaseDetDataset): 13 | """Dataset for COCO.""" 14 | 15 | METAINFO = { 16 | 'classes': 17 | ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 18 | 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 19 | 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 20 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 21 | 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 22 | 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 23 | 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 24 | 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 25 | 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 26 | 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 27 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 28 | 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 29 | 'scissors', 'teddy bear', 'hair drier', 'toothbrush'), 30 | # palette is a list of color tuples, which is used for visualization. 31 | 'palette': 32 | [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), 33 | (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), 34 | (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), 35 | (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255), 36 | (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255), 37 | (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92), 38 | (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164), 39 | (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0), 40 | (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174), 41 | (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54), 42 | (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51), 43 | (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65), 44 | (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0), 45 | (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161), 46 | (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120), 47 | (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133), 48 | (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62), 49 | (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45), 50 | (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1), 51 | (246, 0, 122), (191, 162, 208)] 52 | } 53 | 54 | def load_data_list(self): 55 | self.coco = COCO(self.ann_file) 56 | # The order of returned `cat_ids` will not 57 | # change with the order of the `classes` 58 | self.cat_ids = self.coco.getCatIds(self.metainfo['classes']) 59 | self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} 60 | self.cat_img_map = copy.deepcopy(self.coco.catToImgs) 61 | 62 | img_ids = self.coco.getImgIds() 63 | data_list = [] 64 | for img_id in img_ids: 65 | raw_img_info = self.coco.loadImgs([img_id])[0] 66 | raw_img_info['img_id'] = img_id 67 | 68 | ann_ids = self.coco.getAnnIds(imgIds=[img_id]) 69 | raw_ann_info = self.coco.loadAnns(ann_ids) 70 | 71 | parsed_data_info = self.parse_data_info(raw_ann_info, raw_img_info) 72 | data_list.append(parsed_data_info) 73 | 74 | del self.coco 75 | 76 | return data_list 77 | 78 | def parse_data_info(self, ann_info, img_info): 79 | data_info = {} 80 | 81 | img_path = osp.join(self.img_path, img_info['file_name']) 82 | data_info['img_path'] = img_path 83 | data_info['img_id'] = img_info['img_id'] 84 | data_info['height'] = img_info['height'] 85 | data_info['width'] = img_info['width'] 86 | 87 | instances = [] 88 | for i, ann in enumerate(ann_info): 89 | instance = {} 90 | 91 | if ann.get('ignore', False): 92 | continue 93 | x1, y1, w, h = ann['bbox'] 94 | inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0)) 95 | inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0)) 96 | if inter_w * inter_h == 0: 97 | continue 98 | if ann['area'] <= 0 or w < 1 or h < 1: 99 | continue 100 | if ann['category_id'] not in self.cat_ids: 101 | continue 102 | bbox = [x1, y1, x1 + w, y1 + h] 103 | 104 | if ann.get('iscrowd', False): 105 | instance['ignore_flag'] = 1 106 | else: 107 | instance['ignore_flag'] = 0 108 | instance['bbox'] = bbox 109 | instance['bbox_label'] = self.cat2label[ann['category_id']] 110 | 111 | if ann.get('segmentation', None): 112 | instance['mask'] = ann['segmentation'] 113 | 114 | instances.append(instance) 115 | data_info['instances'] = instances 116 | return data_info 117 | 118 | def filter_data(self): 119 | if self.test_mode: 120 | return self.data_list 121 | 122 | if self.filter_cfg is None: 123 | return self.data_list 124 | 125 | filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) 126 | min_size = self.filter_cfg.get('min_size', 0) 127 | 128 | # obtain images that contain annotation 129 | ids_with_ann = set(data_info['img_id'] for data_info in self.data_list) 130 | # obtain images that contain annotations of the required categories 131 | ids_in_cat = set() 132 | for i, class_id in enumerate(self.cat_ids): 133 | ids_in_cat |= set(self.cat_img_map[class_id]) 134 | # merge the image id sets of the two conditions and use the merged set 135 | # to filter out images if self.filter_empty_gt=True 136 | ids_in_cat &= ids_with_ann 137 | 138 | valid_data_infos = [] 139 | for i, data_info in enumerate(self.data_list): 140 | img_id = data_info['img_id'] 141 | width = data_info['width'] 142 | height = data_info['height'] 143 | if filter_empty_gt and img_id not in ids_in_cat: 144 | continue 145 | if min(width, height) >= min_size: 146 | valid_data_infos.append(data_info) 147 | 148 | return valid_data_infos 149 | --------------------------------------------------------------------------------